Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[next]: Allow np.bool scalar in gtfn backend #1870

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -128,6 +128,7 @@ requires-python = '>=3.10, <3.12'
[project.optional-dependencies]
# bundles
all = ['gt4py[dace,formatting,jax,performance,testing]']
all-next = ['gt4py[dace-next,formatting,jax,performance,testing]']
# device-specific extras
cuda11 = ['cupy-cuda11x>=12.0']
cuda12 = ['cupy-cuda12x>=12.0']
@@ -443,9 +444,17 @@ conflicts = [
{extra = 'dace'},
{extra = 'dace-next'}
],
[
{extra = 'all'},
{extra = 'all-next'}
],
[
{extra = 'all'},
{extra = 'dace-next'}
],
[
{extra = 'all-next'},
{extra = 'dace'}
]
]

4 changes: 4 additions & 0 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
import diskcache
import factory
import filelock
import numpy as np

import gt4py._core.definitions as core_defs
import gt4py.next.allocators as next_allocators
@@ -34,6 +35,9 @@ def convert_arg(arg: Any) -> Any:
arr = arg.ndarray
origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain)))
return arr, origin
if isinstance(arg, np.bool_):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(arg, np.bool_):
if isinstance(arg, np.bool_):

Could you leave a comment why this is needed? Does numpy.float64 and python float work? If so, why does bool not work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree about adding a comment here and would also suggest to use a ternary operator:

    return bool(arg) if isinstance(arg, np.bool_) else arg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a switch-like if therefore I wouldn't do the ternary, like in the other cases.

# nanobind does not support implicit conversion of `np.bool` to `bool`
return bool(arg)
else:
return arg

Original file line number Diff line number Diff line change
@@ -7,8 +7,10 @@
# SPDX-License-Identifier: BSD-3-Clause

from functools import reduce

import numpy as np
import pytest

import gt4py.next as gtx
from gt4py.next import (
astype,
@@ -21,24 +23,24 @@
int64,
minimum,
neighbor_sum,
utils as gt_utils,
)
from gt4py.next.ffront.experimental import as_offset
from gt4py.next import utils as gt_utils

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
C2E,
E2V,
V2E,
E2VDim,
Edge,
IDim,
Ioff,
JDim,
KDim,
Koff,
V2EDim,
Vertex,
Edge,
cartesian_case,
unstructured_case,
unstructured_case_3d,
@@ -196,6 +198,21 @@ def testee(a: int32) -> cases.VField:
)


def test_np_bool_scalar_arg(unstructured_case):
"""Test scalar argument being turned into 0-dim field."""

@gtx.field_operator
def testee(a: gtx.bool) -> cases.VBoolField:
return broadcast(not a, (Vertex,))

a = np.bool_(True) # explicitly using a np.bool

ref = np.full([unstructured_case.default_sizes[Vertex]], not a, dtype=np.bool_)
out = cases.allocate(unstructured_case, testee, cases.RETURN)()

cases.verify(unstructured_case, testee, a, out=out, ref=ref)


def test_nested_scalar_arg(unstructured_case):
@gtx.field_operator
def testee_inner(a: int32) -> cases.VField:
85 changes: 50 additions & 35 deletions uv.lock

Large diffs are not rendered by default.