import glob from os import makedirs, mkdir from os.path import join, exists import pandas as pd import numpy as np import json import ipdb import re import pickle import sys import time from zipfile import ZipFile import argparse from datetime import datetime, timedelta, timezone, timedelta import pytz import matplotlib.pyplot as plt from functools import partial from collections import Counter from itertools import repeat, chain, combinations from multiprocessing import Pool, cpu_count import tensorflow as tf from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder from sklearn.preprocessing import PolynomialFeatures, LabelEncoder from sklearn.model_selection import KFold, train_test_split from sklearn.metrics import accuracy_score from tsfresh.feature_extraction import extract_features from tsfresh.feature_extraction import settings as tsfresh_settings from tsfresh.utilities.string_manipulation import get_config_from_string from modules.datapipeline import get_file_list, load_and_snip, load_data, \ load_split_data, load_harness_data from modules.digitalsignalprocessing import vectorized_slide_win as vsw from modules.digitalsignalprocessing import imu_signal_processing from modules.digitalsignalprocessing import do_pad_fft,\ pressure_signal_processing, infer_frequency from modules.utils import * from modules.evaluations import Evaluation from modules.datapipeline import get_windowed_data, DataSynchronizer,\ parallelize_dataframe from modules.datapipeline import ProjectFileHandler from models.ardregression import ARDRegressionClass from models.knn import KNNClass from models.svm import SVMClass from models.lda import LDAClass from models.svr import SVRClass from models.logisticregression import LogisticRegressionClass from models.linearregression import LinearRegressionClass from models.neuralnet import FNN_HyperModel, LSTM_HyperModel, TunerClass,\ CNN1D_HyperModel from models.ridgeclass import RidgeClass from models.resnet import Regressor_RESNET, Classifier_RESNET from models.xgboostclass import XGBoostClass from pprint import PrettyPrinter from sktime.transformations.panel.rocket import ( MiniRocket, MiniRocketMultivariate, MiniRocketMultivariateVariable, ) from config import WINDOW_SIZE, WINDOW_SHIFT, IMU_FS, DATA_DIR, BR_FS IMU_COLS = ['acc_x', 'acc_y', 'acc_z', 'gyr_x', 'gyr_y', 'gyr_z'] def utc_to_local(utc_dt, tz=None): return utc_dt.replace(tzinfo=timezone.utc).astimezone(tz=tz) def datetime_from_utc_to_local(utc_datetime): now_timestamp = time.time() offset = datetime.fromtimestamp(now_timestamp) - datetime.utcfromtimestamp(now_timestamp) return utc_datetime + offset # Load data def load_bioharness_file(f:str, skiprows=0, skipfooter=0, **kwargs): df_list = [] method = partial(pd.read_csv, skipinitialspace=True, skiprows=list(range(1, skiprows+1)), skipfooter=skipfooter, header=0, **kwargs ) df = method(f) if 'Time' not in df.columns.values: df['Time'] = pd.to_datetime( df.rename(columns={'Date':'Day'})[ ['Day','Month','Year']]) \ + pd.to_timedelta(df['ms'], unit='ms') if pd.isna(df['Time']).any(): df['Time'].interpolate(inplace=True) df['Time'] = pd.to_datetime(df['Time'], format="%d/%m/%Y %H:%M:%S.%f") df['Time'] = df['Time'].dt.strftime("%d/%m/%Y %H:%M:%S.%f") return df def load_bioharness_files(f_list:list, skiprows=0, skipfooter=0, **kwargs): df_list = [] method = partial(pd.read_csv, skipinitialspace=True, skiprows=list(range(1, skiprows+1)), skipfooter=skipfooter, header=0, **kwargs) for f in f_list: df_list.append(load_bioharness_file(f)) df = pd.concat(df_list, ignore_index=True) return df def bioharness_datetime_to_seconds(val): fmt = "%d/%m/%Y %H:%M:%S.%f" dstr = datetime.strptime(val, fmt) seconds = dstr.timestamp() return seconds def load_imu_file(imu_file:str): hdr_file = imu_file.replace('imudata.gz', 'recording.g3') df = pd.read_json(imu_file, lines=True, compression='gzip') hdr = pd.read_json(hdr_file, orient='index') hdr = hdr.to_dict().pop(0) if df.empty: return df, hdr data_df = pd.DataFrame(df['data'].tolist()) df = pd.concat([df.drop('data', axis=1), data_df], axis=1) iso_tz = hdr['created'] tzinfo = pytz.timezone(hdr['timezone']) # adjust for UTC start_time = datetime.fromisoformat(iso_tz[:-1]) start_time = utc_to_local(start_time, tz=tzinfo).astimezone(tzinfo) na_inds = df.loc[pd.isna(df['accelerometer']), :].index.values df.drop(index=na_inds, inplace=True) imu_times = df['timestamp'].values df['timestamp_interp'] = imu_times df['timestamp_interp'] = df['timestamp_interp'].interpolate() imu_times = df['timestamp_interp'].values imu_datetimes = [start_time + timedelta(seconds=val) \ for val in imu_times] imu_s = np.array([time.timestamp() for time in imu_datetimes]) df['sec'] = imu_s time_check_thold = df['sec'].min() + 3*3600 mask = df['sec'] > time_check_thold if np.any(mask): df = df[np.logical_not(mask)] return df, hdr def load_imu_files(f_list:list): data, hdr = [], [] tmp = [] for f in f_list: tmp.append(load_imu_file(f)) for l in tmp: data.append(l[0]) hdr.append(l[1]) data_df = pd.concat(data, axis=0) return data_df, hdr def load_e4_file(e4_file:str): ''' First row is the initial time of the session as unix time. Second row is the sample rate in Hz''' zip_file = ZipFile(e4_file) dfs = {csv_file.filename: pd.read_csv(zip_file.open(csv_file.filename) ,header=None) for csv_file in zip_file.infolist() if csv_file.filename.endswith('.csv')} bvp = dfs["BVP.csv"] t0 = bvp.iloc[0].values[0] fs = bvp.iloc[1].values[0] nsamples = len(bvp) - 2 t0_datetime = datetime.utcfromtimestamp(t0) t0_local = datetime_from_utc_to_local(t0_datetime) time = [t0_local.timestamp() + ind*(1/fs) for ind in range(nsamples)] tmp = [np.nan, np.nan] time = tmp + time bvp.rename(columns={0: "bvp"}, inplace=True) bvp['sec'] = np.array(time) head = bvp.iloc[[0, 1]] bvp.drop(inplace=True, index=[0, 1]) hdr = {'start_time': head.iloc[0,0], 'fs': head.iloc[0,1]} return bvp, hdr def load_e4_files(f_list:list): tmp = [] data = [] hdr = [] for f in f_list: tmp.append(load_e4_file(f)) for d, h in tmp: data.append(d) hdr.append(h) data_df = pd.concat(data, axis=0) return data_df, hdr # Synchronising data def sync_to_ref(df0, df1): dsync0 = DataSynchronizer() dsync1 = DataSynchronizer() time0 = df0['sec'].values time1 = df1['sec'].values t0 = max((time0[0], time1[0])) t1 = min((time0[-1], time1[-1])) dsync0.set_bounds(time0, t0, t1) dsync1.set_bounds(time1, t0, t1) return dsync0.sync_df(df0), dsync1.sync_df(df1) def pss_br_calculations(win, pss_df=None, br_df=None): n_out = 5 if win[-1] == 0: return [None]*n_out dsync = DataSynchronizer() pss_fs = BR_FS pss_col = [col for col in pss_df.columns.values if\ 'breathing' in col.lower()][0] pss_ms = pss_df['ms'].values br_ms = br_df['ms'].values t0, t1 = pss_ms[win][0], pss_ms[win][-1] diff = pss_ms[win][1:] - pss_ms[win][:-1] mask = np.abs(diff/1e3) > 60 diff_chk = np.any(mask) if diff_chk: return [None]*n_out # Get pressure estimate for window pss_win = pss_df.iloc[win] pss_data = pss_win[pss_col] pss_filt = pressure_signal_processing(pss_data, fs=pss_fs) xf, yf = do_pad_fft(pss_filt, fs=pss_fs) pss_est = xf[yf.argmax()]*60 # Sync and get summary br output dsync.set_bounds(br_ms, t0, t1) br_win = dsync.sync_df(br_df) br_out = np.median(br_win['BR'].values) # Get subject and condition sbj_out = pss_win['subject'].values[0] time_out = np.median(pss_win['sec'].values) return time_out, pss_est, br_out, sbj_out, cond_out def get_pss_br_estimates(pss_df, br_df, window_size=12, window_shift=1): pss_fs = BR_FS # pss_col = [col for col in pss_df.columns.values if\ # 'breathing' in col.lower()][0] pss_ms = pss_df['sec'].values br_ms = br_df['sec'].values inds = np.arange(0, len(pss_ms)) vsw_out = vsw(inds, len(inds), sub_window_size=int(window_size*pss_fs), stride_size=int(window_shift*pss_fs)) # dsync = DataSynchronizer() pss_est, br_out = [], [] cond_out, sbj_out = [], [] func = partial(pss_br_calculations, pss_df=pss_df, br_df=br_df) # for i, win in enumerate(vsw_out): # tmp = func(win) with Pool(cpu_count()) as p: tmp = p.map(func, vsw_out) time_out, pss_est, br_out, sbj_out, cond_out = zip(*tmp) time_array = np.array(time_out) pss_est_array = np.array(pss_est) br_out_array = np.array(br_out) sbj_out_array = np.array(sbj_out) cond_out_array = np.array(cond_out) df = pd.DataFrame( np.array( [time_array, sbj_out_array, cond_out_array, pss_est_array, br_out_array] ).T, columns=['ms', 'subject', 'condition', 'pss', 'br']) df.dropna(inplace=True) return df # Multiprocessing task for windowing dataframe def imu_df_win_task(w_inds, df, i, cols): time = df['sec'].values if w_inds[-1] == 0: return w_df = df.iloc[w_inds] t0, t1 = time[w_inds][0], time[w_inds][-1] diff = time[w_inds[1:]] - time[w_inds[0:-1]] mask = np.abs(diff)>20 diff_chk = np.any(mask) if diff_chk: return # sbj = w_df['subject'].values.astype(int) # sbj_mask = np.any((sbj[1:] - sbj[:-1])>0) # if sbj_mask: # return if cols is None: cols = ['acc_x', 'acc_y', 'acc_z', 'gyr_x', 'gyr_y', 'gyr_z'] data = w_df[cols].values # DSP sd_data = (data - np.mean(data, axis=0))/np.std(data, axis=0) # ys = cubic_interp(sd_data, BR_FS, FS_RESAMPLE) filt_data = imu_signal_processing(sd_data, IMU_FS) x_out = pd.DataFrame(filt_data, columns=[ 'acc_x', 'acc_y', 'acc_z', 'gyro_x', 'gyro_y', 'gyro_z', ]) sm_out = w_df['BR'].values ps_out = w_df['PSS'].values x_vec_time = np.median(time[w_inds]) fs = 1/np.mean(diff) ps_freq = int(get_max_frequency(ps_out, fs=fs)) y_tmp = np.array([x_vec_time, np.nanmedian(sm_out), ps_freq]) x_out['sec'] = x_vec_time x_out['id'] = i y_out = pd.DataFrame([y_tmp], columns=['sec', 'br', 'pss']) return x_out, y_out def get_max_frequency(data, fs=IMU_FS): data = pressure_signal_processing(data, fs=fs) xf, yf = do_pad_fft(data, fs=fs) max_freq = xf[yf.argmax()]*60 return max_freq def convert_to_float(df): cols = df.columns.values if 'sec' in cols: df['sec'] = df['sec'].astype(float) if 'pss' in cols: df['pss'] = df['pss'].astype(float) if 'br' in cols: df['br'] = df['br'].astype(float) if 'subject' in cols: df['subject'] = df['subject'].astype(float) def load_and_sync_xsens(subject): # load imu imu_list = get_file_list('imudata.gz', sbj=subject) imu_df_all, imu_hdr_df_all = load_imu_files(imu_list) # load bioharness pss_list = get_file_list('*Breathing.csv', sbj=subject) if len(pss_list) == 0: pss_list = get_file_list('BR*.csv', sbj=subject) br_list = get_file_list('*Summary*', sbj=subject) # load e4 wristband e4_list = get_file_list('*.zip', sbj=subject) bvp_df_all, bvp_hdr = load_e4_files(e4_list) bvp_fs = bvp_hdr[0]['fs'] xsens_list = [] # skip the first and last x minute(s) minutes_to_skip = .5 br_skiprows = br_skipfooter = int(minutes_to_skip*60) pss_skiprows = pss_skipfooter = int(minutes_to_skip*60*BR_FS) # load each bioharness file and sync the imu to it for pss_file, br_file in zip(pss_list, br_list): pss_df = load_bioharness_file(pss_file, skiprows=pss_skiprows, skipfooter=pss_skipfooter, engine='python') pss_time = pss_df['Time'].map(bioharness_datetime_to_seconds).values\ .reshape(-1, 1) pss_df['sec'] = pss_time br_df = load_bioharness_file(br_file, skiprows=br_skiprows, skipfooter=br_skipfooter, engine='python') br_time = br_df['Time'].map(bioharness_datetime_to_seconds).values\ .reshape(-1, 1) br_df['sec'] = br_time # sync br_df, imu_df = sync_to_ref(br_df, imu_df_all.copy()) pss_df, _ = sync_to_ref(pss_df, imu_df_all.copy()) bvp_df, _ = sync_to_ref(bvp_df_all.copy(), pss_df.copy()) # extract relevant data acc_data = np.stack(imu_df['accelerometer'].values) gyr_data = np.stack(imu_df['gyroscope'].values) x_time = imu_df['sec'].values.reshape(-1, 1) br_col = [col for col in pss_df.columns.values if\ 'breathing' in col.lower()][0] pss_data = pss_df[br_col].values pss_data = np.interp(x_time, pss_df['sec'].values, pss_data)\ .reshape(-1, 1) br_lbl = [col for col in br_df.columns.values if\ 'br' in col.lower()][0] br_data = br_df['BR'].values br_data = np.interp(x_time, br_df['sec'].values, br_data)\ .reshape(-1, 1) bvp_data = bvp_df['bvp'].values bvp_data = np.interp(x_time, bvp_df['sec'].values, bvp_data)\ .reshape(-1, 1) xsens_data = np.concatenate( (x_time, br_data, pss_data, bvp_data, acc_data, gyr_data), axis=1) columns=['sec' , 'BR' , 'PSS' , 'BVP' , 'acc_x' , 'acc_y' , 'acc_z' , 'gyr_x' , 'gyr_y' , 'gyr_z' , ] xsens_df_tmp = pd.DataFrame(xsens_data, columns=columns) ''' print("{:.2f}\t{:.2f}\t{:.2f}".format(br_df.sec.iloc[0], pss_df.sec.iloc[0], imu_df.sec.iloc[0])) print("{:.2f}\t{:.2f}\t{:.2f}".format(br_df.sec.iloc[-1], pss_df.sec.iloc[-1], imu_df.sec.iloc[-1])) print(xsens_df_tmp.head()) ''' xsens_list.append(xsens_df_tmp) if len(xsens_list) > 1: xsens_df = pd.concat(xsens_list, axis=0, ignore_index=True) xsens_df.reset_index(drop=True, inplace=True) else: xsens_df = xsens_list[0] return xsens_df def load_tsfresh(subject, project_dir, window_size=12, window_shift=0.2, fs=IMU_FS, overwrite=False): cols = ['acc_x', 'acc_y', 'acc_z', 'gyro_x', 'gyro_y', 'gyro_z'] pkl_file = join(project_dir, 'tsfresh.pkl') if exists(pkl_file) and not overwrite: return pd.read_pickle(pkl_file) xsens_df = load_and_sync_xsens(subject) x_df, y_df = get_df_windows(xsens_df, imu_df_win_task, window_size=window_size, window_shift=window_shift, fs=fs, ) x_features_df = extract_features( x_df, column_sort='sec', column_id='id', # default_fc_parameters=tsfresh_settings.MinimalFCParameters(), ) x_features_df.fillna(0, inplace=True) cols = x_features_df.columns.values df_out = pd.concat([y_df, x_features_df], axis=1) df_out.to_pickle(pkl_file) return df_out def get_activity_log(subject): activity_list = get_file_list('activity*.csv', sbj=subject) activity_dfs = [pd.read_csv(f) for f in activity_list] return pd.concat(activity_dfs, axis=0) def get_respiration_log(subject): log_list = get_file_list('*.json', sbj=subject) log_dfs = [pd.read_json(f) for f in log_list] return pd.concat(log_dfs, axis=0) def get_cal_data(event_df, xsens_df): fmt ="%Y-%m-%d %H.%M.%S" cal_list = [] cpms = [] start_sec = 0 stop_sec = 0 for index, row in event_df.iterrows(): event = row['eventTag'] timestamp = row['timestamp'] inhalePeriod = row['inhalePeriod'] exhalePeriod = row['exhalePeriod'] cpm = np.round( 60/(inhalePeriod + exhalePeriod) ) sec = timestamp.to_pydatetime().timestamp() if event == 'Start': start_sec = sec continue elif event == 'Stop': stop_sec = sec dsync = DataSynchronizer() dsync.set_bounds(xsens_df['sec'].values, start_sec, stop_sec) sync_df = dsync.sync_df(xsens_df.copy()) cal_data = {'cpm': cpm, 'data': sync_df} cal_list.append(cal_data) assert np.round(sync_df.sec.iloc[0])==np.round(start_sec), \ "error with start sync" assert np.round(sync_df.sec.iloc[-1])==np.round(stop_sec), \ "error with stop sync" return pd.DataFrame(cal_list) def get_test_data(cal_df, activity_df, xsens_df): fmt = "%d/%m/%Y %H:%M:%S" start_time = cal_df.iloc[-1]['data'].sec.values[-1] data_df = xsens_df[xsens_df.sec > start_time] activity_start = 0 activity_end = 0 activity_list = [] for index, row in activity_df.iterrows(): sec = datetime.strptime(row['Timestamps'], fmt) if row['Event'] == 'start': activity_start = sec elif row['Event'] == 'stop': activity_stop = sec dsync = DataSynchronizer() dsync.set_bounds(data_df['sec'].values, activity_start, activity_stop) sync_df = dsync.sync_df(data_df.copy()) activity_data = {'activity': row['Activity'], 'data': sync_df} activity_list.append(activity_data) return pd.DataFrame(activity_list) # save evaluation metrics in single file that handles the models for the # subject and config class EvalHandler(): def __init__(self, y_true, y_pred, subject, pfh, mdl_str, overwrite=False): self.subject = subject self.config = pfh.config self.parent_directory = join(DATA_DIR, 'subject_specific') self.fset_id = pfh.fset_id self.mdl_str = mdl_str self.overwrite = overwrite self.evals = Evaluation(y_true, y_pred) entry = {'subject': self.subject, 'config_id': self.fset_id, 'mdl_str': self.mdl_str, } self.entry = {**entry, **self.config, **self.evals.get_evals()} self.eval_history_file = join(self.parent_directory, 'eval_history.csv') self.eval_hist = self.load_eval_history() def load_eval_history(self): if not exists(self.eval_history_file): return None else: return pd.read_csv(self.eval_history_file) def update_eval_history(self): eval_hist = self.eval_hist if eval_hist is None: eval_hist = pd.DataFrame([self.entry]) else: index_list = eval_hist[ (eval_hist['subject'] == self.entry['subject']) &\ (eval_hist['config_id'] == self.entry['config_id']) &\ (eval_hist['mdl_str'] == self.entry['mdl_str'])\ ].index.tolist() if len(index_list) == 0: print("adding new entry") eval_hist = eval_hist._append(self.entry, ignore_index=True) elif index_list is not None and self.overwrite: eval_hist.loc[index_list[0]] = self.entry self.eval_hist = eval_hist def save_eval_history(self): self.eval_hist.to_csv(self.eval_history_file, index=False) # Train IMU - RR models across subjects def imu_rr_model(subject, window_size=12, window_shift=0.2, lbl_str='pss', mdl_str='knn', overwrite=False, feature_method='tsfresh', train_len:int=3, test_standing=False, ): # window_size, window_shift, intra, inter cal_str = 'cpm' fs = IMU_FS tmp = [] imu_cols = ['acc_x', 'acc_y', 'acc_z', 'gyro_x', 'gyro_y', 'gyro_z'] do_minirocket = False use_tsfresh = False overwrite_tsfresh = True train_size = int(train_len) if feature_method == 'tsfresh': use_tsfresh = True elif feature_method == 'minirocket': do_minirocket = True config = {'window_size' : window_size, 'window_shift' : window_shift, 'lbl_str' : lbl_str, 'do_minirocket' : do_minirocket, 'use_tsfresh' : use_tsfresh, 'train_len' : train_len, } pfh = ProjectFileHandler(config) pfh.set_home_directory(join(DATA_DIR, 'subject_specific', subject)) pfh.set_parent_directory('imu_rr') id_check = pfh.get_id_from_config() if id_check is None: pfh.set_project_directory() pfh.save_metafile() else: pfh.set_id(int(id_check)) pfh.set_project_directory() print('Using pre-set data id: ', pfh.fset_id) project_dir = pfh.project_directory marker = f'imu_rr_{subject}_id{pfh.fset_id}' if not use_tsfresh: xsens_df = load_and_sync_xsens(subject) else: xsens_df = load_tsfresh(subject, project_dir, window_size=window_size, window_shift=window_shift, fs=IMU_FS, overwrite=overwrite_tsfresh) activity_df = get_activity_log(subject) event_df = get_respiration_log(subject) cal_df = get_cal_data(event_df, xsens_df) # include standing or not test_df = get_test_data(cal_df, activity_df, xsens_df) ipdb.set_trace() for combi in combinations(cal_df[cal_str].values, train_len): config[cal_cpm] = combi train_df = pd.concat( [cal_df[cal_df[cal_cpm] == cpm]['data'] for cpm in combi], axis=0 ) assert np.isin(train_df.index.values, test_df.index.values).any()==False,\ "overlapping test and train data" print("train") print(train_df.shape) print("test") print(test_df.shape) if do_minirocket: x_train_df, y_train_df = get_df_windows(train_df, imu_df_win_task, window_size=window_size, window_shift=window_shift, fs=fs, ) x_test_df, y_test_df = get_df_windows(test_df, imu_df_win_task, window_size=window_size, window_shift=window_shift, fs=fs, ) x_train = make_windows_from_id(x_train_df, imu_cols) x_test = make_windows_from_id(x_test_df, imu_cols) y_train = y_train_df[lbl_str].values.reshape(-1, 1) y_test = y_test_df[lbl_str].values.reshape(-1, 1) print("minirocket transforming...") x_train = np.swapaxes(x_train, 1, 2) x_test = np.swapaxes(x_test, 1, 2) minirocket = MiniRocketMultivariate() x_train = minirocket.fit_transform(x_train) x_test = minirocket.transform(x_test) elif use_tsfresh: x_train = train_df.iloc[:, 3:].values y_train = train_df[lbl_str].values.reshape(-1, 1) x_test = test_df.iloc[:, 3:].values y_test = test_df[lbl_str].values.reshape(-1, 1) else: x_train_df, y_train_df = get_df_windows(train_df, imu_df_win_task, window_size=window_size, window_shift=window_shift, fs=fs, ) x_test_df, y_test_df = get_df_windows(test_df, imu_df_win_task, window_size=window_size, window_shift=window_shift, fs=fs, ) x_train = make_windows_from_id(x_train_df, imu_cols) x_test = make_windows_from_id(x_test_df, imu_cols) y_train = y_train_df[lbl_str].values.reshape(-1, 1) y_test = y_test_df[lbl_str].values.reshape(-1, 1) transforms, model = model_training(mdl_str, x_train, y_train, marker, validation_data=None, overwrite=overwrite, is_regression=True, project_directory=project_dir, window_size=int(window_size*fs), extra_train=200, ) if transforms is not None: x_test = transforms.transform(x_test) preds = model.predict(x_test) eval_handle = EvalHandler(y_test.flatten(), preds.flatten(), subject, pfh, mdl_str, overwrite=overwrite) eval_handle.update_eval_history() eval_handle.save_eval_history() pp = PrettyPrinter() pp.pprint(eval_handle.load_eval_history()) fig, ax = plt.subplots() fig_title = ' '.join([mdl_str, subject, combi]) ax.plot(y_test) ax.plot(preds) ax.set_title(fig_title) ax.legend([lbl_str, 'pred']) fig_dir = join(project_dir, 'figures',) if not exists(fig_dir): mkdir(fig_dir) fig.savefig(join(fig_dir, mdl_str)) def arg_parser(): parser = argparse.ArgumentParser() parser.add_argument("-m", '--model', type=str, default='linreg', choices=['linreg', 'ard', 'xgboost', 'knn', 'svr', 'cnn1d', 'fnn', 'lstm', 'ridge', 'elastic'], ) parser.add_argument("-s", '--subject', type=int, default=2, choices=list(range(1,3))+[-1], ) parser.add_argument("-f", '--feature_method', type=str, default='minirocket', choices=['tsfresh', 'minirocket', 'None'] ) parser.add_argument("-o", '--overwrite', type=int, default=0, ) parser.add_argument('--win_size', type=int, default=12, ) parser.add_argument('--win_shift', type=float, default=0.2, ) parser.add_argument('-l', '--lbl_str', type=str, default='pss', ) parser.add_argument('-tl', '--train_len', type=int, default=3, help='minutes of data to use for calibration' ) args = parser.parse_args() return args if __name__ == '__main__': # choose either intra or inter subject features to use for model training # '[!M]*' np.random.seed(100) n_subject_max = 2 args = arg_parser() mdl_str = args.model subject = args.subject feature_method = args.feature_method window_size = args.win_size window_shift = args.win_shift lbl_str = args.lbl_str train_len = args.train_len overwrite = args.overwrite print(args) assert train_len>0,"--train_len must be an integer greater than 0" subject_pre_string = 'Pilot' if subject > 0: subject = subject_pre_string+str(subject).zfill(2) imu_rr_model(subject, window_size=window_size, window_shift=window_shift, lbl_str=lbl_str, mdl_str=mdl_str, overwrite=overwrite, feature_method=feature_method, train_len=train_len ) else: subjects = [subject_pre_string+str(i).zfill(2) for i in \ range(1, n_subject_max+1) if i not in imu_issues] imu_rr_func = partial(imu_rr_model, window_size=window_size, window_shift=window_shift, lbl_str=lbl_str, mdl_str=mdl_str, overwrite=overwrite, feature_method=feature_method, train_len=train_len ) if mdl_str in ['fnn', 'lstm', 'cnn1d', 'elastic', 'ard', 'xgboost']: for subject in subjects: imu_rr_func(subject) else: ncpu = min(len(subjects), cpu_count()) with Pool(ncpu) as p: p.map(imu_rr_func, subjects) print(args)