Skip to content

Commit 935ce79

Browse files
committed
Tune down TestMeandDtype.test_mean_custom_dtype
1 parent 3eea7d0 commit 935ce79

File tree

1 file changed

+47
-43
lines changed

1 file changed

+47
-43
lines changed

tests/tensor/test_math.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3210,52 +3210,56 @@ def test_mean_default_dtype(self):
32103210
# TODO FIXME: This is a bad test
32113211
f(data)
32123212

3213-
@pytest.mark.slow
3214-
def test_mean_custom_dtype(self):
3213+
@pytest.mark.parametrize(
3214+
"input_dtype",
3215+
(
3216+
"bool",
3217+
"uint16",
3218+
"int8",
3219+
"int64",
3220+
"float16",
3221+
"float32",
3222+
"float64",
3223+
"complex64",
3224+
"complex128",
3225+
),
3226+
)
3227+
@pytest.mark.parametrize(
3228+
"sum_dtype",
3229+
(
3230+
"bool",
3231+
"uint16",
3232+
"int8",
3233+
"int64",
3234+
"float16",
3235+
"float32",
3236+
"float64",
3237+
"complex64",
3238+
"complex128",
3239+
),
3240+
)
3241+
@pytest.mark.parametrize("axis", [None, ()])
3242+
def test_mean_custom_dtype(self, input_dtype, sum_dtype, axis):
32153243
# Test the ability to provide your own output dtype for a mean.
32163244

3217-
# We try multiple axis combinations even though axis should not matter.
3218-
axes = [None, 0, 1, [], [0], [1], [0, 1]]
3219-
idx = 0
3220-
for input_dtype in map(str, ps.all_types):
3221-
x = matrix(dtype=input_dtype)
3222-
for sum_dtype in map(str, ps.all_types):
3223-
axis = axes[idx % len(axes)]
3224-
# If the inner sum cannot be created, it will raise a
3225-
# TypeError.
3226-
try:
3227-
mean_var = x.mean(dtype=sum_dtype, axis=axis)
3228-
except TypeError:
3229-
pass
3230-
else:
3231-
# Executed if no TypeError was raised
3232-
if sum_dtype in discrete_dtypes:
3233-
assert mean_var.dtype == "float64", (mean_var.dtype, sum_dtype)
3234-
else:
3235-
assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype)
3236-
if (
3237-
"complex" in input_dtype or "complex" in sum_dtype
3238-
) and input_dtype != sum_dtype:
3239-
continue
3240-
f = function([x], mean_var)
3241-
data = np.random.random((3, 4)) * 10
3242-
data = data.astype(input_dtype)
3243-
# TODO FIXME: This is a bad test
3244-
f(data)
3245-
# Check that we can take the gradient, when implemented
3246-
if "complex" in mean_var.dtype:
3247-
continue
3248-
try:
3249-
grad(mean_var.sum(), x, disconnected_inputs="ignore")
3250-
except NotImplementedError:
3251-
# TrueDiv does not seem to have a gradient when
3252-
# the numerator is complex.
3253-
if mean_var.dtype in complex_dtypes:
3254-
pass
3255-
else:
3256-
raise
3245+
x = matrix(dtype=input_dtype)
3246+
# If the inner sum cannot be created, it will raise a TypeError.
3247+
mean_var = x.mean(dtype=sum_dtype, axis=axis)
3248+
if sum_dtype in discrete_dtypes:
3249+
assert mean_var.dtype == "float64", (mean_var.dtype, sum_dtype)
3250+
else:
3251+
assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype)
32573252

3258-
idx += 1
3253+
f = function([x], mean_var, mode="FAST_COMPILE")
3254+
data = np.ones((2, 1)).astype(input_dtype)
3255+
if axis != ():
3256+
expected_res = np.array(2).astype(sum_dtype) / 2
3257+
else:
3258+
expected_res = data
3259+
np.testing.assert_allclose(f(data), expected_res)
3260+
3261+
if "complex" not in mean_var.dtype:
3262+
grad(mean_var.sum(), x, disconnected_inputs="ignore")
32593263

32603264
def test_mean_precision(self):
32613265
# Check that the default accumulator precision is sufficient

0 commit comments

Comments
 (0)