Skip to content

Commit

Permalink
Implement dpnp.isneginf and dpnp.isposinf (#1888)
Browse files Browse the repository at this point in the history
* Implement dpnp.isneginf()

* Add tests for dpnp.isneginf()

* Implement dpnp.isposinf()

* Add tests for dpnp.isposinf()

* Add new functions to gen docs

* Add additional checks

* Add test_infinity_sign_errors

* Add sycl_queue/usm tests for logic functions

* Update tests

* Remove out dtype check

* Add TODO with support different out dtype

* Update test_logic_op_2in

---------

Co-authored-by: Anton <[email protected]>
  • Loading branch information
vlad-perevezentsev and antonwolfy authored Jun 26, 2024
1 parent 73ace12 commit 0b7c230
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 11 deletions.
2 changes: 2 additions & 0 deletions doc/reference/logic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ Infinities and NaNs
dpnp.isfinite
dpnp.isinf
dpnp.isnan
dpnp.isneginf
dpnp.isposinf


Array type testing
Expand Down
146 changes: 146 additions & 0 deletions dpnp/dpnp_iface_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
"isfinite",
"isinf",
"isnan",
"isneginf",
"isposinf",
"less",
"less_equal",
"logical_and",
Expand Down Expand Up @@ -777,6 +779,150 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
)


def isneginf(x, out=None):
"""
Test element-wise for negative infinity, return result as bool array.
For full documentation refer to :obj:`numpy.isneginf`.
Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
Input array.
out : {None, dpnp.ndarray, usm_ndarray}, optional
A location into which the result is stored. If provided, it must have a
shape that the input broadcasts to and a boolean data type.
If not provided or ``None``, a freshly-allocated boolean array
is returned.
Default: ``None``.
Returns
-------
out : dpnp.ndarray
Boolean array of same shape as ``x``.
See Also
--------
:obj:`dpnp.isinf` : Test element-wise for positive or negative infinity.
:obj:`dpnp.isposinf` : Test element-wise for positive infinity,
return result as bool array.
:obj:`dpnp.isnan` : Test element-wise for NaN and
return result as a boolean array.
:obj:`dpnp.isfinite` : Test element-wise for finiteness.
Examples
--------
>>> import dpnp as np
>>> x = np.array(np.inf)
>>> np.isneginf(-x)
array(True)
>>> np.isneginf(x)
array(False)
>>> x = np.array([-np.inf, 0., np.inf])
>>> np.isneginf(x)
array([ True, False, False])
>>> x = np.array([-np.inf, 0., np.inf])
>>> y = np.zeros(x.shape, dtype='bool')
>>> np.isneginf(x, y)
array([ True, False, False])
>>> y
array([ True, False, False])
"""

dpnp.check_supported_arrays_type(x)

if out is not None:
dpnp.check_supported_arrays_type(out)

x_dtype = x.dtype
if dpnp.issubdtype(x_dtype, dpnp.complexfloating):
raise TypeError(
f"This operation is not supported for {x_dtype} values "
"because it would be ambiguous."
)

is_inf = dpnp.isinf(x)
signbit = dpnp.signbit(x)

# TODO: support different out dtype #1717(dpctl)
return dpnp.logical_and(is_inf, signbit, out=out)


def isposinf(x, out=None):
"""
Test element-wise for positive infinity, return result as bool array.
For full documentation refer to :obj:`numpy.isposinf`.
Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
Input array.
out : {None, dpnp.ndarray, usm_ndarray}, optional
A location into which the result is stored. If provided, it must have a
shape that the input broadcasts to and a boolean data type.
If not provided or ``None``, a freshly-allocated boolean array
is returned.
Default: ``None``.
Returns
-------
out : dpnp.ndarray
Boolean array of same shape as ``x``.
See Also
--------
:obj:`dpnp.isinf` : Test element-wise for positive or negative infinity.
:obj:`dpnp.isneginf` : Test element-wise for negative infinity,
return result as bool array.
:obj:`dpnp.isnan` : Test element-wise for NaN and
return result as a boolean array.
:obj:`dpnp.isfinite` : Test element-wise for finiteness.
Examples
--------
>>> import dpnp as np
>>> x = np.array(np.inf)
>>> np.isposinf(x)
array(True)
>>> np.isposinf(-x)
array(False)
>>> x = np.array([-np.inf, 0., np.inf])
>>> np.isposinf(x)
array([False, False, True])
>>> x = np.array([-np.inf, 0., np.inf])
>>> y = np.zeros(x.shape, dtype='bool')
>>> np.isposinf(x, y)
array([False, False, True])
>>> y
array([False, False, True])
"""

dpnp.check_supported_arrays_type(x)

if out is not None:
dpnp.check_supported_arrays_type(out)

x_dtype = x.dtype
if dpnp.issubdtype(x_dtype, dpnp.complexfloating):
raise TypeError(
f"This operation is not supported for {x_dtype} values "
"because it would be ambiguous."
)

is_inf = dpnp.isinf(x)
signbit = ~dpnp.signbit(x)

# TODO: support different out dtype #1717(dpctl)
return dpnp.logical_and(is_inf, signbit, out=out)


_LESS_DOCSTRING = """
Computes the less-than test results for each element `x1_i` of
the input array `x1` with the respective element `x2_i` of the input array `x2`.
Expand Down
47 changes: 47 additions & 0 deletions tests/test_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .helper import (
get_all_dtypes,
get_float_complex_dtypes,
get_float_dtypes,
)


Expand Down Expand Up @@ -432,3 +433,49 @@ def test_finite(op, data, dtype):
dpnp_res = getattr(dpnp, op)(x, out=dp_out)
assert dp_out is dpnp_res
assert_equal(dpnp_res, np_res)


@pytest.mark.parametrize("func", ["isneginf", "isposinf"])
@pytest.mark.parametrize(
"data",
[
[dpnp.inf, -1, 0, 1, dpnp.nan, -dpnp.inf],
[[dpnp.inf, dpnp.nan], [dpnp.nan, 0], [1, -dpnp.inf]],
],
ids=[
"1D array",
"2D array",
],
)
@pytest.mark.parametrize("dtype", get_float_dtypes())
def test_infinity_sign(func, data, dtype):
x = dpnp.asarray(data, dtype=dtype)
np_res = getattr(numpy, func)(x.asnumpy())
dpnp_res = getattr(dpnp, func)(x)
assert_equal(dpnp_res, np_res)

dp_out = dpnp.empty(np_res.shape, dtype=dpnp.bool)
dpnp_res = getattr(dpnp, func)(x, out=dp_out)
assert dp_out is dpnp_res
assert_equal(dpnp_res, np_res)


@pytest.mark.parametrize("func", ["isneginf", "isposinf"])
def test_infinity_sign_errors(func):
data = [dpnp.inf, 0, -dpnp.inf]

# unsupported data type
x = dpnp.asarray(data, dtype="c8")
x_np = dpnp.asnumpy(x)
assert_raises(TypeError, getattr(dpnp, func), x)
assert_raises(TypeError, getattr(numpy, func), x_np)

# unsupported type
assert_raises(TypeError, getattr(dpnp, func), data)
assert_raises(TypeError, getattr(dpnp, func), x_np)

# unsupported `out` data type
x = dpnp.asarray(data, dtype=dpnp.default_float_type())
out = dpnp.empty_like(x, dtype="int32")
with pytest.raises(ValueError):
getattr(dpnp, func)(x, out=out)
83 changes: 83 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,40 @@ def test_1in_1out(func, data, device):
assert_sycl_queue_equal(result_queue, expected_queue)


@pytest.mark.parametrize(
"op",
[
"all",
"any",
"isfinite",
"isinf",
"isnan",
"isneginf",
"isposinf",
"logical_not",
],
)
@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_logic_op_1in(op, device):
x = dpnp.array(
[-dpnp.inf, -1.0, 0.0, 1.0, dpnp.inf, dpnp.nan], device=device
)
result = getattr(dpnp, op)(x)

x_orig = dpnp.asnumpy(x)
expected = getattr(numpy, op)(x_orig)
assert_dtype_allclose(result, expected)

expected_queue = x.get_array().sycl_queue
result_queue = result.get_array().sycl_queue

assert_sycl_queue_equal(result_queue, expected_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
Expand Down Expand Up @@ -705,6 +739,55 @@ def test_2in_1out(func, data1, data2, device):
assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue)


@pytest.mark.parametrize(
"op",
[
"equal",
"greater",
"greater_equal",
# TODO: unblock when dpnp.isclose() is updated
# "isclose",
"less",
"less_equal",
"logical_and",
"logical_or",
"logical_xor",
"not_equal",
],
)
@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_logic_op_2in(op, device):
x1 = dpnp.array(
[-dpnp.inf, -1.0, 0.0, 1.0, dpnp.inf, dpnp.nan], device=device
)
x2 = dpnp.array(
[dpnp.inf, 1.0, 0.0, -1.0, -dpnp.inf, dpnp.nan], device=device
)
# Remove NaN value from input arrays because numpy raises RuntimeWarning
if op in [
"greater",
"greater_equal",
"less",
"less_equal",
]:
x1 = x1[:-1]
x2 = x2[:-1]
result = getattr(dpnp, op)(x1, x2)

x1_orig = dpnp.asnumpy(x1)
x2_orig = dpnp.asnumpy(x2)
expected = getattr(numpy, op)(x1_orig, x2_orig)

assert_dtype_allclose(result, expected)

assert_sycl_queue_equal(result.sycl_queue, x1.sycl_queue)
assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue)


@pytest.mark.parametrize(
"func, data, scalar",
[
Expand Down
34 changes: 23 additions & 11 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,20 +357,32 @@ def test_tril_triu(func, usm_type):
@pytest.mark.parametrize(
"op",
[
"equal",
"greater",
"greater_equal",
"less",
"less_equal",
"logical_and",
"logical_or",
"logical_xor",
"not_equal",
"all",
"any",
"isfinite",
"isinf",
"isnan",
"isneginf",
"isposinf",
"logical_not",
],
ids=[
)
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
def test_coerced_usm_types_logic_op_1in(op, usm_type_x):
x = dp.arange(-10, 10, usm_type=usm_type_x)
res = getattr(dp, op)(x)

assert x.usm_type == res.usm_type == usm_type_x


@pytest.mark.parametrize(
"op",
[
"equal",
"greater",
"greater_equal",
# TODO: unblock when dpnp.isclose() is updated
# "isclose",
"less",
"less_equal",
"logical_and",
Expand All @@ -381,7 +393,7 @@ def test_tril_triu(func, usm_type):
)
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
def test_coerced_usm_types_logic_op(op, usm_type_x, usm_type_y):
def test_coerced_usm_types_logic_op_2in(op, usm_type_x, usm_type_y):
x = dp.arange(100, usm_type=usm_type_x)
y = dp.arange(100, usm_type=usm_type_y)[::-1]

Expand Down
13 changes: 13 additions & 0 deletions tests/third_party/cupy/logic_tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,16 @@ def test_isinf(self):

def test_isnan(self):
self.check_unary_nan("isnan")


class TestUfuncLike(unittest.TestCase):
@testing.numpy_cupy_array_equal()
def check_unary(self, name, xp):
a = xp.array([-3, xp.inf, -1, -xp.inf, 0, 1, 2, xp.nan])
return getattr(xp, name)(a)

def test_isneginf(self):
self.check_unary("isneginf")

def test_isposinf(self):
self.check_unary("isposinf")

0 comments on commit 0b7c230

Please sign in to comment.