Skip to content

Commit f2532d9

Browse files
committed
Fix error when switching data on already trained model
New dataset didn't have the predictions column yet which errored when trying to serialize and send to anchorviz. Fixed by adding a "data_changed" event that the model class listens for and calls `fit()`
1 parent 837f8e0 commit f2532d9

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

icat/data.py

+15
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def __init__(
295295
self._data_label_callbacks: list[Callable] = []
296296
self._row_selected_callbacks: list[Callable] = []
297297
self._sample_changed_callbacks: list[Callable] = []
298+
self._data_changed_callbacks: list[Callable] = []
298299

299300
super().__init__(**params) # required for panel components
300301
# Note that no widgets can be declared _after_ the above, or their values won't be
@@ -452,6 +453,14 @@ def on_row_selected(self, callback: Callable):
452453
"""
453454
self._row_selected_callbacks.append(callback)
454455

456+
def on_data_changed(self, callback: Callable):
457+
"""Register a callback function for the "data changed" event, when the
458+
active_data dataframe is switched out.
459+
460+
Callbacks for this event should take no parameters.
461+
"""
462+
self._data_changed_callbacks.append(callback)
463+
455464
@param.depends("sample_indices", watch=True)
456465
def fire_on_sample_changed(self):
457466
for callback in self._sample_changed_callbacks:
@@ -465,6 +474,10 @@ def fire_on_row_selected(self, index: int):
465474
for callback in self._row_selected_callbacks:
466475
callback(index)
467476

477+
def fire_on_data_changed(self):
478+
for callback in self._data_changed_callbacks:
479+
callback()
480+
468481
# ============================================================
469482
# INTERNAL FUNCTIONS
470483
# ============================================================
@@ -698,6 +711,8 @@ def set_data(self, data: pd.DataFrame):
698711
if self.label_col not in self.active_data:
699712
self.active_data[self.label_col] = -1
700713

714+
self.fire_on_data_changed()
715+
701716
self.set_random_sample()
702717
# TODO: seems weird to handle this here
703718
self._apply_filters()

icat/model.py

+6
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
self.anchor_list.on_anchor_removed(self._on_anchor_remove)
8181
self.anchor_list.on_anchor_changed(self._on_anchor_change)
8282
self.data.on_data_labeled(self._on_data_label)
83+
self.data.on_data_changed(self._on_data_changed)
8384
self.view.on_selected_points_change(self._on_selected_points_change)
8485

8586
self._last_anchor_names: dict[str, str] = []
@@ -88,6 +89,11 @@ def __init__(
8889

8990
self.anchor_list.build_tfidf_features()
9091

92+
def _on_data_changed(self):
93+
"""Event handler for when set_data in datamanager is called."""
94+
# self.data.active_data = self.featurize(self.data.active_data, normalize=False)
95+
self.fit()
96+
9197
def _on_data_label(self, index: int | list[int], new_label: int | list[int]):
9298
"""Event handler for datamanager.
9399

0 commit comments

Comments
 (0)