from datetime import datetime import numpy as np import matplotlib.pyplot as plt import seaborn as sns from regress_rr import load_tsfresh from os.path import join from os import makedirs import ipdb import joblib import pandas as pd from config import * from regress_rr import get_activity_log from modules.evaluations import Evaluation from modules.digitalsignalprocessing import ( movingaverage, butter_lowpass_filter) from sklearn.preprocessing import PolynomialFeatures, LabelEncoder plt.close('all') plt.rcParams.update({'figure.titlesize' : 8, 'axes.titlesize' : 7, 'axes.labelsize' : 7, 'xtick.labelsize' : 6, 'ytick.labelsize' : 6, 'legend.fontsize' : 6, 'legend.title_fontsize' : 6, }) window_size = 20 window_shift = 0.05 # get the test data for the subject # load imu, bvp, and imu+bvp data for the subject # get the lin reg data for the subject # plot the pss vs the linear regression data for the subject # highlight standing periods sbj = 'S01' sens_list = ['imu', 'bvp', 'imu-bvp'] # cfg_ids = [4, 0, 0] cfg_ids = [5, 11, 4] lbl_str = 'pss' mdl_dir = join(DATA_DIR, 'subject_specific', sbj) mdl_file = 'linr_model.joblib' tsfresh_dir = join( mdl_dir, 'tsfresh__winsize_{0}__winshift_{1}'.format(window_size, window_shift) ) combi_strs = ['combi7.0-10.0-12.0-15.0-17.0']*3 def get_model(sens, sbj, cfg_id, combi): m_dir = f'linreg_{sens}_rr_{sbj}_id{cfg_id}_{combi}' mdl_fname = join(mdl_dir, sens+'_rr', str(cfg_id).zfill(2), m_dir, mdl_file) return joblib.load(mdl_fname) def get_test_data(): test_fname = 'test__winsize_{0}__winshift_{1}__tsfresh.pkl'.format( window_size, window_shift) test_df = pd.read_pickle(join(tsfresh_dir, test_fname)) y_cols = ['sec', 'br', 'pss', 'cpm'] x_cols = [col for col in test_df.columns.values if col not in y_cols] x_test = test_df[x_cols] y_test = test_df[lbl_str].values.reshape(-1, 1) time = test_df['sec'] return x_test, y_test, time if __name__ == '__main__': activity_log = get_activity_log(sbj) standing_mask = activity_log['Activity'] == 'standing' standing_log = activity_log[standing_mask] x_test, y_test, time = get_test_data() y_out = {} metrics_dict = {} x_test_cols = x_test.columns y_test = movingaverage(y_test, 8) cm = 1/2.54 fig, axs = plt.subplots(figsize=(14*cm, 7.5*cm), dpi=300) axs.plot(y_test, label='ground-truth', linewidth=1, c='k'); standing_wins = 10*60 for sens, cfg_id, combi_str in zip(sens_list, cfg_ids, combi_strs): model = get_model(sens, sbj, cfg_id, combi_str) if sens == 'imu': cols = [col for col in x_test_cols if 'acc' in col or 'gyr' in col] elif sens == 'bvp': cols = [col for col in x_test_cols if 'bvp' in col] else: cols = x_test.columns poly = PolynomialFeatures(1) x_test_sens = x_test[cols] x_test_sens = poly.fit_transform(x_test_sens) y_pred = model.predict(x_test_sens) y_pred = movingaverage(y_pred, 64) evals = Evaluation(y_test.flatten(), y_pred.flatten()) metrics_dict[sens]=evals.get_evals() axs.plot(y_pred, label=sens.upper(), linewidth=1) # axs.plot(np.abs(y_test - y_pred), label=sens.upper()) y0 = axs.get_ylim()[-1] # y1 = axs[1].get_ylim()[-1] fmt = "%d/%m/%Y %H:%M:%S" x = np.arange(len(y_test)) for start_, end_ in zip( standing_log.iloc[::2].iterrows(), standing_log.iloc[1::2].iterrows() ): start = start_[1] end = end_[1] start_sec = datetime.strptime(start['Timestamps'], fmt).timestamp() end_sec = datetime.strptime(end['Timestamps'], fmt).timestamp() mask = (time >= start_sec) & (time <= end_sec) axs.fill_between(x, y0, where=mask, facecolor='k', alpha=0.2) # axs[1].fill_between(x, y1, where=mask, facecolor='k', alpha=0.2) diff = np.diff(time) diff = np.insert(diff, 0, 0) new_day_idx = diff.argmax() axs.axvline(new_day_idx, c='r', alpha=0.6, linewidth=1) # axs[1].axvline(new_day_idx, c='r') print(metrics_dict) axs.legend(ncols=4, prop={'size': 6}) # fig.suptitle("RR Predictions during Rest Standing and Sitting") axs.set_title("Multi-Modal RR Predictions during Rest Standing and Sitting") # axs[1].set_title("Absolute Error over Time") axs.set_xlabel("Time (indices)") axs.set_ylabel("Respiration Rate (CPM)") # axs[1].set_ylabel("Abs. Error") fig.savefig('S01cal.png')