From 95a2cca49258db5b237c3facc142e3001c721433 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 26 Oct 2022 14:59:38 -0700 Subject: [PATCH] Satisfy stricter input requirement for `jnp.linalg.norm` The previous version of the code caused an error under the new more restrictive API of `jnp.linalg.norm` (see https://github.com/google/jax/pull/12670). PiperOrigin-RevId: 484080138 --- jax_cfd/base/initial_conditions.py | 2 +- jax_cfd/base/initial_conditions_test.py | 2 +- jax_cfd/collocated/initial_conditions.py | 2 +- jax_cfd/collocated/initial_conditions_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jax_cfd/base/initial_conditions.py b/jax_cfd/base/initial_conditions.py index c1d4135..c2bc579 100644 --- a/jax_cfd/base/initial_conditions.py +++ b/jax_cfd/base/initial_conditions.py @@ -65,7 +65,7 @@ def _log_normal_pdf(x, mode, variance=.25): def _max_speed(v): - return jnp.linalg.norm([u.data for u in v], axis=0).max() + return jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max() def filtered_velocity_field( diff --git a/jax_cfd/base/initial_conditions_test.py b/jax_cfd/base/initial_conditions_test.py index 76259ca..e92f3fc 100644 --- a/jax_cfd/base/initial_conditions_test.py +++ b/jax_cfd/base/initial_conditions_test.py @@ -50,7 +50,7 @@ def test_filtered_velocity_field( self, seed, grid, maximum_velocity, peak_wavenumber): v = ic.filtered_velocity_field( jax.random.PRNGKey(seed), grid, maximum_velocity, peak_wavenumber) - actual_maximum_velocity = jnp.linalg.norm([u.data for u in v], axis=0).max() + actual_maximum_velocity = jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max() max_divergence = fd.divergence(v).data.max() # Assert that initial velocity is divergence free diff --git a/jax_cfd/collocated/initial_conditions.py b/jax_cfd/collocated/initial_conditions.py index cb1f4ff..0918be9 100644 --- a/jax_cfd/collocated/initial_conditions.py +++ b/jax_cfd/collocated/initial_conditions.py @@ -44,7 +44,7 @@ def _log_normal_pdf(x, mode, variance=.25): def _max_speed(v): - return jnp.linalg.norm([u.data for u in v], axis=0).max() + return jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max() def filtered_velocity_field( diff --git a/jax_cfd/collocated/initial_conditions_test.py b/jax_cfd/collocated/initial_conditions_test.py index 006a4af..f5bfe20 100644 --- a/jax_cfd/collocated/initial_conditions_test.py +++ b/jax_cfd/collocated/initial_conditions_test.py @@ -43,7 +43,7 @@ def test_filtered_velocity_field( self, seed, grid, maximum_velocity, peak_wavenumber): v = initial_conditions.filtered_velocity_field( jax.random.PRNGKey(seed), grid, maximum_velocity, peak_wavenumber) - actual_maximum_velocity = jnp.linalg.norm([u.data for u in v], axis=0).max() + actual_maximum_velocity = jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max() max_divergence = fd.centered_divergence(v).data.max() # Assert that initial velocity is divergence free