From dd9ea695d418eaa2b9d5c7c6c857e10fb8c5c2e8 Mon Sep 17 00:00:00 2001 From: Raymond Chia <rqchia@mercury9.ihpc.uts.edu.au> Date: Thu, 23 Nov 2023 19:23:23 +1100 Subject: [PATCH] imu+bvp implemented --- regress_rr.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/regress_rr.py b/regress_rr.py index f600080..2939f6d 100644 --- a/regress_rr.py +++ b/regress_rr.py @@ -39,7 +39,7 @@ from modules.digitalsignalprocessing import imu_signal_processing from modules.digitalsignalprocessing import bvp_signal_processing from modules.digitalsignalprocessing import hernandez_sp, reject_artefact from modules.digitalsignalprocessing import do_pad_fft,\ - pressure_signal_processing, infer_frequency + pressure_signal_processing, infer_frequency, movingaverage from modules.utils import * from modules.evaluations import Evaluation @@ -762,7 +762,8 @@ class EvalHandler(): (eval_hist['subject'] == self.entry['subject']) &\ (eval_hist['config_id'] == self.entry['config_id']) &\ (eval_hist['mdl_str'] == self.entry['mdl_str']) &\ - (eval_hist['cpm'] == self.entry['cpm'])\ + (eval_hist['cpm'] == self.entry['cpm']) &\ + (eval_hist['sens_list'] == self.entry['sens_list'])\ ].index.tolist() if len(index_list) == 0: print("adding new entry") @@ -1053,12 +1054,24 @@ def sens_rr_model(subject, pp = PrettyPrinter() pp.pprint(eval_handle.load_eval_history()) - fig, ax = plt.subplots() - fig_title = '_'.join([mdl_str, subject]+[combi_str]) - ax.plot(y_test) - ax.plot(preds) - ax.set_title(fig_title) - ax.legend([lbl_str, 'pred']) + fig, ax = plt.subplots(2, 1, figsize=(7.3, 4.5)) + fig_title = '_'.join([mdl_str, data_input, subject]+[combi_str]) + fig.suptitle(fig_title) + ax[0].plot(y_test) + ax[0].plot(preds) + ax[0].set_title('raw') + + if lbl_str == 'pss': + br = y_test_df['br'].values + ax[1].plot(movingaverage(y_test, 12), color='tab:blue') + ax[1].plot(br, 'k') + ax[1].plot(movingaverage(preds, 12), color='tab:orange') + ax[1].legend([lbl_str, 'br', 'pred']) + else: + ax[1].plot(y_test, 'k') + ax[1].plot(movingaverage(preds, 12), color='tab:orange') + ax[1].legend([lbl_str, 'pred']) + ax[1].set_title('smoothened') fig_dir = join(project_dir, 'figures') if not exists(fig_dir): mkdir(fig_dir) fig.savefig(join(fig_dir, fig_title+".png")) @@ -1074,7 +1087,7 @@ def arg_parser(): ) parser.add_argument("-s", '--subject', type=int, default=2, - choices=list(range(1,3))+[-1], + choices=list(range(1,4))+[-1], ) parser.add_argument("-f", '--feature_method', type=str, default='minirocket', -- GitLab