@@ -13,6 +13,11 @@ def draw(self, *args, **kwargs):
13
13
def __call__ (self , * args , ** kwargs ):
14
14
self .draw (* args , ** kwargs )
15
15
16
+ def set_output_mode (self , mode : str ):
17
+ """Set notebook or script mode - not implemented yet"""
18
+ ...
19
+
20
+
16
21
17
22
class LossSubplot (BaseSubplot ):
18
23
"""To rewrire, this one now won't work"""
@@ -59,6 +64,7 @@ def draw(self, logs):
59
64
plt .title (self .title )
60
65
plt .xlabel ('epoch' )
61
66
plt .legend (loc = 'center right' )
67
+ plt .show ()
62
68
63
69
64
70
class Plot1D (BaseSubplot ):
@@ -77,6 +83,7 @@ def draw(self, *args, **kwargs):
77
83
plt .plot (self .X , self .predict (self .model , self .X ), '-' , label = "Model" )
78
84
plt .title ("Prediction" )
79
85
plt .legend (loc = 'lower right' )
86
+ plt .show ()
80
87
81
88
82
89
class Plot2d (BaseSubplot ):
@@ -119,3 +126,4 @@ def send(self, logger):
119
126
plt .scatter (self .X [:, 0 ], self .X [:, 1 ], c = self .Y , cmap = self .cm_points )
120
127
if self .X_test is not None :
121
128
plt .scatter (self .X_test [:, 0 ], self .X_test [:, 1 ], c = self .Y_test , cmap = self .cm_points , alpha = 0.3 )
129
+ plt .show ()
0 commit comments