from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from regress_rr import load_tsfresh
from os.path import join
from os import makedirs
import ipdb
import joblib
from cycler import cycler

import pandas as pd

from config import *
from regress_rr import (
    get_activity_log, load_and_sync_xsens, get_respiration_log,
    get_cal_data
)
from regress_rr import get_test_data as get_raw_test_data
from modules.evaluations import Evaluation
from modules.digitalsignalprocessing import (
    movingaverage, butter_lowpass_filter)
from sklearn.preprocessing import (
    PolynomialFeatures, LabelEncoder, StandardScaler
)

plt.close('all')
plt.rcParams.update({'figure.titlesize'              : 8,
                     'axes.titlesize'                : 7,
                     'axes.labelsize'                : 7,
                     'xtick.labelsize'               : 6,
                     'ytick.labelsize'               : 6,
                     'legend.fontsize'               : 6,
                     'legend.title_fontsize'         : 6,
                    })
window_size = 20
window_shift = 0.05

# get the test data for the subject

# load imu, bvp, and imu+bvp data for the subject

# get the lin reg data for the subject

# plot the pss vs the linear regression data for the subject

# highlight standing periods

sbj = 'S01'
sens_list = ['imu', 'bvp', 'imu-bvp']
# cfg_ids = [4, 0, 0]
cfg_ids = [5, 11, 4]

lbl_str = 'pss'

mdl_dir = join(DATA_DIR, 'subject_specific', sbj)
mdl_file = 'linr_model.joblib'

tsfresh_dir = join(
    mdl_dir,
    'tsfresh__winsize_{0}__winshift_{1}'.format(window_size, window_shift)
)

combi_strs = ['combi7.0-10.0-12.0-15.0-17.0']*3

from matplotlib.animation import FuncAnimation
cm = 1/2.54

class AnimationPlotter():
    def __init__(self, x, data:list, duration:int, buff_len:int,
                 vname='test.mp4'):
        self.data = data
        self.fig, self.axs = plt.subplots(3, 1, dpi=200,
                                          figsize=(8*(4/3)*cm, 8*cm))
        self.duration = duration
        self.buff_len = buff_len
        self.x0 = x[0]
        self.x = x - x[0]
        self.vname = vname

        self.cm = [
            cycler('color', sns.color_palette('dark')),
            cycler('color', sns.color_palette('muted')),
            cycler('color', sns.color_palette('pastel')),
        ]

        self.lines=[]
        for i, ax in enumerate(self.axs):
            ax.set_prop_cycle(self.cm[i])
            ndim = self.data[i].shape[-1]
            l = ax.plot(np.arange(0, buff_len), np.zeros((buff_len, ndim)))
            self.lines.append(l)

            ax.set_xlabel('Time (s)')
            if i == 0:
                # twin = ax.twinx()
                ax.set_ylabel('Pressure')
                # twin.set_ylabel('BR')
            elif i == 1:
                ax.set_ylabel('Accelerometer')
            elif i == 2:
                ax.set_ylabel('Gyroscope')
        self.axs[1].legend(['x', 'y', 'z'], prop={'size': 6}, loc='upper left')
        self.axs[0].set_xticklabels({})
        self.axs[1].set_xticklabels({})

    def animate(self, frame):
        for i, m_line in enumerate(self.lines):
            data = self.data[i]
            sect = np.arange(frame, frame+self.buff_len)
            if isinstance(m_line, list):
                for j, m_l in enumerate(m_line):
                    m_l.set_data(self.x[sect], data[sect, j])
            else:
                m_line.set_data(self.x[sect], data[sect])
        for ax in self.axs: ax.relim(); ax.autoscale_view(True, True, True);
        plt.draw()
        # return self.lines

    def run(self):
        ani = FuncAnimation(self.fig, self.animate, frames=self.duration,
                            repeat=False, interval=9)
        ani.save(self.vname, writer='ffmpeg', codec='h264')
        # plt.show()
        return ani

