Newer
Older
import argparse
import ipdb
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import seaborn as sns
sns.set_theme()
from tqdm import tqdm
import glob
import re
import pandas as pd
import json
import cv2
import tsfresh
from tsfresh import extract_features
from tsfresh.feature_selection import relevance as tsfresh_relevance
from tsfresh.feature_extraction import settings as tsfresh_settings
from datetime import datetime, timedelta
from os import makedirs, listdir
from os.path import isdir, getsize
from os.path import join as path_join
from os.path import exists as path_exists
from pprint import PrettyPrinter
from multiprocessing import Pool, cpu_count
from functools import partial
from ast import literal_eval
from modules.animationplotter import AnimationPlotter, AnimationPlotter2D
from modules.digitalsignalprocessing import imu_signal_processing
from modules.digitalsignalprocessing import vectorized_slide_win
from modules.digitalsignalprocessing import get_video_features
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
,sec_to_datetime, DataSynchronizer
from modules.datapipeline import get_file_list, load_files_conditions
from modules.evaluations import Evaluation
from modules.utils import map_condition_to_tlx
# from skimage.util import view_as_blocks
from config import DEBUG, NROWS, MARKER_FS, BR_FS, N_MARKERS , WINDOW_SIZE \
,WINDOW_SHIFT, ACC_THOLD, FS_RESAMPLE, MIN_RESP_RATE, MAX_RESP_RATE \
,TRAIN_VAL_TEST_SPLIT, IMU_FS, DATA_DIR
DIR = './evaluations/'
def load_df(subject, condition):
fs = IMU_FS
# set glob based on sensor type and subject and condition
data_glob = f"*{condition}_xsens_df"
data_list = get_file_list(data_glob, sbj=subject)
if len(data_list) == 0:
return None
# load the files in the glob
data_df = load_files_conditions(data_list, skip_ratio=0.0,
do_multiprocess=False)
return data_df
def get_subject_tsfresh(fs, data_df, df_path, tlx_df,
window_size=WINDOW_SIZE,
window_shift=WINDOW_SHIFT):
if tlx_df is None:
tlx_df = pd.read_csv('seated_nasa_tlx.csv', index_col=[0]).T
x_time = data_df['sec'].values
x_times = []
x, y = [], []
c_out = []
x_df_out = pd.DataFrame()
inds = np.arange(len(data_df))
vsw = vectorized_slide_win
wins = vsw(inds, len(inds), sub_window_size=int(window_size*fs),
stride_size=int(window_shift*fs))
for i, w_inds in enumerate(wins):
if w_inds[-1] == 0: break
x_df = data_df.iloc[w_inds]
t0, t1 = x_time[w_inds][0], x_time[w_inds][-1]
diff = x_time[w_inds[1:]] - x_time[w_inds[0:-1]]
mask = diff>20
diff_chk = np.any(mask)
if diff_chk:
continue
data = x_df[['acc_x', 'acc_y', 'acc_z',
'gyr_x', 'gyr_y', 'gyr_z']].values
# DSP
sd_data = (data - np.mean(data, axis=0))/np.std(data, axis=0)
# ys = cubic_interp(sd_data, BR_FS, FS_RESAMPLE)
filt_data = imu_signal_processing(sd_data, fs)
df_sp = pd.DataFrame(filt_data,
columns=['acc_x', 'acc_y', 'acc_z',
'gyro_x', 'gyro_y', 'gyro_z'])
sm_out = data_df['BR'].values
cd_out = data_df['condition'].values
ps_out = data_df['PSS'].values
sbj = data_df['subject'].values
x_vec_time = np.median(x_time[w_inds])
y_arr = np.array([x_vec_time, np.nanmedian(sm_out),
np.nanmedian(ps_out), cd_out[0], sbj[0]])
y_out = pd.DataFrame([y_arr], columns=['sec', 'br', 'pss', 'condition',
'subject'])
map_condition_to_tlx(y_out, tlx_df['S'+str(sbj[0])])
x_times.append(x_vec_time)
df_sp['id'] = i
df_sp['time'] = x_vec_time
if x_df_out.empty: x_df_out = df_sp
else: x_df_out = pd.concat([x_df_out, df_sp])
x.append(df_sp)
y.append(y_out)
if len(x) == 0 or len(y) == 0:
return None
y_df = pd.concat(y)
y_df.reset_index(drop=True, inplace=True)
# Create tsfresh features; check for NaN values
x_features_df = extract_features(x_df_out, column_sort='time',
column_id='id', n_jobs=1,
# default_fc_parameters=tsfresh_settings.MinimalFCParameters(),
)
x_features_df.fillna(0, inplace=True)
cols = x_features_df.columns.values
df_out = pd.concat([y_df, x_features_df], axis=1)
df_out.to_pickle(df_path)
return df_out
def plot_video_features(video_fname, nframes=1000):
cap = cv2.VideoCapture(video_fname)
xmin, ymin = 0, 0
boxw = 1080
boxh = 1920
fps = 25
box = np.array([[xmin,ymin], [xmin+boxw,ymin],
[xmin,ymin+boxh], [xmin+boxw,ymin+boxh]])
fig, ax = plt.subplots()
for f_id in range(nframes):
_, frame = cap.read()
if f_id == 0:
oldframe_gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
oldPts = get_video_features(oldframe_gray, use_shi=True)
else:
frame_gray = cv2.cvtColor(frame,cv2.COLOR_RGB2GRAY)
nextPts, pt_status, err = cv2.calcOpticalFlowPyrLK(
oldframe_gray, frame_gray, oldPts, None, winSize=(15,15))
inds = np.arange(len(pt_status))
mask = pt_status == 1
inds = inds[mask.squeeze()]
npts = nextPts.squeeze()
ax.imshow(frame_gray,cmap='gray')
# plot and track only if the status has been found
if len(inds) > 0:
ax.scatter(npts[inds,0], npts[inds,1], color=(1,0,0), marker='o')
# update
oldframe_gray = frame_gray
if not f_id%(20*fps):
# extract harmonic data from here. Pixel displacement / dt
oldPts = get_video_features(oldframe_gray, use_shi=True)
plt.pause(0.01)
plt.cla()
ipdb.set_trace()
def map_imu_tsfresh_subject(subject,
window_size=5, window_shift=0.2):
pfh = ProjectFileHandler({})
pfh.set_home_directory(join(DATA_DIR, 'subject_specific', subject))
pfh.home_directory,
"{0}__winsize_{1}__winshift_{2}_tsfresh_df.pkl"\
.format(window_size, window_shift))
if path_exists(tsfresh_pkl): continue
print(f"trying {subject} for {condition}")
data_df = load_df(subject, condition)
if data_df is None: continue
get_subject_tsfresh(fs, data_df, tsfresh_pkl, tlx_df,
window_size=window_size,
window_shift=window_shift)
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument('-w', '--win_size', type=str,
default=50,
)
args = parser.parse_args()
if len(args.win_size) == 1:
win_size = [int(args.win_size)]
else:
win_size = [int(item) for item in args.win_size.split(',')]
return win_size
if __name__ == '__main__':
# TODO
window_sizes = [5, 7, 10, 12, 15, 17, 20, 22, 25, 27, 30]
subjects = ['Pilot'+str(i).zfill(2) for i in range(2,4)] # mars13
for window_shift in window_shifts:
for window_size in window_sizes:
for subject in subjects:
map_imu_tsfresh_subject(subject,