From 338721cb7d595579ef7c8572eab01939b337f583 Mon Sep 17 00:00:00 2001 From: YeongHyeon Park Date: Tue, 28 Dec 2021 03:14:21 +0900 Subject: [PATCH] 0.2.1 --- dist/whiteboxlayer-0.2.1-py3-none-any.whl | Bin 0 -> 11256 bytes setup.py | 2 +- whiteboxlayer/extensions/attention.py | 90 ++++++++++++++++++++++ whiteboxlayer/layers.py | 40 ++++++++++ 4 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 dist/whiteboxlayer-0.2.1-py3-none-any.whl create mode 100644 whiteboxlayer/extensions/attention.py diff --git a/dist/whiteboxlayer-0.2.1-py3-none-any.whl b/dist/whiteboxlayer-0.2.1-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..45e9bb6248d506693efb0f6174e8e27b415d1430 GIT binary patch literal 11256 zcmai)WmH_*wzdoR;7)LNhu~HOcXxMpcMI+gA-EGXxVyW%OK`XF(f5A$c6a)mb9aqh zwd==v#$02pJ>O|^l3?Iy0000Qux6N{`2GL@{p0KX!hXM9&CH#Q4Q$=4^xcge7<6^b zZOol?b?NQg!2ojqyD-H+LlFNAtO_BIE&2 z<{GHZ$_E3r>%=Hu6vl;DtLAkJ5g1AA32qk1w{8Nyc_NlP_cbJvi}|U6(0<|J;Ym67 zB1PYa^r4K)&%O{2eby8KkC&-3btDFV5_4fi3+Z_Hx;))_bc>`Vc}zHBfS-_T3_YAh z1S=+{MJ4+KMF!ntpKw+-gR7%tkq;@q27SxOPqWtLPg`sS}=H(J;IENAE zNoyUBWt9E0J+Xze&U*oh2npOUZfp@GN;;V=d*V{uzLts?Wk=YRzN@6)$2v+3gUL5x z8$mfJ{p4%r(iTipx zbJZyR{QEH_jB3QWX{N+NU+H2{+7IJ|55!H?hc@V;jJV!vo()DwOB$C)VgjeghZ{&_ z2}(Mwv~^@;3vvdfFtovmYJ@tU;n3tO$rrRdk5L;KS?Z+84@{X zb4Z9{$y=_pBJLk*7~f#9Se7eGo$7ChjOqw+rq0`yyhX|?!0SO zs~M7VmuhJdrb0sCbbH5uWj;dp3Xt3T+W|qyE0HgD&d9V`5kt3bTsXQi>u<~f3sS2E zyXPQ3t{|0_NY1Dk*XDEasfAHRv8JdI~|6L`mw=EIn98 zk&!V>g>d$Ndc{|zw(mew3#!e?yC2P4&Ph>aU?!hlBE6~>E{2fNa7?+1bVP&eqsp|5 zaviakeK%Lsb2wxr>&VwN&Y8_&4sl+V7&C{pf08?k-ig@JY?2%RoNSmRVIY9-2&|xb z@RiVBVE5L}2V}@~z)1P$2FAKiAHMm7MzZMd)6jR4tcT|OdaQaSTJEdb_YdE)zRxO}S@E;ya@vitf{s6N+4+38F}MQNuq8GVkk!5hf-VM9}1+UVgD zW^j^Be)E}!q7w?FNB|xK8oP!p(F`U7M4dkEM|WU{@AJ(-4g_z1v}s*xZ(HI|yav6< zg7P%llTem@2upQfMYBQRt!*2NGjY=bghFV%r?#FrX*Hg*I11v~ueOh~z+*FIRo>SE zO0ADk-8XZhTN5(+oMBtCq)%{Oul9ssDK-EO>}W&Xh(%|tP?FLDu!55t-ozy|Gq<6} zLNT>y^5$RU*TSh2L+HMQb{NIBI2NXJ^*L~&A}hRocb&`f7rC#wAW+zrD|U(gk)(x! zq8Onn<=u4_!C9^uo=j$Hl@0QckEHH<3_8+sF=I$>n(9_gHK<-GEgyuQ#q19{C(X+l zotW$jIjKj(SEoz|KPzGoP93n+a#9ou2{#XkR6ay9 zwa2@|sa-iKK6;{W{=&U1@nUy%DdM|jx$zUWS(1fWcc`I{J0e63HD$#%xh?0^+h~PH zvgfg*1I!}c_p8CgI_Sc!&H=H~@KY|e_S+HUv+WnjhEFJ((W*ekBHqEWZ&;>dkK>?f z9<#}H0&EQsIn+46UaWHX)G1I2em(GsP`hhPT74$34LFM?do^^rO2tNuhla9mSAf{6 zYgmk-i^sDQDXld+AXxvh=kzf+Q(*bK!=Ykf3N zZ~)*D0RSNTPYivhcbwBZ?p;?Zl%;KV*pb?vRllo)7rLN-y30Ne)AnyvIf3-Z<>zuAJ}t1GKh z@{BQsL1;8?F-4+#)&o^JqNX5HbO@SZw5wP#9RZ#J>Fd}`fl~>DOWo2CJ>CptFT3F0 z-X>A?tOM2LLl} zWJ=Ouz&N=6bnMAYJ?EppX@F74UVs8^qmPz?6QSn!^cXup_*!xXUf2i{B}F>~n!!^# z%tG1h)k{I@1mr`Z#)wcxdNAe*$e5s6?!hkdvzszTTA~Ot(J$Hh1eTU~mR;A7(o}y| z(#N(5I&qLXjL)^9SNnmE7RJ;E*#p-F!sGE~{nb5VorClwcSH)Z9U*sgE3eP<)-X@PgsYLkLY$!ffAAq%x9s=V&=;J z3dG~b+D5pQVxlI`han&GU))&nPU2p_aIPZl1J#$#PJ|sdP2VPGhg;=F0Xc5ImNZueBS7T;0P7trxITC?vWV} zz<)?st6yq(U3Ec!1X1p2a=sf1X4gj5lhfYoo~&f?;!X}XzeHsGQM|*BVG?+SR25*) zs@ow?t6EbPGaJtz&elk7D|*uXO$VvT z8zWkTyb#i5klLPd@O=o`roY*-U$uQalgXERqln8y%;`i)fm>~kanUL)-!3(;&QsRL z^A^P8Usip%F%%R1wB17gPyk11oIf&}QxdTm9<|NuS+>yDQ__LiGj6tQQOdG9ZDg9M z)6fsNER2YtoBWKE8mI?~*MB@L_*MB=Q1^GieibndF+SU(mw{_`8w{P9#xLYtQO#zt za*LANQR0sj<1L}Ok#RxG=1t`H=9^9xRFX*1q4L}!ni9G))4m5b&tANT2cH^hI2yCF zxs}!@6CKvav~17fQ}Yy5>&?4A%D#>Lg8I7$)cwHnw)*b9B5?r#qW|Q*Oda&?%>Iao z)vRMy*x`R=>k=7pX0~w&CV0bbg7kbpow4bhv{l!mZ1`DRZ!D{@ezp*ko4+`|ciLRp zc9stIAzF$re=szkSck2CR;0YlM7EeX7u&8SKTWQFbi~7Q-kA@}h?XKae~en(vKmIB z*_4^MC2TS*U%YJ|`O#8fvo@bgn=Ch3`GtlK9qfzJHNspg#-%Wwp5SD7-~B!E5@S<2 zdc?*M8=NWFtSl-$9_g1M8!g7HVsvXYr05a}iApD=@y`iBk;X|7*bo(GIM_ohYU|=% zC-oXY{f;}7+g27E0UZ^lDAL-+`dW#Rr|cQ%H68rvcxC4XeCtnu=71}p;@TeO%9yFL zxN_bY(6#1Lr*|~?%{!i@w1WDuyq|R>fUcO1+v2-RxyvwC{#5s;6gT|FpFu3sVyMSk zws~KP?nxUJw!{9&je*o$8LEx1bfdVWO2rbgj#Pbt#2WYo;(D`h)vz>q1 zZ38+F-Dtp?mhOG6pWHHhKHNmE8*wT9h9URk^J4^BNI*cAF|81@uEC`fvUctRi`rP< zwfL)T3=L&*BIQ=mAn2ORPe3B_B8TeLv|T9cZ;p|BkXrUKv$2ylz?a- zDZJG%Tda07su2Zj&h*Aigc*4t-hxis-4#8BQjwld2JyR!w-{gQE}_c)OrzDR+ef*R zvJ^c@sxLESC}hchAY*9QE`{433^uPpH*FOVaem;E4|8(aiB3C1@R0*bIdLepm}ap& zsXKk+64p*jH*6HAj%`MhkFn#J@(aiz-XL@P7B<`mmppnNND3qvCtyX+C-P4WAOgX% zMM-^aJoP^AS&n7N*cXI>U!_3}$~r1bl4e!+=bffdCpw~M==J9-oQWG=vGRl zUY5&0pJg$}`&P%DTVA*wX2fR%nYiG04BB!b5}qx15kK@;>D@eyLm;V%=?=lMOp!{e zgT>trfVNmws-JSrqV8SYM5@LX=hKSJlazgC^te_RS*Mr!M59up#O=1y&z_l!Md{naLW3MYW>A#o8#u&73n2@tJd^SsJFkA z2>u8TMZD(;ZdFM06^3+7v{s3V-8v4I3Y78Sqm;Lal^{? z3&hl509BEUUWKP~))VBDID+UKy+pvsF{)6U^uZVB;-jYR>I|ln*B0!veI8VsdDKl0 z$mWO@HxOK66}(n$>C-fQpX<$drUDmg%y6p8)pTpjr*Z%8{PpiWJ=wtk0A|GhBXo5z zHgtAyc#mG+H#w#{WwSzv<~^aBBPRdlSi~Bi#+@PIuI|*>dxThok=2ad&2Pp{mmj{L zvjq3EM$%}w0d`N@LAK-fDtCDaW543?8b~!Td6`B6W+k%{HG16)`yTRyNak2eMM9x# zy#2;PY7i*WT}x1xVkM`qfaz>w=nJ&=n{?Pnp3T%wYMyuRF9 z;XXs~Zdx0Io#0GM3;55HWeMpSQ{}10nclt~_@{h7%L{E-&paTA0-Wn(D0PYaxUc*n z<9H9XLvwhh_Ff8+u=1tZL1Rry`S3LkG8H&kh{;840}qlB%{``MTMwBw13rg;!2Y6v zu*+Qo^)v(L0_zd#48UR}d;BJVL2I5Wwadf`|BKng8+S9bjWeXyjZEOUt~Xli94C)t zKenC!%Cp>HM-8XZ6W>qy`X%XEla!J%n|$^bDQosz_LEr?Rr!plQaan*hE^$Scar)& zV~Y{wdgG#o*6?KqCF^bEZfA5hxjZzwFhOnzoE}gzBiQ(i!Q88s7O#+Kiu}Y4g_XxxS8FH;8{^ z+asLzXcabOgy->Bu-D(bIf*&2-K@1P0LeUZXArTKrO<1BFgjHoAYv%UEmQz?+G|Y6 zm^+#I2&0|HH=Fpi;#PS1Ch$kNyxQUPdJx8EbKd02iZy(c@5>JX%sPb9zAH9GB zSD1f)@gOvAT;1QRkNF<}0FwXYTAZECt<0U=-(R}|l@*&6P9*PpRU~>)$;#o`AwD#a zy3U3SG~kSNwf)z+_e!Kvd6X0~sj=qkQx=i)x9yt-HqcP=0|!n=laM|cmEl+kS=b*a zw=(QO=EO!nj2kDDhDVP5p2@%NmM*)uD6yMp!lctQ?1W12PL`v1R!Uivq9rdg>a&_x zIGcJF!Zy%1;fohMM9NwM=#h@!XDEn5W`y0N7pJ(QCe5XZ-*CI+*2#jQIg z9MdK6Knv5z?+Ox0=?OWA(u?6GKpe*3Zg1@PJ_o@yHQ|zCM z=o?PkE!d}EkH!~)0z@tWLvL5cGmBYrPn@YENEk426BS0$ltl=YP9(0~N{-h}YW>Or zfqAKuC^BFefg!sVm=u`E>N%IOJS6_%gG9uZu7UgAB^7By4OmWe*n}1WJk{$LcNnX~ z<~{!Do=+)7HW?A99SQVcbT=2_V!vJcvZe8Xg(X%?d#50&3Jrn!!4c{&d=CztT z8Nrmm+zj*p`)d$*gi%+r{yfR(bkTVZIq2ut3GmA>^5D!FsN(TmBZ31smyMHHQ zjI9575R};z0bwyOjpj3;z?J=hpNTOBDyurCQad1qV4NuJ2T1qCeSPYDUmJ9Y2$@e^ z)eFId_iWiWpJtkdC9a{xYd5D2QyfzL?dXK>;f@Js z3n@!7{LK-m4^^W}FRk{c#OH{EM;RAVAOn3Dw!y4yQ}w%$RwTM(y-BH*lQ79h9}_ga z5*NjE7l2Lv($F7S%*mRp-@0p!1D}))PrLnNPF-gMqgo&$57)@ERZwuej9-tjm^?5= zIAXze?KAp^0Xy7U9!Nni0Z0Pqf+-$Y_@U;U)J!=k5j%~`2}`uZwf7IVARX)53l0si z1n2}mWK2wNm??(R+M8YNlG{Uads6CA|?RIn3#$&8s>*r#5#VDa*`K5IN${s?J zntfS21zc9v3%@@^V{v91%K6Ke^tZ*P+w7eX{1gbm`A^ZfJH1-dbwze} z@^SU{d^Jg-L_a_3B9&YgCa+Z|7717apN3HQ7vd!c0=dMN11-3B?cR4~{esbh z_6ocF8pCb#3Ph>zo|T|z;e&;8chyt9Nrh@qH@Zl$PvL%8&RZutkp6)bADg{5wYnqD zstv>Knz@2Z>K^%>7nOBffpyoyOn2z*s2c>~bj$Y1Tj&E@Ti0r3EBJ}N#L45#k#dTL zB{_I}AcSNWqc|^*1`T{Hfr4xhG=UdRH=az{FMr9GdE+t^e~tzaJH7{mH19d{KdOCn zjP%U(O!P+Pj!tyuHYTGy9Zwe#_-)M$>%(*^&G#gkHnOf^>^)3l4pXc6(rsaB9F##q~F_ zC24G2e0B^P3gK)@=0veGvnXRTFuUe;;SjbA=zyx@LdLR?Gl`IM5_ zf>fFT0_?+wqtw@LqT zYNdfnf+B)Sf~Tt1jw}7hzoK=k3r>)ah&nNcDo2nP^OM)|tI=Xe782yyZjjYtRY6=2 z*91vlPg*>r-JE*&b<`Qv`MiH5;?V zplHMU04p%~tq$KhJ>iD5)t;Z*9VD&EzlU~0>7U`ahXuzzg0nX(L|odrv3otP1puws z-55~A!?k>xpN@_XCRDqEbxO1%SYK`~UpXL|k;v7DJ@K$cf9ic9t`8Ja3nhP|9@(T= zEw2+lwQu6v?KWOUd&XMd?bb*W&(G(>YGcf$%-87j{Dy&!4oyzFA?FZYB@-C)a4vyU zr%#vJ$;yX7KS?f{9Lp+X z#e{@h>vJqf!kV3s&zT9Df3`=Peq5Dsphi`P-N&S1>^t6sX?gOdJO}BJlzt>QfxGUuacnZyBkt`j&w0k{*x?sN`OQ@gy#QVT3 z=rxV$CaYmA=!oJ){y8;4kQ%OX&_R*Rnl`~NYO`5{q=Z<^h)5c@>z3rwUM?x#1`LkQ z!5~vG7X<=Za&)A8_&JJ#SaU-Wwl=c)Gh7M#A}ri5m7zp5s;H<(8c|J0M= z(?H`Va;`o;CZdu>ztCHAbv08<(_r@91D5Pnk~E1cYO;*v>q2dY1h8un<$;%{hr8#9 zksomk%PaMp%L^Z^Ye)Z};Vf|D)@qFtye+<<#!)k@N^spQ2HUY!B9n(3wU zO7-Lunz`pwdK#(&gy_0*=H-=A0ergri{{jpaE~ES+d2*gY2#lof0CyhkBer<%0{hD zn9yq%*IuAVE2NDja94{D*)W)VQ9!QN9pj$qP~8h*%g7b*X*W!YGRE*zG&Ce8L1-bQQWA8uFDc_J73Ny2 zo_D1UH)){|Wu9Jl23GA+N!NdFCjIkih=Gg5+1B+IKi8A;RbC|Cd4&R zO5C{r8N5W*EU3q(MUW%G&S@o=ro(8Sj4jL%>yF%KKwNnX#Hg(z?p&wW(d>AkU+&`1f_62?Ij!q+0a z85o}Pn&JuPL*A*{&knhTNg4;v3a)&+m+aglDTrVe?hL~lHy`UhqN*Qi>4u}UXbHE6 z+0MsxG-4|}wbQeC)~qHRbo5FeXtk%w--SoBTcyX39yW#}4j&`@`jzWKSwT7xl0{xo zRQec&e40FH+v`|>jxWmML7l#{rD03J|DuTbQT?XwR}x}zXv!ldynrc^|xO`gu z=*jS+-3mQAyX>N0@NWq-@x>Om3e-^Zrgyfv` zXxZ*I=wIrOe~5O+lr|$(1@8sr0GciajE338s&dkky*9>bFUpI`-`u&jbe#pFnIt{jrzmN!4#;_8RNTukV&oMJ5z zCt)PBPkw%OAAS@Ou!zh46!%yiEmSD0qW1kNl&Dqu4SFEo9T8Mb2o1c}gLIkI2=LYg z6K*lLIAnb?FXB=!^&ajk4b0`7D z+f;>Wk?)R|1Z_+K5F$GX+QEDY!+{CsYPTL7n9t=JKU(TEhXu<@r>B%!kQ!7(@~I+Fn%3K)J_iR!a^G?1;=2z6*JAED|08qTw39(&HebB zONL|POfEpw!-ckqKxLamPlXi4PG-@4oqxJ57H_HMqXwzfb&{at+qPq%Uo35{MGtU$ zhLDp40YwA*_rrYeNAtg~eE&aw{&xKTI*{>K%)c7#`)|m5vCaR}KQRA0==WF1znTpB z6T<%of`Qo@(&|_Z{Gfi6omT^M*dGL_jlay4Ua!@ zTJL`7-|qQ;+aJI4es3@Q$>V#^L;fxAug!+vk-ryC|3vwRq`QJVG-}e0vGyhAreh2<8E`I`7 w|4-oGq~>?<@8a@5;5rbt_j<{Hw)DTHMotpq{e=Jk5Z*tS?-Dga@yEaa2L!||bpQYW literal 0 HcmV?d00001 diff --git a/setup.py b/setup.py index 8c0c205..fa14cd9 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'whiteboxlayer', - version = '0.2.0', + version = '0.2.1', description = 'TensorFlow based custom layers', author = 'YeongHyeon Park', author_email = 'young200405@gmail.com', diff --git a/whiteboxlayer/extensions/attention.py b/whiteboxlayer/extensions/attention.py new file mode 100644 index 0000000..cf6c7a8 --- /dev/null +++ b/whiteboxlayer/extensions/attention.py @@ -0,0 +1,90 @@ +import numpy as np +import tensorflow as tf + +def embedding(layer, x, dim_model, name='emb', verbose=True): + + emb = layer.fully_connected(x=x, c_out=dim_model, \ + batch_norm=False, activation=None, name="%s" %(name), verbose=verbose) + + return emb + +def feed_forward_network(layer, x, dim_ff, dim_model, name='ffn', verbose=True): + + ff1 = layer.fully_connected(x=x, c_out=dim_ff, \ + batch_norm=False, activation='relu', name="%s_0" %(name), verbose=verbose) + ff2 = layer.fully_connected(x=ff1, c_out=dim_model, \ + batch_norm=False, activation=None, name="%s_1" %(name), verbose=verbose) + + return ff2 + +def get_angles(pos, i, dim_model): + # https://www.tensorflow.org/text/tutorials/transformer?hl=en + + angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(dim_model)) + + return pos * angle_rates + +def positional_encoding(position, dim_model): + angle_rads = get_angles(np.arange(position)[:, np.newaxis], + np.arange(dim_model)[np.newaxis, :], + dim_model) + + # apply sin to even indices in the array; 2i + angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) + + # apply cos to odd indices in the array; 2i+1 + angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) + + pos_encoding = angle_rads[np.newaxis, ...] + + return tf.cast(pos_encoding, dtype=tf.float32) + +def concat_heads(x, verbose=True): + # https://www.tensorflow.org/text/tutorials/transformer?hl=en + + [d_n, d_s, d_h, d_fh] = x.shape + xc = tf.reshape(x, (d_n, d_s, -1)) + + if(verbose): print("Concat Head", x.shape, "->", xc.shape) + return xc + +def self_attention(layer, x_query, x_key, x_value, num_head=1, mask_idx=-1, udmask=False, name='enc', verbose=True): + + [_, d_s, d_f] = x_query.shape + + enc_query = layer.fully_connected(x=x_query, c_out=d_f, \ + batch_norm=False, activation=None, name="%s-query" %(name), verbose=verbose) + enc_key = layer.fully_connected(x=x_key, c_out=d_f, \ + batch_norm=False, activation=None, name="%s-key" %(name), verbose=verbose) + enc_value = layer.fully_connected(x=x_value, c_out=d_f, \ + batch_norm=False, activation=None, name="%s-value" %(name), verbose=verbose) + + sq_dk = tf.math.sqrt(float(d_f)) + enc_qk = [] + if(num_head != 1): + list_query = tf.split(enc_query, num_or_size_splits=num_head, axis=2) + list_key = tf.split(enc_key, num_or_size_splits=num_head, axis=2) + list_value = tf.split(enc_value, num_or_size_splits=num_head, axis=2) + + for idx_query, _ in enumerate(list_query): + enc_qk.append(tf.matmul(a=list_query[idx_query], b=list_key[idx_query], transpose_a=False, transpose_b=True) / sq_dk) + + enc_qk = tf.stack(enc_qk) + else: + enc_qk = tf.matmul(a=enc_query, b=enc_key, transpose_a=False, transpose_b=True) / sq_dk + + if(udmask): # upper diagonal masking + enc_qk = tf.where(tf.linalg.band_part(enc_qk, -1, 0)==0, -1e+9, enc_qk) + enc_smax_qk = tf.nn.softmax(enc_qk, axis=-1) + + if(num_head != 1): + enc_qkv = [] + for idx_value, _ in enumerate(list_value): + enc_qkv.append(tf.matmul(enc_smax_qk[idx_value], list_value[idx_value])) + enc_qkv = tf.transpose(tf.stack(enc_qkv), [1, 2, 0, 3]) + enc_qkv = concat_heads(x=enc_qkv, verbose=verbose) + else: + enc_qkv = tf.matmul(enc_smax_qk, enc_value) + + if(verbose): print("Self-Attn (Head: %d)" %(num_head), x_query.shape, "->", enc_qkv.shape) + return {'query':enc_query, 'key':enc_key, 'value':enc_value, 'attention':enc_smax_qk, 'output':enc_qkv} diff --git a/whiteboxlayer/layers.py b/whiteboxlayer/layers.py index 250c42a..6466245 100644 --- a/whiteboxlayer/layers.py +++ b/whiteboxlayer/layers.py @@ -55,6 +55,12 @@ def activation(self, x, activation=None, name=''): return tf.nn.swish(x, name='%s' %(name)) else: return x + def dropout(self, x, rate=0.5, name=''): + + y = tf.nn.dropout(x=x, rate=rate, name=name) + + return y + def batch_normalization(self, x, trainable=True, name='', verbose=True): # https://arxiv.org/pdf/1502.03167.pdf @@ -79,6 +85,40 @@ def batch_normalization(self, x, trainable=True, name='', verbose=True): if(verbose): print("BN (%s)" %(name), x.shape, "->", y.shape) return y + def layer_normalization(self, x, trainable=True, name='', verbose=True): + + len_xdim = len(x.shape) + if(len_xdim == 2): x = tf.transpose(x, [1, 0]) + elif(len_xdim == 3): x = tf.transpose(x, [2, 1, 0]) + elif(len_xdim == 4): x = tf.transpose(x, [3, 1, 2, 0]) + elif(len_xdim == 5): x = tf.transpose(x, [4, 1, 2, 3, 0]) + + mean, variance = tf.nn.moments(x=x, axes=[0], keepdims=True, name="%s_mmt" %(name)) + + c_in = x.get_shape().as_list()[-1] + offset = self.get_variable(shape=[c_in], constant=0, \ + trainable=trainable, name="%s_ofs" %(name)) + scale = self.get_variable(shape=[c_in], constant=1, \ + trainable=trainable, name="%s_sce" %(name)) + + y = tf.nn.batch_normalization( + x=x, + mean=mean, + variance=variance, + offset=offset, + scale=scale, + variance_epsilon=1e-12, + name=name + ) + + if(len_xdim == 2): y = tf.transpose(y, [1, 0]) + elif(len_xdim == 3): y = tf.transpose(y, [2, 1, 0]) + elif(len_xdim == 4): y = tf.transpose(y, [3, 1, 2, 0]) + elif(len_xdim == 5): y = tf.transpose(y, [4, 1, 2, 3, 0]) + + if(verbose): print("LN (%s)" %(name), x.shape, "->", y.shape) + return y + def maxpool(self, x, ksize=2, strides=1, \ padding='SAME', name='', verbose=True):