def do_animation(sect_duration, buffer_duration, sens_list=['imu', 'bvp']):
    xsens_df    = load_and_sync_xsens(sbj, sens_list=sens_list)
    activity_df = get_activity_log(sbj).reset_index(drop=True)
    event_df    = get_respiration_log(sbj)

    cal_df = get_cal_data(event_df, xsens_df)

    # include standing or not
    test_df_tmp = get_raw_test_data(cal_df, activity_df, xsens_df, 1)
    test_df = pd.concat([df for df in test_df_tmp['data']], axis=0)

    # cmap = sns.color_palette('Paired')

    plot_strs = ['PSS']
    acc_cols = ['acc_x', 'acc_y', 'acc_z']
    gyro_cols = ['gyro_x', 'gyro_y', 'gyro_z']

    def run_anim(df, vname):
        df['PSS'] = StandardScaler().fit_transform(
            df.PSS.values.reshape(-1, 1)).flatten()
        df[acc_cols] = StandardScaler().fit_transform(df[acc_cols])
        df[gyro_cols] = StandardScaler().fit_transform(df[gyro_cols])
        x = df.sec.values

        data = (df[plot_strs].values,
                df[acc_cols].values,
                df[gyro_cols].values,
               )
        
        anim = AnimationPlotter(x, data, duration, buff_len, vname=vname)
        anim.run()

    start_idx = 20*IMU_FS
    buff_len = buffer_duration*IMU_FS
    duration = sect_duration*IMU_FS

    end_idx = start_idx + buff_len + duration

    # get 3 calibration steps, 5, 10, 20
    muted = sns.color_palette('muted')
    pastel = sns.color_palette('pastel')
    deep = sns.color_palette('deep')
    m_figs, m_axs = [], []
    # for i in [0, 2, 3, 4, 6]:
    for i in [3]:
        cpm = cal_df.cpm.iloc[i]
        df = cal_df.data.iloc[i].iloc[start_idx:end_idx, :]
        df['PSS'] = StandardScaler().fit_transform(
            df.PSS.values.reshape(-1, 1)).flatten()
        df[acc_cols] = StandardScaler().fit_transform(df[acc_cols])
        df[gyro_cols] = StandardScaler().fit_transform(df[gyro_cols])
        x = df.sec.values

        data = (df[plot_strs].values,
                df[acc_cols].values,
                df[gyro_cols].values,
               )
        
        anim = AnimationPlotter(x, data, duration, buff_len,
                                vname='cal'+str(int(cpm))+'.mp4')
        # anim.animate(1); plt.show()
        # ipdb.set_trace()

        '''
        fig, ax = plt.subplots(3, 1)
        ax[0].plot(df['PSS'], c='k', alpha=0.4)
        twin = ax[0].twinx()
        twin.plot(df['BR'], c='k')
        ax[1].set_prop_cycle(cycler('color', deep))
        ax[1].plot(StandardScaler().fit_transform(
            df[['acc_x', 'acc_y', 'acc_z']]))

        ax[2].set_prop_cycle(cycler('color', pastel))
        ax[2].plot(
            StandardScaler().fit_transform(df[['gyro_x', 'gyro_y', 'gyro_z']]),
        )
        m_figs.append(fig)
        m_axs.append(ax)
        '''
        anim.run()

    fmt = "%d/%m/%Y %H:%M:%S"
    activity_sec = activity_df.Timestamps.map(
        lambda x: datetime.strptime(x, fmt).timestamp()
    )
    seated_start, seated_end = activity_sec.iloc[2], activity_sec.iloc[3]
    standing_start, standing_end = activity_sec.iloc[4], activity_sec.iloc[5]

    # get seated no2, 7s
    seated_df = test_df[
        (test_df.sec > seated_start) & (test_df.sec < seated_end)
    ]
    seated_df = seated_df.iloc[start_idx:end_idx, :]
    ipdb.set_trace()
    run_anim(seated_df, 'seated.mp4')

    # get standing no2, 7s
    standing_df = test_df[
        (test_df.sec > standing_start) & (test_df.sec < standing_end)
    ]
    standing_df = standing_df.iloc[start_idx:end_idx, :]
    run_anim(standing_df, 'standing.mp4')

    ipdb.set_trace()

