From c5e355238ab977d31443a0757c41feb569d5c591 Mon Sep 17 00:00:00 2001 From: Raymond Chia <rqchia@janus0.ihpc.uts.edu.au> Date: Fri, 1 Dec 2023 23:24:35 +1100 Subject: [PATCH] validation split and nsbjs fix --- models/__pycache__/neuralnet.cpython-38.pyc | Bin 12764 -> 12746 bytes models/neuralnet.py | 4 +--- modules/__pycache__/utils.cpython-38.pyc | Bin 11269 -> 11369 bytes modules/utils.py | 24 +++++++++++++++----- regress_rr.py | 2 +- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/models/__pycache__/neuralnet.cpython-38.pyc b/models/__pycache__/neuralnet.cpython-38.pyc index cbe052be930dd43036f0f60dbcc2069a4a430de5..682fd872d624cdd604379504894e89096240c244 100644 GIT binary patch delta 1372 zcmZvbe{2(F7{~iud)I5bma;LnHr=|;vhB9g7God;e+`1n7{?z|6Dy|H-i_T>+VWmU zve^csB%m%ud`PTj0)l2U{^c$)n1C=DNC+{}7-Jg$(8M_ZlfRfT5q+O43^ZJGpL_0k zzR&YM?{jZ|%3jNQpY?j(Hh%5Bq-)X7yem$(oxeTKZ=(l&a%i2bN~-e^q(oI5u_>;) z?}_1IOa|M{XXNYR!QrgLrFSk5H$;BZ(-`A8Y#Rj?RN;_Zah9EBd)e{2P?5``YMXGe z?c5_)Z53Bp7`K(>Q7F5P*v3J%&kAMfU7-pUcfMP_6X&_30>$|)sE&q4B2n3G&if`H zAk^0!$mDc2Q_>6hG=F5+^SlQ|z^(oec<?!Y82$Ys9`q-69p-^q>N3h<gC>_-kb!;z zIBq%Y!r!vw05>>pze(bBGv9Z5XsW2Orwgi<i;wCh1#rdxphUG;2sf>7F}JT>1gN5; zMT0NQ@s^4J$ME;ohrP$RL=o5vq^y!sta0-@^U3u`z+WZrFrkZZ6n|-}K@F$dndj)L zw~8z6F<8cEV662G^1exUi?D}K!w&=9VI;GIFvD?{-v)gj8JYreE1SU^&UHAk)@A+^ z_zO0_O*$(%M%HdZ7vVHV%A;y|LrYKQ^^(EPkozOT$CwDE0`HJ%EfylnT9GPsnx{iA z3PCDx&yMcR4up5ZDb&OLFo(y(37E&F@TV|~#jgErRvL*@?qZ>kOB?#MX7CmN?CSK) zQOSA2Cm4xDU=Vji=62q15>q5PL9lwWaCa^zd7uOMTpZlsb9sQBq^X!Yx+etBIm)_7 zIB$I<+R}UGw^1jwo8<LzD9*a8@bb8pE$Ig|wum=21RuR>r4YU(T;xbKu}oeq>@&0s z%Z@1$yGD_(2>(~(P){ggIqz!DlW?_|(!?8D?&*MYc&2Ad^AZm#?#5|q6r-SIb-!da zeu*nSH+y^609?Q=eUC&Jsh*!OLb$?F_wLNF+|)>^P%P^Cz3c+k`nJh+Nh=nzV+Q`% zw@bXtWj|w7Iq}f<t643R{8a3l%$4*aD`Yjp;JJ#!s>rCZy42V*AI(iA3|Fjve&B&W zFy4O_mT<Yhvzg{l7mG~KmspZYU&Xcog$D-*Ho{FT47|^6cdR2w!jxQMdM20FRc0q` zV>UP#3v$t(*kd5#Z0wdqi}O0CHtq)>HpgG}d`*tq-VfqK5X9f&mH&A2iLDUE`NTFb z@z+Fe@P5aMM#(PYhU8{(ftxjaDH(wUd@Y%T>u4sw@z6hw(XnC+IJ4;~QTJ>9SN;V) Cm}r#% delta 1558 zcmZvcOKcle6hQsP<G-Dle5hj_IlrcHo!Abl3ROW=f<mZ7x(HohQdtwvq?yFA&6`O@ zf~}w?B2m+Pnu~yx0;!-sSRgcAAfV6|8VJ7yR4fb&6p2Mwz=|&T-8)Vbs4}B>=Fa`S z`|g`JU#5Rb`wsbhZVO&BdapA1p6?HbTfnz{IJ6e1aE5#~**<AW+GIzMCE_&pIB(j< z(+)J>zPu>dad>cO8N`&OpUF%XCY?oR!C9~t#0jD3D%fPpkc&Tp8rz(u=q?C@mV&FF z72E~;v}KUV*3;I4<7Huv6+PJ&`L+i}_X}kY^b%wdy<F>V1&?vQVu-a`OO;2Zj4DgI zn#-o}A(iY(K9?TQ_)}O<#LB1e0a1U5d7;Z6W>b*%-~aeDDo^9k+7K4zb-4v6qg%jR z{7V#tKb;Pdx#?~F-EJSNdbGPc)is(|_!Bu<$;A3q9oGE!vT5i{1dNWmt}-?czg8*i zL*r0&QDBSkeeDC=7jcG&z@NqkpC^PiP*2hl;)b#7?rG*X&E2GIAS}Sw+g@Qypai&Q zVM99)Yk_ul6DsQlYLAn4hH!$=OIU&vbxq+HNi_4%qVxd{jdse(ymii6J`2A&0>&?O z|FGRB$<YjKC+Q%ef$%0G>5-MJrldx)s;=?3NPU)Y4uYX%9gt{FWG=~^jWRVDv!TO6 zkP`6C)%-VF3?F1?;K}e#b^*r2adr{j4ZkD4h5@h;Xn|u52lkoK6q9u4bGb}PQ^ypI zpQD`55=IF3klod<W^}Veeu0!132%e9v5`FtvBvZJx9ZX;IlVzJYd5jkvt=}d8TL#> z0ejZWm$73tMyzQ_XueEQR|y{zHsd!M@T&1qi-QF&k+}j9ag<!SzCk6et49@n4Sw7n z9ADi?@CiwjFX`n{R?ZD;ip0|c5j+2iET0m#@GjN#Nu!FUO4(dqSBKRx^9*2=>AM+f z7P&D;#9OW!er~O?o2f6uT<d+l>o`lqUH(c}A>UfxTsG3=WV>^jpP}x3qiRPbV^^T5 z?ZK97<m4yx5I#qgeETJy8ST+?`MjDv!mq$Y+g@AwM6R?wW?Mm@HK>ZbK2C?U98x`k zpWE<AMpyGZmsT_l10$koGwUeX%iZBLA1>BBR>vrOg75+%N}vMnEcyr3{Z89Tby`>j z-Xa}&i0?exNvli-jwpJnPto+$0G;!^#HC?HSGZ<wL&=usYF6h7bmgmHi$>UWh(|lv zEjS!Kh4QUveUSQ=bn0A{GAUK&R`M<H0mR#b=yag{A?ASN?SI?N#;(F#$1}{o7<=CH z1^KQkXJTC}1UF*E|7m;Udsrj96yM7(!FTZ;!L2@R+8VwB)rs9U1I1I&ooHkR<P!<@ d4V+1Q>!CjaUnUs%=h_A*J9pWJ{EGj%e*v2_i0J?T diff --git a/models/neuralnet.py b/models/neuralnet.py index e1f3290..6a07c4e 100644 --- a/models/neuralnet.py +++ b/models/neuralnet.py @@ -153,7 +153,7 @@ class FNN_HyperModel(kt.HyperModel): self.model = model return model - def fit(self, hp, model, x, y, validation_data, epochs, **kwargs): + def fit(self, hp, model, x, y, validation_data=None, **kwargs): def make_ds(x, y): ds_x = tf.data.Dataset.from_tensor_slices(x)\ .batch(self.batch_size, drop_remainder=True) @@ -169,14 +169,12 @@ class FNN_HyperModel(kt.HyperModel): val_ds = make_ds(*validation_data) history = model.fit(train_ds, validation_data=val_ds, - epochs=epochs, verbose=self.verbose, **kwargs ) else: val_ds = None history = model.fit(x, y, - epochs=epochs, verbose=self.verbose, batch_size=self.batch_size, **kwargs diff --git a/modules/__pycache__/utils.cpython-38.pyc b/modules/__pycache__/utils.cpython-38.pyc index 0ae60611424d313f7060707ed9e97db8ad855b88..b33d8e0b67828c2e0c10591ec40526fbe33e3837 100644 GIT binary patch delta 1332 zcmaJ>UuauZ7(d_1&FxLhlJw7#dy}+jyC!Lyty{avCKPd-AfgX~h^fLt-DygiHI<Xg zisRh&Ayx!gjXn4<88h{DFsku&Lm%uAL}mVi;6#NY;)C-=d|Lg^*~X&i67F}-`F_9O zcfa$y_uRF@uZ8e)d-eng{w`cC&uyG|KirOlQ@FJ7I*C_ptMzAO9pEpy*1ZEaWThv4 zG$6W}eUW-GVK7n&G$_Bqh=?&~np7<ox-%1f3l3)Y3J|nph}b&FcY6*WPgH2b;4ip7 zmaGF10SkU+ImC;*!*$|CMO;LSrm!$}VT_k3&i1hYMoX9tA1}z>-Xd(tT5oRtYqw8C z7WT4moYbLV)Q;fFW)aqX9y}aT6T?hg1XvU^-QwSU8$9|SeMfcQ_CNa0t=h=6OY~VH z2utZbyzY(&pLmG%pQoOKy%~Ik4G5>0shTXU#s*YBME0O&Z{v=K)Q;V3rwVyF)S9(v zH{wDwD(EKE!5uA*q53SUhf(e4r&dwEPs_(eunt~AB${~vo8Tqg{T&7UNKOf-d0>;E zxz55s(-BE5O4X`4p|aIu=R!EQ)%<x<9S0mW`T)($Dyh-}5eZfDt_vf&vlVPCqZ>PT zzm4iO>^!USONi$*ey9V#g!rh&uOU9B@yQN+1@Uo>|A2VD4y$O~R0rV{683AtZ6qAf z1rB%Mk0L&SxXA|)KdA8|9rys^hcrHo_@u@kzK;t(Dob~K^lV`_=N$DuXd7*OX84@% znt9?}&pk7hH=d1_PIOwT<^Q9e(XQTmv&N=c^RX_7QNDo(uN^o&d*;kbv!^RH{uZuV zug1eZNJ*Kv8fphiWw%&ap5^m3`AqUD*pL^KZ>K2GUDFqYAhoCgl<5C1rS!JplP6P; zKbUYKwFcWV-S0-qfal)8dlImKUvItFzYVb5DyLrrSd^a(J_0x8?%?J44P=b@c>#yE zugq%vD9(6Wems<cH{~xwGw`!KKHNL?lS+>(R^Q@mWv;kfuCk@&TCv2J7K>%Ju5PEk zHv~dm76gIsv|bzj2;d$0#K<JvYP~XI1AHboGEYXoP^HeQ@RbVR$lo$g!yockb|UaC zCfe{a|6ML+AA=9%RyGeHzt4USU&_n5Jp3rX%T2<kGBBD1vy~am`O@2}#&s2Ls<5NN bEfw&b5^l;${*6e%GEDR}IAnz^I~e~PlSX5F delta 1204 zcmaKrOK1~O6o&7aJUWRrlRlD}BsDc@^J;3fzR(I%wF=UL3lUpip?2CuTZ2po7v-jf z#x6uE!D~07l5X5|qsEPbq6>>4w)oakQK2prv4S9OM9;k{4fx0~|9yPtKRI(|?z{HU zcGsSzOPvIN$IS!XpPMhcW))WemkzwQ;XTk__@bNwc*;i0-oho;S6=O@?Dgq{dMHqd zTB(DU>DpKcyIsC!i?2`0Sq-mg(fwHfLZR2D%RZ9vn&B+TxT%-AQ+6t&mY{ljB%Iag z4pf=ia~5MgTUU{SQ*68<w*87(NtIrY?ovq>a+ZNlxYBe*FB^l5sw?R6;*Lrkx*K=O zs9RK}{!vwISjB499%(nBL=T7xTg+N&p-c3tF)5?rKU<9^eHqnKk+fY8@-YYLs~BA^ zqDSy~mh<N^qs1<ai=o#x@YPgI$h)aD0a-BPsRX)@p*w_bo8ccudsx`(s51*0ANA#v z6eb|!H|vWM+OUl0RLifJAlN#sH<<TPKPHv<D&N3&(;Eg$ajx<CUXrc@90mQ4r(>9; zrFKGnJms%0)QHNOu&`QD*viE+(l@d4h~WE>M+IM1#4jL^2|kECE_h=RA3~lG{4Dai zEDU4gRu=^v!hoa*xQ+q!BEgy>z8U!n!Do>-2)?$62a&H7{44TRg0EZT)Rwaw;h%lX zs8r2&`}j8@+Nk0?!}t5EnOg_Te>ao&CZl3Y8;V;h?Ei3gh_3$rvJ#DjNpA^wEyfu< zG{$L^Sw~lAfB)gmuD*f7UDaiQ0DI=Y!SaDt7-Am+<K<S6h`9I&(l@(}ov&*7+xEAr zQMkjh%S`1U;E}Xqb^<aOQw4i)3ZSoWzj{AFCrgAj!W`=gO~M1Fggx+rMZyutves~i z=MAr`nad6?2`;bMwQ#Uvgy)oTO`x<y1}7Nvh4<lGfah?swh<NzFKab`d+bYOr+b>) zj&XU&<uTh8ZG%tjT(rUQ1UF{ksPU0KiZ;VVRvAkIVa>5=m|;ph39s0ucq80kN8$mn r6;8xsmg+k^#}t=2F7sR#xZp=k3>%wCo>baos~x28&X8OxYfkkC3)?=T diff --git a/modules/utils.py b/modules/utils.py index 295d2b4..f633da5 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -224,11 +224,17 @@ def model_training(mdl_str, x_train, y_train, marker, lstm_mdl = tuner.load_model(is_training=True) lstm_hypermodel.verbose = True callbacks = tuner.get_callbacks(epochs=extra_train) + fit_kwargs = {'epochs': extra_train, + 'callbacks': callbacks, + } + if validation_data is not None: + fit_kwargs['validation_split'] = None + else: + fit_kwargs['validation_split'] = 0.2 + history = lstm_hypermodel.fit( None, lstm_mdl, x_train, y_train, - validation_data=validation_data, epochs=extra_train, - callbacks=callbacks - ) + **fit_kwargs,) tuner.save_weights_to_path() tuner.load_model(is_training=False) @@ -267,11 +273,17 @@ def model_training(mdl_str, x_train, y_train, marker, hypermodel.verbose = True callbacks = tuner.get_callbacks(epochs=extra_train) + fit_kwargs = {'epochs': extra_train, + 'callbacks': callbacks, + } + if validation_data is not None: + fit_kwargs['validation_split'] = None + else: + fit_kwargs['validation_split'] = 0.2 + history = hypermodel.fit( None, mdl, x_train, y_train, - validation_data=validation_data, epochs=extra_train, - callbacks=callbacks, - ) + **fit_kwargs,) tuner.save_weights_to_path() tuner.load_model(is_training=False) diff --git a/regress_rr.py b/regress_rr.py index 5063968..b4dce1c 100644 --- a/regress_rr.py +++ b/regress_rr.py @@ -1424,7 +1424,7 @@ def arg_parser(): ) parser.add_argument("-s", '--subject', type=int, default=1, - choices=list(range(1,N_SUBJECT_MAX))+[-1], + choices=list(range(1,N_SUBJECT_MAX+1))+[-1], ) parser.add_argument("-f", '--feature_method', type=str, default='minirocket', -- GitLab