@@ -3210,52 +3210,56 @@ def test_mean_default_dtype(self):
3210
3210
# TODO FIXME: This is a bad test
3211
3211
f (data )
3212
3212
3213
- @pytest .mark .slow
3214
- def test_mean_custom_dtype (self ):
3213
+ @pytest .mark .parametrize (
3214
+ "input_dtype" ,
3215
+ (
3216
+ "bool" ,
3217
+ "uint16" ,
3218
+ "int8" ,
3219
+ "int64" ,
3220
+ "float16" ,
3221
+ "float32" ,
3222
+ "float64" ,
3223
+ "complex64" ,
3224
+ "complex128" ,
3225
+ ),
3226
+ )
3227
+ @pytest .mark .parametrize (
3228
+ "sum_dtype" ,
3229
+ (
3230
+ "bool" ,
3231
+ "uint16" ,
3232
+ "int8" ,
3233
+ "int64" ,
3234
+ "float16" ,
3235
+ "float32" ,
3236
+ "float64" ,
3237
+ "complex64" ,
3238
+ "complex128" ,
3239
+ ),
3240
+ )
3241
+ @pytest .mark .parametrize ("axis" , [None , ()])
3242
+ def test_mean_custom_dtype (self , input_dtype , sum_dtype , axis ):
3215
3243
# Test the ability to provide your own output dtype for a mean.
3216
3244
3217
- # We try multiple axis combinations even though axis should not matter.
3218
- axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
3219
- idx = 0
3220
- for input_dtype in map (str , ps .all_types ):
3221
- x = matrix (dtype = input_dtype )
3222
- for sum_dtype in map (str , ps .all_types ):
3223
- axis = axes [idx % len (axes )]
3224
- # If the inner sum cannot be created, it will raise a
3225
- # TypeError.
3226
- try :
3227
- mean_var = x .mean (dtype = sum_dtype , axis = axis )
3228
- except TypeError :
3229
- pass
3230
- else :
3231
- # Executed if no TypeError was raised
3232
- if sum_dtype in discrete_dtypes :
3233
- assert mean_var .dtype == "float64" , (mean_var .dtype , sum_dtype )
3234
- else :
3235
- assert mean_var .dtype == sum_dtype , (mean_var .dtype , sum_dtype )
3236
- if (
3237
- "complex" in input_dtype or "complex" in sum_dtype
3238
- ) and input_dtype != sum_dtype :
3239
- continue
3240
- f = function ([x ], mean_var )
3241
- data = np .random .random ((3 , 4 )) * 10
3242
- data = data .astype (input_dtype )
3243
- # TODO FIXME: This is a bad test
3244
- f (data )
3245
- # Check that we can take the gradient, when implemented
3246
- if "complex" in mean_var .dtype :
3247
- continue
3248
- try :
3249
- grad (mean_var .sum (), x , disconnected_inputs = "ignore" )
3250
- except NotImplementedError :
3251
- # TrueDiv does not seem to have a gradient when
3252
- # the numerator is complex.
3253
- if mean_var .dtype in complex_dtypes :
3254
- pass
3255
- else :
3256
- raise
3245
+ x = matrix (dtype = input_dtype )
3246
+ # If the inner sum cannot be created, it will raise a TypeError.
3247
+ mean_var = x .mean (dtype = sum_dtype , axis = axis )
3248
+ if sum_dtype in discrete_dtypes :
3249
+ assert mean_var .dtype == "float64" , (mean_var .dtype , sum_dtype )
3250
+ else :
3251
+ assert mean_var .dtype == sum_dtype , (mean_var .dtype , sum_dtype )
3257
3252
3258
- idx += 1
3253
+ f = function ([x ], mean_var , mode = "FAST_COMPILE" )
3254
+ data = np .ones ((2 , 1 )).astype (input_dtype )
3255
+ if axis != ():
3256
+ expected_res = np .array (2 ).astype (sum_dtype ) / 2
3257
+ else :
3258
+ expected_res = data
3259
+ np .testing .assert_allclose (f (data ), expected_res )
3260
+
3261
+ if "complex" not in mean_var .dtype :
3262
+ grad (mean_var .sum (), x , disconnected_inputs = "ignore" )
3259
3263
3260
3264
def test_mean_precision (self ):
3261
3265
# Check that the default accumulator precision is sufficient
0 commit comments