Skip to content
Snippets Groups Projects
single_subject_cal_plot.py 4.66 KiB
Newer Older
Raymond Chia's avatar
Raymond Chia committed
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from regress_rr import load_tsfresh
from os.path import join
from os import makedirs
import ipdb
import joblib

import pandas as pd

from config import *
from regress_rr import get_activity_log
from modules.evaluations import Evaluation
from modules.digitalsignalprocessing import (
    movingaverage, butter_lowpass_filter)
from sklearn.preprocessing import PolynomialFeatures, LabelEncoder

plt.close('all')
plt.rcParams.update({'figure.titlesize'              : 8,
                     'axes.titlesize'                : 7,
                     'axes.labelsize'                : 7,
                     'xtick.labelsize'               : 6,
                     'ytick.labelsize'               : 6,
                     'legend.fontsize'               : 6,
                     'legend.title_fontsize'         : 6,
                    })
window_size = 20
window_shift = 0.05

# get the test data for the subject

# load imu, bvp, and imu+bvp data for the subject

# get the lin reg data for the subject

# plot the pss vs the linear regression data for the subject

# highlight standing periods

sbj = 'S01'
sens_list = ['imu', 'bvp', 'imu-bvp']
# cfg_ids = [4, 0, 0]
cfg_ids = [5, 11, 4]

lbl_str = 'pss'

mdl_dir = join(DATA_DIR, 'subject_specific', sbj)
mdl_file = 'linr_model.joblib'

tsfresh_dir = join(
    mdl_dir,
    'tsfresh__winsize_{0}__winshift_{1}'.format(window_size, window_shift)
)

combi_strs = ['combi7.0-10.0-12.0-15.0-17.0']*3

def get_model(sens, sbj, cfg_id, combi):
    m_dir = f'linreg_{sens}_rr_{sbj}_id{cfg_id}_{combi}'
    mdl_fname = join(mdl_dir, sens+'_rr', str(cfg_id).zfill(2), m_dir, mdl_file)
    return joblib.load(mdl_fname)

def get_test_data():
    test_fname = 'test__winsize_{0}__winshift_{1}__tsfresh.pkl'.format(
        window_size, window_shift)
    test_df = pd.read_pickle(join(tsfresh_dir, test_fname))
    y_cols = ['sec', 'br', 'pss', 'cpm']
    x_cols = [col for col in test_df.columns.values if col not in y_cols]
    x_test  = test_df[x_cols]
    y_test  = test_df[lbl_str].values.reshape(-1, 1)
    time = test_df['sec']
    return x_test, y_test, time

if __name__ == '__main__':
    activity_log = get_activity_log(sbj)
    standing_mask = activity_log['Activity'] == 'standing'
    standing_log = activity_log[standing_mask]

    x_test, y_test, time = get_test_data()
    y_out = {}
    metrics_dict = {}
    x_test_cols = x_test.columns
    y_test = movingaverage(y_test, 8)
    cm = 1/2.54
    fig, axs = plt.subplots(figsize=(14*cm, 7.5*cm), dpi=300)
    axs.plot(y_test, label='ground-truth', linewidth=1, c='k');

    standing_wins = 10*60 
    for sens, cfg_id, combi_str in zip(sens_list, cfg_ids, combi_strs):
        model = get_model(sens, sbj, cfg_id, combi_str)
        if sens == 'imu':
            cols = [col for col in x_test_cols if 'acc' in col or 'gyr' in col]
        elif sens == 'bvp':
            cols = [col for col in x_test_cols if 'bvp' in col]
        else:
            cols = x_test.columns

        poly = PolynomialFeatures(1)
        x_test_sens = x_test[cols]
        x_test_sens = poly.fit_transform(x_test_sens)

        y_pred = model.predict(x_test_sens)
        y_pred = movingaverage(y_pred, 64)

        evals = Evaluation(y_test.flatten(), y_pred.flatten())
        metrics_dict[sens]=evals.get_evals()

        axs.plot(y_pred, label=sens.upper(), linewidth=1)
        # axs.plot(np.abs(y_test - y_pred), label=sens.upper())

    y0 = axs.get_ylim()[-1]
    # y1 = axs[1].get_ylim()[-1]
    fmt = "%d/%m/%Y %H:%M:%S"
    x = np.arange(len(y_test))
    for start_, end_ in zip(
        standing_log.iloc[::2].iterrows(), standing_log.iloc[1::2].iterrows()
    ):
        start = start_[1]
        end = end_[1]
        start_sec = datetime.strptime(start['Timestamps'], fmt).timestamp()
        end_sec = datetime.strptime(end['Timestamps'], fmt).timestamp()
        mask = (time >= start_sec) & (time <= end_sec)
        axs.fill_between(x, y0, where=mask, facecolor='k', alpha=0.2)
        # axs[1].fill_between(x, y1, where=mask, facecolor='k', alpha=0.2)

    diff = np.diff(time)
    diff = np.insert(diff, 0, 0)
    new_day_idx = diff.argmax()

    axs.axvline(new_day_idx, c='r', alpha=0.6, linewidth=1)
    # axs[1].axvline(new_day_idx, c='r')

    print(metrics_dict)
    axs.legend(ncols=4, prop={'size': 6})

    # fig.suptitle("RR Predictions during Rest Standing and Sitting")
    axs.set_title("Multi-Modal RR Predictions during Rest Standing and Sitting")
    # axs[1].set_title("Absolute Error over Time")

    axs.set_xlabel("Time (indices)")
    axs.set_ylabel("Respiration Rate (CPM)")
    # axs[1].set_ylabel("Abs. Error")
    fig.savefig('S01cal.png')