diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index 5c982aa8..04fa3e18 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -896,13 +896,14 @@ def comprehend_and_validate(arr, t): if u is not None: reshape_control = False for i in range(len(x)): - if len(x[i].shape) != len(np.array(u[i]).shape): + if len(np.array(x[i]).shape) != len(np.array(u[i]).shape): reshape_control = True if reshape_control: try: - shape = np.array(x[0].shape) - shape[x[0].ax_coord] = -1 - u = [np.reshape(u[i], shape) for i in range(len(x))] + shape = [list(np.array(xi).shape) for xi in x] + for i in range(len(shape)): + shape[i][x[i].ax_coord] = -1 + u = [np.reshape(u[i], shape[i]) for i in range(len(x))] except Exception: try: if np.isscalar(u[0]): diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 25c42db8..52cd431e 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -811,10 +811,10 @@ def _apply_indexing( def comprehend_axes(x): axes = {} - axes["ax_coord"] = len(x.shape) - 1 - axes["ax_time"] = len(x.shape) - 2 - if x.ndim > 2: - axes["ax_spatial"] = list(range(len(x.shape) - 2)) + axes["ax_coord"] = len(x if isinstance(x, list) else x.shape) - 1 + axes["ax_time"] = len(x if isinstance(x, list) else x.shape) - 2 + if np.array(x).ndim > 2: + axes["ax_spatial"] = list(range(len(x if isinstance(x, list) else x.shape) - 2)) return axes