From dd9ea695d418eaa2b9d5c7c6c857e10fb8c5c2e8 Mon Sep 17 00:00:00 2001
From: Raymond Chia <rqchia@mercury9.ihpc.uts.edu.au>
Date: Thu, 23 Nov 2023 19:23:23 +1100
Subject: [PATCH] imu+bvp implemented

---
 regress_rr.py | 31 ++++++++++++++++++++++---------
 1 file changed, 22 insertions(+), 9 deletions(-)

diff --git a/regress_rr.py b/regress_rr.py
index f600080..2939f6d 100644
--- a/regress_rr.py
+++ b/regress_rr.py
@@ -39,7 +39,7 @@ from modules.digitalsignalprocessing import imu_signal_processing
 from modules.digitalsignalprocessing import bvp_signal_processing
 from modules.digitalsignalprocessing import hernandez_sp, reject_artefact
 from modules.digitalsignalprocessing import do_pad_fft,\
-        pressure_signal_processing, infer_frequency
+        pressure_signal_processing, infer_frequency, movingaverage
 from modules.utils import *
 
 from modules.evaluations import Evaluation
@@ -762,7 +762,8 @@ class EvalHandler():
                 (eval_hist['subject'] == self.entry['subject']) &\
                 (eval_hist['config_id'] == self.entry['config_id']) &\
                 (eval_hist['mdl_str'] == self.entry['mdl_str']) &\
-                (eval_hist['cpm'] == self.entry['cpm'])\
+                (eval_hist['cpm'] == self.entry['cpm']) &\
+                (eval_hist['sens_list'] == self.entry['sens_list'])\
             ].index.tolist()
             if len(index_list) == 0:
                 print("adding new entry")
@@ -1053,12 +1054,24 @@ def sens_rr_model(subject,
         pp = PrettyPrinter()
         pp.pprint(eval_handle.load_eval_history())
 
-        fig, ax = plt.subplots()
-        fig_title = '_'.join([mdl_str, subject]+[combi_str])
-        ax.plot(y_test)
-        ax.plot(preds)
-        ax.set_title(fig_title)
-        ax.legend([lbl_str, 'pred'])
+        fig, ax = plt.subplots(2, 1, figsize=(7.3, 4.5))
+        fig_title = '_'.join([mdl_str, data_input, subject]+[combi_str])
+        fig.suptitle(fig_title)
+        ax[0].plot(y_test)
+        ax[0].plot(preds)
+        ax[0].set_title('raw')
+
+        if lbl_str == 'pss':
+            br  = y_test_df['br'].values
+            ax[1].plot(movingaverage(y_test, 12), color='tab:blue')
+            ax[1].plot(br, 'k')
+            ax[1].plot(movingaverage(preds, 12), color='tab:orange')
+            ax[1].legend([lbl_str, 'br', 'pred'])
+        else:
+            ax[1].plot(y_test, 'k')
+            ax[1].plot(movingaverage(preds, 12), color='tab:orange')
+            ax[1].legend([lbl_str, 'pred'])
+        ax[1].set_title('smoothened')
         fig_dir = join(project_dir, 'figures')
         if not exists(fig_dir): mkdir(fig_dir)
         fig.savefig(join(fig_dir, fig_title+".png"))
@@ -1074,7 +1087,7 @@ def arg_parser():
                        )
     parser.add_argument("-s", '--subject', type=int,
                         default=2,
-                        choices=list(range(1,3))+[-1],
+                        choices=list(range(1,4))+[-1],
                        )
     parser.add_argument("-f", '--feature_method', type=str,
                         default='minirocket',
-- 
GitLab