From 9b24004d062c64b7eff502c16a8e68b68735dc39 Mon Sep 17 00:00:00 2001 From: Yaad Rebhun Date: Sun, 23 Feb 2025 17:41:40 +0200 Subject: [PATCH] Bug fix - to handle "u: list of array-like, shape (n_samples, n_control_features)" input - I'll elaborate more in the PR comment section --- pysindy/pysindy.py | 9 +++++---- pysindy/utils/axes.py | 8 ++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index 5c982aa82..04fa3e18d 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -896,13 +896,14 @@ def comprehend_and_validate(arr, t): if u is not None: reshape_control = False for i in range(len(x)): - if len(x[i].shape) != len(np.array(u[i]).shape): + if len(np.array(x[i]).shape) != len(np.array(u[i]).shape): reshape_control = True if reshape_control: try: - shape = np.array(x[0].shape) - shape[x[0].ax_coord] = -1 - u = [np.reshape(u[i], shape) for i in range(len(x))] + shape = [list(np.array(xi).shape) for xi in x] + for i in range(len(shape)): + shape[i][x[i].ax_coord] = -1 + u = [np.reshape(u[i], shape[i]) for i in range(len(x))] except Exception: try: if np.isscalar(u[0]): diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 25c42db8f..52cd431e1 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -811,10 +811,10 @@ def _apply_indexing( def comprehend_axes(x): axes = {} - axes["ax_coord"] = len(x.shape) - 1 - axes["ax_time"] = len(x.shape) - 2 - if x.ndim > 2: - axes["ax_spatial"] = list(range(len(x.shape) - 2)) + axes["ax_coord"] = len(x if isinstance(x, list) else x.shape) - 1 + axes["ax_time"] = len(x if isinstance(x, list) else x.shape) - 2 + if np.array(x).ndim > 2: + axes["ax_spatial"] = list(range(len(x if isinstance(x, list) else x.shape) - 2)) return axes