From 9a7a4192b9b2fc2fc54bffeb84988a352ceb4bd0 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Wed, 31 Jul 2024 21:23:41 +0100 Subject: [PATCH] Revert tensordot change. --- cubed/tests/test_array_api.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index b8c68dbe..6ce42db0 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -424,13 +424,11 @@ def test_outer(spec, executor): @pytest.mark.parametrize("axes", [1, (1, 0)]) -@pytest.mark.parametrize("dtypes", [(None, None), (np.float32, xp.float32)]) -def test_tensordot(axes, dtypes): - ntype, xtype = dtypes - x = np.arange(400, dtype=ntype).reshape((20, 20)) - a = xp.asarray(x, chunks=(5, 4), dtype=xtype) - y = np.arange(200, dtype=ntype).reshape((20, 10)) - b = xp.asarray(y, chunks=(4, 5), dtype=xtype) +def test_tensordot(axes): + x = np.arange(400).reshape((20, 20)) + a = xp.asarray(x, chunks=(5, 4)) + y = np.arange(200).reshape((20, 10)) + b = xp.asarray(y, chunks=(4, 5)) assert_array_equal( xp.tensordot(a, b, axes=axes).compute(), np.tensordot(x, y, axes=axes) )