Skip to content
Snippets Groups Projects
Commit 22a6af40 authored by Raymond Chia's avatar Raymond Chia
Browse files

median fs

parent 56a0b294
Branches
No related merge requests found
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
No preview for this file type
......@@ -396,7 +396,7 @@ def df_win_task(w_inds, df, i, cols):
t0, t1 = time[w_inds][0], time[w_inds][-1]
diff = time[w_inds[1:]] - time[w_inds[0:-1]]
fs_est = 1/np.mean(diff)
fs_est = 1/np.median(diff)
if fs_est > 70 and 'acc_x' in cols: fs = IMU_FS
elif fs_est < 70 and 'bvp' in cols: fs = PPG_FS
......@@ -964,6 +964,85 @@ class EvalHandler():
def save_eval_history(self):
self.eval_hist.to_csv(self.eval_history_file, index=False)
# save evaluation metrics in single file that handles the models for the
# subject and config
class DSPEvalHandler():
"""
Handles the evaluation metric for each subject and DSP sensor.
...
Attributes
----------
y_true : numpy.ndarray
a numpy array of the respiration rate ground truth values from the
bioharness
y_pred : numpy.ndarray
a numpy array of the predicted respiration rate
subject : str
the subject in format Pilot01, S01 etc.
pfh : ProjectFileHandler
custom class detailing the directories, metafile, and configurations
sens_str : str
a string to inform what sensor was used
overwrite : bool
overwrites the evaluations (default False)
Methods
-------
load_eval_history()
loads the evaluation file
save_eval_history()
saves the evaluation file
update_eval_history()
updates the evaluation file using the new entry if there is no matching
model or configuration for the given subject
"""
def __init__(self, y_true, y_pred, subject, pfh, sens_str, overwrite=False):
self.subject = subject
self.config = pfh.config
self.parent_directory = join(DATA_DIR, 'subject_specific')
self.fset_id = pfh.fset_id
self.sens_str = sens_str
self.overwrite = overwrite
self.evals = Evaluation(y_true, y_pred)
entry = {'subject': self.subject,
'config_id': self.fset_id,
'sens_str': self.sens_str,
}
self.entry = {**entry, **self.config, **self.evals.get_evals()}
self.eval_history_file = join(self.parent_directory,
'dsp_eval_history.csv')
self.eval_hist = self.load_eval_history()
def load_eval_history(self):
if not exists(self.eval_history_file):
return None
else:
return pd.read_csv(self.eval_history_file)
def update_eval_history(self):
eval_hist = self.eval_hist
if eval_hist is None:
eval_hist = pd.DataFrame([self.entry])
else:
index_list = eval_hist[
(eval_hist['subject'] == self.entry['subject']) &\
(eval_hist['config_id'] == self.entry['config_id']) &\
(eval_hist['sens_str'] == self.entry['sens_str'])
].index.tolist()
if len(index_list) == 0:
print("adding new entry")
eval_hist = eval_hist._append(self.entry, ignore_index=True)
elif index_list is not None and self.overwrite:
eval_hist.loc[index_list[0]] = self.entry
self.eval_hist = eval_hist
def save_eval_history(self):
self.eval_hist.to_csv(self.eval_history_file, index=False)
def imu_rr_dsp(subject,
window_size=12,
window_shift=0.2,
......@@ -1049,35 +1128,39 @@ def imu_rr_dsp(subject,
window_size=window_size,
window_shift=window_shift,
fs=fs,
cols=['gyr_x', 'gyr_y', 'gyr_z'])
cols=['gyro_x', 'gyro_y', 'gyro_z'])
acc_evals = Evaluation(acc_y_dsp_df[lbl_str], acc_dsp_df['pred'])
gyr_evals = Evaluation(gyr_y_dsp_df[lbl_str], gyr_dsp_df['pred'])
print("acc evals: \n", acc_evals.get_evals())
print("gyr evals: \n", gyr_evals.get_evals())
plt.subplot(211)
plt.plot(acc_y_dsp_df[lbl_str]); plt.plot(acc_dsp_df['pred'])
plt.subplot(212)
plt.plot(gyr_y_dsp_df[lbl_str]); plt.plot(gyr_dsp_df['pred'])
plt.show()
# TODO
eval_handle = DSPEvalHandler(y_test.flatten(), preds.flatten(), subject,
pfh, None, overwrite=overwrite)
eval_handle.update_eval_history()
eval_handle.save_eval_history()
acc_eval = DSPEvalHandler(acc_evals.y_true.values.flatten(),
acc_evals.y_pred.values.flatten(),
subject,
pfh, 'acc', overwrite=overwrite)
acc_eval.update_eval_history()
acc_eval.save_eval_history()
gyr_eval = DSPEvalHandler(gyr_evals.y_true.values.flatten(),
gyr_evals.y_pred.values.flatten(),
subject,
pfh, 'gyro', overwrite=overwrite)
gyr_eval.update_eval_history()
gyr_eval.save_eval_history()
pp = PrettyPrinter()
pp.pprint(eval_handle.load_eval_history())
fig, ax = plt.subplots()
fig_title = '_'.join([mdl_str, subject]+[combi_str])
ax.plot(y_test)
ax.plot(preds)
ax.set_title(fig_title)
ax.legend([lbl_str, 'pred'])
pp.pprint(acc_eval.load_eval_history())
fig, ax = plt.subplots(2, 1)
ax[0].plot(acc_y_dsp_df[lbl_str]); plt.plot(acc_dsp_df['pred'])
ax[0].set_title("ACC")
ax[1].plot(gyr_y_dsp_df[lbl_str]); plt.plot(gyr_dsp_df['pred'])
ax[1].set_title("GYRO")
ax[1].legend([lbl_str, 'estimate'])
fig_dir = join(project_dir, 'figures')
if not exists(fig_dir): mkdir(fig_dir)
fig_title = '_'.join([subject, 'dsp'])
fig.savefig(join(fig_dir, fig_title+".png"))
plt.close()
......@@ -1340,7 +1423,7 @@ def arg_parser():
)
parser.add_argument("-s", '--subject', type=int,
default=2,
choices=list(range(1,4))+[-1],
choices=list(range(1,5))+[-1],
)
parser.add_argument("-f", '--feature_method', type=str,
default='minirocket',
......@@ -1371,12 +1454,18 @@ def arg_parser():
help='1 or 0 input, choose if standing data will be '\
'recorded or not'
)
parser.add_argument('--method', type=str,
default='ml',
help="choose between 'ml' or 'dsp' methods for"\
" regression",
choices=['ml', 'dsp']
)
args = parser.parse_args()
return args
if __name__ == '__main__':
np.random.seed(100)
n_subject_max = 3
n_subject_max = 4
args = arg_parser()
# Load command line arguments
......@@ -1390,13 +1479,14 @@ if __name__ == '__main__':
overwrite = args.overwrite
data_input = args.data_input
test_standing = args.test_standing
method = args.method
print(args)
assert train_len>0,"--train_len must be an integer greater than 0"
subject_pre_string = 'Pilot'
if subject > 0:
if subject > 0 and method == 'ml':
subject = subject_pre_string+str(subject).zfill(2)
sens_rr_model(subject,
......@@ -1410,7 +1500,36 @@ if __name__ == '__main__':
test_standing=test_standing,
data_input=data_input,
)
else:
elif subject <= 0 and method == 'ml':
subjects = [subject_pre_string+str(i).zfill(2) for i in \
range(2, n_subject_max+1)]
rr_func = partial(sens_rr_model,
window_size=window_size,
window_shift=window_shift,
lbl_str=lbl_str,
mdl_str=mdl_str,
overwrite=overwrite,
feature_method=feature_method,
train_len=train_len,
test_standing=test_standing,
data_input=data_input,
)
for subject in subjects:
rr_func(subject)
if subject > 0 and method == 'dsp':
subject = subject_pre_string+str(subject).zfill(2)
imu_rr_dsp(subject,
window_size=window_size,
window_shift=window_shift,
lbl_str=lbl_str,
overwrite=overwrite,
train_len=train_len,
test_standing=test_standing)
elif subject <= 0 and method == 'ml':
subjects = [subject_pre_string+str(i).zfill(2) for i in \
range(2, n_subject_max+1)]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment