From 569fd2eaad5cc0e03bccc75496e13253a7af96e5 Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Fri, 5 Apr 2024 09:40:37 +0200 Subject: [PATCH 1/5] remove fixed version of neural-tangents --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3050110..cb1c088 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ plotly flax tqdm pandas -neural-tangents==0.6.4 +neural-tangents tensorflow-datasets isort tensorflow From 499070e3be3aac5572a5c103dbd6b41cb1e0bb91 Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Fri, 5 Apr 2024 09:42:25 +0200 Subject: [PATCH 2/5] Run black --- znnl/training_recording/jax_recording.py | 24 ++--- znnl/visualization/tsne_visualizer.py | 110 ++++++++++++----------- 2 files changed, 71 insertions(+), 63 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 862b7d4..9ba92f8 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -254,17 +254,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any([ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ]): + if any( + [ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ] + ): self._compute_ntk = True if "loss_derivative" in self._selected_properties: diff --git a/znnl/visualization/tsne_visualizer.py b/znnl/visualization/tsne_visualizer.py index 856ecdd..d0d613a 100644 --- a/znnl/visualization/tsne_visualizer.py +++ b/znnl/visualization/tsne_visualizer.py @@ -104,45 +104,47 @@ def run_visualization(self): fig_dict["layout"]["xaxis2"] = {"domain": [0.8, 1.0]} fig_dict["layout"]["yaxis2"] = {"anchor": "x2"} fig_dict["layout"]["hovermode"] = "closest" - fig_dict["layout"]["updatemenus"] = [{ - "buttons": [ - { - "args": [ - None, - { - "frame": {"duration": 500, "redraw": False}, - "fromcurrent": True, - "transition": { - "duration": 300, - "easing": "quadratic-in-out", + fig_dict["layout"]["updatemenus"] = [ + { + "buttons": [ + { + "args": [ + None, + { + "frame": {"duration": 500, "redraw": False}, + "fromcurrent": True, + "transition": { + "duration": 300, + "easing": "quadratic-in-out", + }, }, - }, - ], - "label": "Play", - "method": "animate", - }, - { - "args": [ - [None], - { - "frame": {"duration": 0, "redraw": False}, - "mode": "immediate", - "transition": {"duration": 0}, - }, - ], - "label": "Pause", - "method": "animate", - }, - ], - "direction": "left", - "pad": {"r": 10, "t": 87}, - "showactive": False, - "type": "buttons", - "x": 0.1, - "xanchor": "right", - "y": 0, - "yanchor": "top", - }] + ], + "label": "Play", + "method": "animate", + }, + { + "args": [ + [None], + { + "frame": {"duration": 0, "redraw": False}, + "mode": "immediate", + "transition": {"duration": 0}, + }, + ], + "label": "Pause", + "method": "animate", + }, + ], + "direction": "left", + "pad": {"r": 10, "t": 87}, + "showactive": False, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top", + } + ] sliders_dict = { "active": 0, @@ -163,20 +165,24 @@ def run_visualization(self): } # Add initial data - fig_dict["data"].append({ - "x": self.dynamic[0][:, 0], - "y": self.dynamic[0][:, 1], - "mode": "markers", - "name": "Predictor", - }) - fig_dict["data"].append({ - "x": self.reference[0][:, 0], - "y": self.reference[0][:, 1], - "mode": "markers", - "xaxis": "x2", - "yaxis": "y2", - "name": "Target", - }) + fig_dict["data"].append( + { + "x": self.dynamic[0][:, 0], + "y": self.dynamic[0][:, 1], + "mode": "markers", + "name": "Predictor", + } + ) + fig_dict["data"].append( + { + "x": self.reference[0][:, 0], + "y": self.reference[0][:, 1], + "mode": "markers", + "xaxis": "x2", + "yaxis": "y2", + "name": "Target", + } + ) # Make the figure frames. for i, item in enumerate(self.dynamic): From 37f0f50f0f422304529a4cf199bfdca8dae8d3a8 Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Fri, 5 Apr 2024 15:03:54 +0200 Subject: [PATCH 3/5] upper bound for nt version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index cb1c088..e41aa51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ plotly flax tqdm pandas -neural-tangents +neural-tangents>=0.6.5 tensorflow-datasets isort tensorflow From 42be478a88724ed898a56892cc0b394094069270 Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Mon, 8 Apr 2024 09:48:33 +0200 Subject: [PATCH 4/5] Constraining Jax Version with <=0.4.25 --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e41aa51..65d3737 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,8 @@ tensorflow_probability scipy scikit-learn jaxlib -jax +# Temp fix of version of jax until the next release +jax<=0.4.25 plotly flax tqdm From ac3040f205a6148170bdf99023c0e66ce7a9f65e Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Mon, 8 Apr 2024 09:51:47 +0200 Subject: [PATCH 5/5] constrain requirements of jaxlib --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 65d3737..d31d588 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,9 +12,9 @@ nbsphinx tensorflow_probability scipy scikit-learn -jaxlib -# Temp fix of version of jax until the next release +# Temp fix of version of jax and jaxlib until the next release jax<=0.4.25 +jaxlib<=0.4.25 plotly flax tqdm