def get_model(sens, sbj, cfg_id, combi):
    m_dir = f'linreg_{sens}_rr_{sbj}_id{cfg_id}_{combi}'
    mdl_fname = join(mdl_dir, sens+'_rr', str(cfg_id).zfill(2), m_dir, mdl_file)
    return joblib.load(mdl_fname)

def get_test_data():
    test_fname = 'test__winsize_{0}__winshift_{1}__tsfresh.pkl'.format(
        window_size, window_shift)
    test_df = pd.read_pickle(join(tsfresh_dir, test_fname))
    y_cols = ['sec', 'br', 'pss', 'cpm']
    x_cols = [col for col in test_df.columns.values if col not in y_cols]
    x_test  = test_df[x_cols]
    y_test  = test_df[lbl_str].values.reshape(-1, 1)
    time = test_df['sec']
    return x_test, y_test, time

if __name__ == '__main__':
    activity_log = get_activity_log(sbj)
    standing_mask = activity_log['Activity'] == 'standing'
    standing_log = activity_log[standing_mask]

    do_animation(20, 15)

    x_test, y_test, time = get_test_data()
    y_out = {}
    metrics_dict = {}
    x_test_cols = x_test.columns
    y_test = movingaverage(y_test, 8)
    cm = 1/2.54
    fig, axs = plt.subplots(figsize=(14*cm, 7.5*cm), dpi=300)
    axs.plot(y_test, label='ground-truth', linewidth=1, c='k');

    standing_wins = 10*60 
    for sens, cfg_id, combi_str in zip(sens_list, cfg_ids, combi_strs):
        model = get_model(sens, sbj, cfg_id, combi_str)
        if sens == 'imu':
            cols = [col for col in x_test_cols if 'acc' in col or 'gyr' in col]
        elif sens == 'bvp':
            cols = [col for col in x_test_cols if 'bvp' in col]
        else:
            cols = x_test.columns

        poly = PolynomialFeatures(1)
        x_test_sens = x_test[cols]
        x_test_sens = poly.fit_transform(x_test_sens)

        y_pred = model.predict(x_test_sens)
        y_pred = movingaverage(y_pred, 64)

        evals = Evaluation(y_test.flatten(), y_pred.flatten())
        metrics_dict[sens]=evals.get_evals()

        axs.plot(y_pred, label=sens.upper(), linewidth=1)
        # axs.plot(np.abs(y_test - y_pred), label=sens.upper())

    y0 = axs.get_ylim()[-1]
    # y1 = axs[1].get_ylim()[-1]
    fmt = "%d/%m/%Y %H:%M:%S"
    x = np.arange(len(y_test))
    for start_, end_ in zip(
        standing_log.iloc[::2].iterrows(), standing_log.iloc[1::2].iterrows()
    ):
        start = start_[1]
        end = end_[1]
        start_sec = datetime.strptime(start['Timestamps'], fmt).timestamp()
        end_sec = datetime.strptime(end['Timestamps'], fmt).timestamp()
        mask = (time >= start_sec) & (time <= end_sec)
        axs.fill_between(x, y0, where=mask, facecolor='k', alpha=0.2)
        # axs[1].fill_between(x, y1, where=mask, facecolor='k', alpha=0.2)

    diff = np.diff(time)
    diff = np.insert(diff, 0, 0)
    new_day_idx = diff.argmax()

    axs.axvline(new_day_idx, c='r', alpha=0.6, linewidth=1)
    # axs[1].axvline(new_day_idx, c='r')

    print(metrics_dict)
    axs.legend(ncols=4, prop={'size': 6})

    # fig.suptitle("RR Predictions during Rest Standing and Sitting")
    axs.set_title("Multi-Modal RR Predictions during Rest Standing and Sitting")
    # axs[1].set_title("Absolute Error over Time")

    axs.set_xlabel("Time (sec)")
    axs.set_ylabel("Respiration Rate (CPM)")
    # axs[1].set_ylabel("Abs. Error")
    fig.savefig('S01cal.png')