From d2ec2704cf6110417de1432f072c622ed98cda85 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 24 Feb 2025 10:24:08 +0100 Subject: [PATCH 01/19] step 1 --- src/gt4py/cartesian/backend/__init__.py | 3 + src/gt4py/cartesian/backend/base.py | 4 - src/gt4py/cartesian/backend/debug_backend.py | 80 +++++++ src/gt4py/cartesian/gtc/common.py | 3 + src/gt4py/cartesian/gtc/debug/__init__.py | 13 ++ .../cartesian/gtc/debug/debug_codegen.py | 196 ++++++++++++++++++ src/gt4py/cartesian/utils/__init__.py | 2 + src/gt4py/cartesian/utils/field.py | 80 +++++++ .../multi_feature_tests/test_debug_backend.py | 74 +++++++ 9 files changed, 451 insertions(+), 4 deletions(-) create mode 100644 src/gt4py/cartesian/backend/debug_backend.py create mode 100644 src/gt4py/cartesian/gtc/debug/__init__.py create mode 100644 src/gt4py/cartesian/gtc/debug/debug_codegen.py create mode 100644 src/gt4py/cartesian/utils/field.py create mode 100644 tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py diff --git a/src/gt4py/cartesian/backend/__init__.py b/src/gt4py/cartesian/backend/__init__.py index 4296e3b389..3179f116ac 100644 --- a/src/gt4py/cartesian/backend/__init__.py +++ b/src/gt4py/cartesian/backend/__init__.py @@ -19,6 +19,7 @@ register, ) from .cuda_backend import CudaBackend +from .debug_backend import DebugBackend from .gtcpp_backend import GTCpuIfirstBackend, GTCpuKfirstBackend, GTGpuBackend from .module_generator import BaseModuleGenerator from .numpy_backend import NumpyBackend @@ -32,9 +33,11 @@ "BasePyExtBackend", "CLIBackendMixin", "CudaBackend", + "DebugBackend", "GTCpuIfirstBackend", "GTCpuKfirstBackend", "GTGpuBackend", + "GTGpuBackend", "NumpyBackend", "PurePythonBackendCLIMixin", "from_name", diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 5bab0453a9..9a638ef0ee 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -338,10 +338,6 @@ def generate_computation(self) -> Dict[str, Union[str, Dict]]: source = self.make_module_source(ir=self.builder.gtir) return {str(file_name): source} - def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: - """Pure python backends typically will not support bindings.""" - return super().generate_bindings(language_name) - class BasePyExtBackend(BaseBackend): @property diff --git a/src/gt4py/cartesian/backend/debug_backend.py b/src/gt4py/cartesian/backend/debug_backend.py new file mode 100644 index 0000000000..7c430c9fc4 --- /dev/null +++ b/src/gt4py/cartesian/backend/debug_backend.py @@ -0,0 +1,80 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Type, Union + +from gt4py import storage +from gt4py.cartesian.backend.base import BaseBackend, CLIBackendMixin, register +from gt4py.cartesian.backend.numpy_backend import ModuleGenerator +from gt4py.cartesian.gtc.debug.debug_codegen import DebugCodeGen +from gt4py.cartesian.gtc.gtir_to_oir import GTIRToOIR +from gt4py.cartesian.gtc.passes.oir_pipeline import OirPipeline +from gt4py.eve.codegen import format_source + + +if TYPE_CHECKING: + from gt4py.cartesian.stencil_object import StencilObject + + +def recursive_write(root_path: Path, tree: dict[str, Union[str, dict]]): + root_path.mkdir(parents=True, exist_ok=True) + for key, value in tree.items(): + if isinstance(value, dict): + recursive_write(root_path / key, value) + else: + src_path = root_path / key + src_path.write_text(value) + + +@register +class DebugBackend(BaseBackend, CLIBackendMixin): + """Debug backend using plain python loops.""" + + name = "debug" + options: ClassVar[dict[str, Any]] = { + "oir_pipeline": {"versioning": True, "type": OirPipeline}, + # TODO: Implement this option in source code + "ignore_np_errstate": {"versioning": True, "type": bool}, + } + storage_info = storage.layout.NaiveCPULayout + languages = {"computation": "python", "bindings": ["python"]} + MODULE_GENERATOR_CLASS = ModuleGenerator + + def generate_computation(self) -> dict[str, Union[str, dict]]: + computation_name = ( + self.builder.caching.module_prefix + + "computation" + + self.builder.caching.module_postfix + + ".py" + ) + oir = GTIRToOIR().visit(self.builder.gtir) + source_code = DebugCodeGen().visit(oir) + + if self.builder.options.format_source: + source_code = format_source("python", source_code) + + return {computation_name: source_code} + + def generate_bindings(self, language_name: str) -> dict[str, Union[str, dict]]: + super().generate_bindings(language_name) + return {self.builder.module_path.name: self.make_module_source()} + + def generate(self) -> Type["StencilObject"]: + self.check_options(self.builder.options) + src_dir = self.builder.module_path.parent + if not self.builder.options._impl_opts.get("disable-code-generation", False): + src_dir.mkdir(parents=True, exist_ok=True) + recursive_write(src_dir, self.generate_computation()) + return self.make_module() diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index 60236a3e97..9ce0df19bd 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -317,6 +317,9 @@ def zero(cls) -> CartesianOffset: def to_dict(self) -> Dict[str, int]: return {"i": self.i, "j": self.j, "k": self.k} + def to_str(self) -> str: + return f"i + {self.i}, j + {self.j}, k + {self.k}" + class VariableKOffset(eve.GenericNode, Generic[ExprT]): k: ExprT diff --git a/src/gt4py/cartesian/gtc/debug/__init__.py b/src/gt4py/cartesian/gtc/debug/__init__.py new file mode 100644 index 0000000000..6c43e2f12a --- /dev/null +++ b/src/gt4py/cartesian/gtc/debug/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py new file mode 100644 index 0000000000..b72860525e --- /dev/null +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -0,0 +1,196 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py import eve +from gt4py.cartesian import utils +from gt4py.cartesian.gtc.common import ( + AxisBound, + DataType, + FieldAccess, + HorizontalInterval, + HorizontalMask, + LevelMarker, +) +from gt4py.cartesian.gtc.definitions import Extent +from gt4py.cartesian.gtc.oir import ( + AssignStmt, + BinaryOp, + Cast, + Decl, + FieldDecl, + HorizontalExecution, + HorizontalRestriction, + Interval, + Literal, + Stencil, + Temporary, +) +from gt4py.cartesian.gtc.passes.oir_optimizations.utils import StencilExtentComputer +from gt4py.eve import codegen + + +class DebugCodeGen(codegen.TemplatedGenerator, eve.VisitorWithSymbolTableTrait): + def __init__(self) -> None: + self.body = utils.text.TextBlock() + + def visit_VerticalLoop(self): + pass + + def generate_field_decls(self, declarations: list[Decl]) -> None: + for declaration in declarations: + if isinstance(declaration, FieldDecl): + self.body.append( + f"{declaration.name} = Field({declaration.name}, _origin_['{declaration.name}'], " + f"({', '.join([str(x) for x in declaration.dimensions])}))" + ) + + def visit_FieldAccess(self, field_access: FieldAccess, **_): + full_string = field_access.name + "[" + field_access.offset.to_str() + "]" + return full_string + + def visit_AssignStmt(self, assignment_statement: AssignStmt, **_): + self.body.append( + self.visit(assignment_statement.left) + "=" + self.visit(assignment_statement.right) + ) + + def visit_BinaryOp(self, binary: BinaryOp, **_): + return self.visit(binary.left) + str(binary.op) + self.visit(binary.right) + + def visit_Literal(self, literal: Literal, **_): + return str(literal.value) + + def visit_Cast(self, cast: Cast, **_): + return self.visit(cast.expr) + + def visit_HorizontalExecution(self, horizontal_execution: HorizontalExecution, **_): + for stmt in horizontal_execution.body: + self.visit(stmt) + + def visit_HorizontalMask(self, horizontal_mask: HorizontalMask, **_): + i_min, i_max = self.visit(horizontal_mask.i, var="i") + j_min, j_max = self.visit(horizontal_mask.j, var="j") + conditions = [] + if i_min is not None: + conditions.append(f"({i_min}) <= i") + if i_max is not None: + conditions.append(f"i < ({i_max})") + if j_min is not None: + conditions.append(f"({j_min}) <= j") + if j_max is not None: + conditions.append(f"j < ({j_max})") + assert len(conditions) + if_code = f"if( {' and '.join(conditions)} ):" + self.body.append(if_code) + + def visit_HorizontalInterval(self, horizontal_interval: HorizontalInterval, **kwargs): + return self.visit( + horizontal_interval.start, **kwargs + ) if horizontal_interval.start else None, self.visit( + horizontal_interval.end, **kwargs + ) if horizontal_interval.end else None + + def visit_HorizontalRestriction(self, horizontal_restriction: HorizontalRestriction, **_): + self.visit(horizontal_restriction.mask) + self.body.indent() + self.visit(horizontal_restriction.body) + self.body.dedent() + + @staticmethod + def compute_extents(node: Stencil, **_) -> tuple[dict[str, Extent], dict[int, Extent]]: + ctx: StencilExtentComputer.Context = StencilExtentComputer().visit(node) + return ctx.fields, ctx.blocks + + def generate_temp_decls( + self, temporary_declarations: list[Temporary], field_extents: dict[str, Extent] + ): + for declaration in temporary_declarations: + self.body.append(self.visit(declaration, field_extents=field_extents)) + + def visit_Temporary(self, temporary_declaration: Temporary, **kwargs): + field_extents = kwargs["field_extents"] + local_field_extent = field_extents[temporary_declaration.name] + i_padding: int = local_field_extent[0][1] - local_field_extent[0][0] + j_padding: int = local_field_extent[1][1] - local_field_extent[1][0] + shape: list[str] = [f"i_size + {i_padding}", f"j_size + {j_padding}", "k_size"] + data_dimensions: list[str] = [str(dim) for dim in temporary_declaration.data_dims] + shape = shape + data_dimensions + shape_decl = ", ".join(shape) + dtype = self.visit(temporary_declaration.dtype) + field_offset = tuple(-ext[0] for ext in local_field_extent) + offset = [str(off) for off in field_offset] + ["0"] * ( + 1 + len(temporary_declaration.data_dims) + ) + return f"{temporary_declaration.name} = Field.empty(({shape_decl}), {dtype}, ({', '.join(offset)}))" + + def visit_DataType(self, data_type: DataType, **_): + if data_type not in {DataType.BOOL}: + return f"np.{data_type.name.lower()}" + else: + return data_type.name.lower() + + def visit_Stencil(self, stencil: Stencil, **_): + field_extents, block_extents = self.compute_extents(stencil) + self.body.append("from gt4py.cartesian.utils import Field") + self.body.append("import numpy as np") + + function_signature = "def run(*" + args = [] + for param in stencil.params: + args.append(self.visit(param)) + function_signature = ",".join([function_signature, *args]) + function_signature += ",_domain_, _origin_):" + self.body.append(function_signature) + self.body.indent() + self.body.append("# ===== Domain Description ===== #") + self.body.append("i_0, j_0, k_0 = 0,0,0") + self.body.append("i_size, j_size, k_size = _domain_") + self.body.empty_line() + self.body.append("# ===== Temporary Declaration ===== #") + self.generate_temp_decls(stencil.declarations, field_extents) + self.body.empty_line() + self.body.append("# ===== Field Declaration ===== #") + self.generate_field_decls(stencil.params) + self.body.empty_line() + + for loop in stencil.vertical_loops: + for section in loop.sections: + loop_bounds = self.visit(section.interval, var="k") + loop_code = "for k in range(" + loop_bounds + "):" + self.body.append(loop_code) + self.body.indent() + for execution in section.horizontal_executions: + extents = block_extents[id(execution)] + i_loop = f"for i in range(i_0 + {extents[0][0]} , i_size + {extents[0][1]}):" + self.body.append(i_loop) + self.body.indent() + j_loop = f"for j in range(j_0 + {extents[1][0]} , j_size + {extents[1][1]}):" + self.body.append(j_loop) + self.body.indent() + self.visit(execution) + self.body.dedent() + self.body.dedent() + self.body.dedent() + return self.body.text + + def visit_FieldDecl(self, field_decl: FieldDecl, **_): + return str(field_decl.name) + + def visit_AxisBound(self, axis_bound: AxisBound, **kwargs): + if axis_bound.level == LevelMarker.START: + return f"{kwargs['var']}_0 + {axis_bound.offset}" + if axis_bound.level == LevelMarker.END: + return f"{kwargs['var']}_size + {axis_bound.offset}" + + def visit_Interval(self, interval: Interval, **kwargs): + return ",".join([self.visit(interval.start, **kwargs), self.visit(interval.end, **kwargs)]) diff --git a/src/gt4py/cartesian/utils/__init__.py b/src/gt4py/cartesian/utils/__init__.py index 626d29b167..e8e00275e8 100644 --- a/src/gt4py/cartesian/utils/__init__.py +++ b/src/gt4py/cartesian/utils/__init__.py @@ -35,6 +35,7 @@ shashed_id, slugify, ) +from .field import Field __all__ = [ # noqa: RUF022 `__all__` is not sorted @@ -51,6 +52,7 @@ "classmethod_to_function", "classproperty", "compose", + "Field", "flatten", "flatten_iter", "get_member", diff --git a/src/gt4py/cartesian/utils/field.py b/src/gt4py/cartesian/utils/field.py new file mode 100644 index 0000000000..b2f28decc0 --- /dev/null +++ b/src/gt4py/cartesian/utils/field.py @@ -0,0 +1,80 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numbers +from typing import Tuple + +import numpy as np + + +class Field: + def __init__( + self, field: np.ndarray, offsets: Tuple[int, ...], dimensions: Tuple[bool, bool, bool] + ): + ii = iter(range(3)) + self.idx_to_data = tuple( + [next(ii) if has_dim else None for has_dim in dimensions] + + list(range(sum(dimensions), len(field.shape))) + ) + + shape = [field.shape[i] if i is not None else 1 for i in self.idx_to_data] + self.field_view = np.reshape(field.data, shape).view(np.ndarray) + + self.offsets = offsets + + @classmethod + def empty(cls, shape, dtype, offset): + return cls(np.empty(shape, dtype=dtype), offset, (True, True, True)) + + def shim_key(self, key): + new_args = [] + if not isinstance(key, tuple): + key = (key,) + for index in self.idx_to_data: + if index is None: + new_args.append(slice(None, None)) + else: + idx = key[index] + offset = self.offsets[index] + if isinstance(idx, slice): + new_args.append( + slice(idx.start + offset, idx.stop + offset, idx.step) if offset else idx + ) + else: + new_args.append(idx + offset) + if not isinstance(new_args[2], (numbers.Integral, slice)): + new_args = self.broadcast_and_clip_variable_k(new_args) + return tuple(new_args) + + def broadcast_and_clip_variable_k(self, new_args: list): + assert isinstance(new_args[0], slice) and isinstance(new_args[1], slice) + if np.max(new_args[2]) >= self.field_view.shape[2] or np.min(new_args[2]) < 0: + new_args[2] = np.clip(new_args[2].copy(), 0, self.field_view.shape[2] - 1) + new_args[:2] = np.broadcast_arrays( + np.expand_dims( + np.arange(new_args[0].start, new_args[0].stop), + axis=tuple(i for i in range(self.field_view.ndim) if i != 0), + ), + np.expand_dims( + np.arange(new_args[1].start, new_args[1].stop), + axis=tuple(i for i in range(self.field_view.ndim) if i != 1), + ), + ) + return new_args + + def __getitem__(self, key): + return self.field_view.__getitem__(self.shim_key(key)) + + def __setitem__(self, key, value): + return self.field_view.__setitem__(self.shim_key(key), value) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py new file mode 100644 index 0000000000..bab315e782 --- /dev/null +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py @@ -0,0 +1,74 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +from gt4py import storage as gt_storage +from gt4py.cartesian import gtscript +from gt4py.cartesian.gtscript import BACKWARD, PARALLEL, computation, interval + + +def test_simple_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend="debug") + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): + with computation(BACKWARD): + with interval(-2, -1): # block 1 + field_out = field_in + with interval(0, -2): # block 2 + field_out = field_in + with computation(BACKWARD): + with interval(-1, None): # block 3 + field_out = 2 * field_in + with interval(0, -1): # block 4 + field_out = 3 * field_in + + stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 0:-1], 3) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, -1], 2) + + +def test_tmp_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend="debug") + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): + with computation(PARALLEL): + with interval(...): + tmp = field_in + 1 + with computation(PARALLEL): + with interval(...): + field_out = tmp[-1, 0, 0] + tmp[1, 0, 0] + + stencil(field_in, field_out, origin=(1, 1, 0), domain=(4, 4, 6)) + + # the inside of the domain is 4 + np.testing.assert_allclose(field_out.view(np.ndarray)[1:-1, 1:-1, :], 4) + # the rest is 0 + np.testing.assert_allclose(field_out.view(np.ndarray)[0:1, :, :], 0) + np.testing.assert_allclose(field_out.view(np.ndarray)[-1:, :, :], 0) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, 0:1, :], 0) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, -1:, :], 0) From d693eb2e29e99b97907e3b693ed2462ea7aa2d42 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 24 Feb 2025 10:29:23 +0100 Subject: [PATCH 02/19] step2 --- .gitignore | 2 +- src/gt4py/cartesian/config.py | 1 + .../cartesian/frontend/gtscript_frontend.py | 20 +- .../cartesian/gtc/debug/debug_codegen.py | 280 ++++++++++++------ .../multi_feature_tests/test_debug_backend.py | 189 ++++++++++++ 5 files changed, 397 insertions(+), 95 deletions(-) diff --git a/.gitignore b/.gitignore index ebbbfaebeb..34992bf9f7 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ _local /src/__init__.py /tests/__init__.py .gt_cache/ -.gt4py_cache/ +.gt_cache*/ .gt_cache_pytest*/ # DaCe diff --git a/src/gt4py/cartesian/config.py b/src/gt4py/cartesian/config.py index 5aa32506b7..4719fed945 100644 --- a/src/gt4py/cartesian/config.py +++ b/src/gt4py/cartesian/config.py @@ -65,6 +65,7 @@ "extra_link_args": extra_link_args, "parallel_jobs": multiprocessing.cpu_count(), "cpp_template_depth": os.environ.get("GT_CPP_TEMPLATE_DEPTH", GT_CPP_TEMPLATE_DEPTH), + "literal_floating_point_precision": os.environ.get("GT4PY_LITERAL_PRECISION", None), } if GT4PY_USE_HIP: build_settings["cuda_library_path"] = os.path.join(CUDA_ROOT, "lib") diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 4d8ac98529..f56e23564c 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -21,6 +21,7 @@ import numpy as np from gt4py.cartesian import definitions as gt_definitions, gtscript, utils as gt_utils +from gt4py.cartesian.config import build_settings from gt4py.cartesian.frontend import node_util, nodes from gt4py.cartesian.frontend.defir_to_gtir import DefIRToGTIR, UnrollVectorAssignments from gt4py.cartesian.gtc import utils as gtc_utils @@ -1013,11 +1014,20 @@ def visit_Constant( loc=nodes.Location.from_ast_node(node), ) elif isinstance(value, numbers.Number): - value_type = ( - self.dtypes[type(value)] - if self.dtypes and type(value) in self.dtypes.keys() - else np.dtype(type(value)) - ) + if self.dtypes and type(value) in self.dtypes.keys(): + value_type = self.dtypes[type(value)] + else: + if build_settings["literal_floating_point_precision"] is not None: + if isinstance(value, int): + value_type = np.dtype( + f"i{int(int(build_settings['literal_floating_point_precision'])/8)}" + ) + else: + value_type = np.dtype( + f"f{int(int(build_settings['literal_floating_point_precision'])/8)}" + ) + else: + value_type = np.dtype(type(value)) data_type = nodes.DataType.from_dtype(value_type) return nodes.ScalarLiteral(value=value, data_type=data_type) else: diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py index b72860525e..a4a709ddc2 100644 --- a/src/gt4py/cartesian/gtc/debug/debug_codegen.py +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -21,6 +21,8 @@ HorizontalInterval, HorizontalMask, LevelMarker, + LoopOrder, + While, ) from gt4py.cartesian.gtc.definitions import Extent from gt4py.cartesian.gtc.oir import ( @@ -33,8 +35,17 @@ HorizontalRestriction, Interval, Literal, + LocalScalar, + MaskStmt, + NativeFuncCall, + ScalarAccess, + ScalarDecl, Stencil, Temporary, + TernaryOp, + UnaryOp, + VerticalLoop, + VerticalLoopSection, ) from gt4py.cartesian.gtc.passes.oir_optimizations.utils import StencilExtentComputer from gt4py.eve import codegen @@ -44,8 +55,44 @@ class DebugCodeGen(codegen.TemplatedGenerator, eve.VisitorWithSymbolTableTrait): def __init__(self) -> None: self.body = utils.text.TextBlock() - def visit_VerticalLoop(self): - pass + def visit_Stencil(self, stencil: Stencil, **_): + self.generate_imports() + + self.generate_run_function(stencil) + + field_extents, block_extents = self.compute_extents(stencil) + self.initial_declarations(stencil, field_extents) + self.generate_stencil_code(stencil, block_extents) + + return self.body.text + + def generate_imports(self): + self.body.append("import numpy as np") + self.body.append("from gt4py.cartesian.gtc import ufuncs") + self.body.append("from gt4py.cartesian.utils import Field") + + @staticmethod + def compute_extents(node: Stencil, **_) -> tuple[dict[str, Extent], dict[int, Extent]]: + ctx: StencilExtentComputer.Context = StencilExtentComputer().visit(node) + return ctx.fields, ctx.blocks + + def initial_declarations(self, stencil: Stencil, field_extents: dict[str, Extent]): + self.body.append("# ===== Domain Description ===== #") + self.body.append("i_0, j_0, k_0 = 0,0,0") + self.body.append("i_size, j_size, k_size = _domain_") + self.body.empty_line() + self.body.append("# ===== Temporary Declaration ===== #") + self.generate_temp_decls(stencil.declarations, field_extents) + self.body.empty_line() + self.body.append("# ===== Field Declaration ===== #") + self.generate_field_decls(stencil.params) + self.body.empty_line() + + def generate_temp_decls( + self, temporary_declarations: list[Temporary], field_extents: dict[str, Extent] + ) -> None: + for declaration in temporary_declarations: + self.body.append(self.visit(declaration, field_extents=field_extents)) def generate_field_decls(self, declarations: list[Decl]) -> None: for declaration in declarations: @@ -55,8 +102,117 @@ def generate_field_decls(self, declarations: list[Decl]) -> None: f"({', '.join([str(x) for x in declaration.dimensions])}))" ) - def visit_FieldAccess(self, field_access: FieldAccess, **_): - full_string = field_access.name + "[" + field_access.offset.to_str() + "]" + def generate_run_function(self, stencil: Stencil): + function_signature = "def run(*" + args = [] + for param in stencil.params: + args.append(self.visit(param)) + function_signature = ",".join([function_signature, *args]) + function_signature += ",_domain_, _origin_):" + self.body.append(function_signature) + self.body.indent() + + def generate_stencil_code(self, stencil: Stencil, block_extents: dict[int, Extent]): + for loop in stencil.vertical_loops: + for section in loop.sections: + with self.create_k_loop_code(section, loop): + for execution in section.horizontal_executions: + with self.generate_ij_loop(block_extents, execution): + self.visit(execution) + + @contextmanager + def create_k_loop_code(self, section: VerticalLoopSection, loop: VerticalLoop) -> Generator: + loop_bounds: str = self.visit(section.interval, var="k", direction=loop.loop_order) + iterator = "1" if loop.loop_order != LoopOrder.BACKWARD else "-1" + loop_code = "for k in range(" + loop_bounds + "," + iterator + "):" + self.body.append(loop_code) + self.body.indent() + yield + self.body.dedent() + + @contextmanager + def generate_ij_loop( + self, block_extents: dict[int, Extent], execution: HorizontalExecution + ) -> Generator: + extents = block_extents[id(execution)] + i_loop = f"for i in range(i_0 + {extents[0][0]} , i_size + {extents[0][1]}):" + self.body.append(i_loop) + self.body.indent() + j_loop = f"for j in range(j_0 + {extents[1][0]} , j_size + {extents[1][1]}):" + self.body.append(j_loop) + self.body.indent() + yield + self.body.dedent() + self.body.dedent() + + def visit_While(self, while_node: While, **_) -> None: + while_condition = self.visit(while_node.cond) + while_code = f"while {while_condition}:" + self.body.append(while_code) + self.body.indent() + for statement in while_node.body: + self.visit(statement) + self.body.dedent() + + def visit_FieldDecl(self, field_decl: FieldDecl, **_) -> str: + return str(field_decl.name) + + def visit_AxisBound(self, axis_bound: AxisBound, **kwargs): + if axis_bound.level == LevelMarker.START: + return f"{kwargs['var']}_0 + {axis_bound.offset}" + if axis_bound.level == LevelMarker.END: + return f"{kwargs['var']}_size + {axis_bound.offset}" + + def visit_Interval(self, interval: Interval, **kwargs): + if kwargs["direction"] == LoopOrder.BACKWARD: + return ",".join( + [ + self.visit(interval.end, **kwargs) + "- 1", + self.visit(interval.start, **kwargs) + "- 1", + ] + ) + else: + return ",".join( + [self.visit(interval.start, **kwargs), self.visit(interval.end, **kwargs)] + ) + + def visit_Temporary(self, temporary_declaration: Temporary, **kwargs) -> str: + field_extents = kwargs["field_extents"] + local_field_extent = field_extents[temporary_declaration.name] + i_padding: int = local_field_extent[0][1] - local_field_extent[0][0] + j_padding: int = local_field_extent[1][1] - local_field_extent[1][0] + shape: list[str] = [f"i_size + {i_padding}", f"j_size + {j_padding}", "k_size"] + data_dimensions: list[str] = [str(dim) for dim in temporary_declaration.data_dims] + shape = shape + data_dimensions + shape_decl = ", ".join(shape) + dtype: str = self.visit(temporary_declaration.dtype) + field_offset = tuple(-ext[0] for ext in local_field_extent) + offset = [str(off) for off in field_offset] + ["0"] * ( + 1 + len(temporary_declaration.data_dims) + ) + return f"{temporary_declaration.name} = Field.empty(({shape_decl}), {dtype}, ({', '.join(offset)}))" + + def visit_DataType(self, data_type: DataType, **_) -> str: + if data_type not in {DataType.BOOL}: + return f"np.{data_type.name.lower()}" + else: + return data_type.name.lower() + + def visit_FieldAccess(self, field_access: FieldAccess, **_) -> str: + if field_access.data_index: + data_index_access = ",".join( + [self.visit(data_index) for data_index in field_access.data_index] + ) + full_string = ( + field_access.name + + "[" + + field_access.offset.to_str() + + "," + + data_index_access + + "]" + ) + else: + full_string = field_access.name + "[" + field_access.offset.to_str() + "]" return full_string def visit_AssignStmt(self, assignment_statement: AssignStmt, **_): @@ -67,11 +223,15 @@ def visit_AssignStmt(self, assignment_statement: AssignStmt, **_): def visit_BinaryOp(self, binary: BinaryOp, **_): return self.visit(binary.left) + str(binary.op) + self.visit(binary.right) - def visit_Literal(self, literal: Literal, **_): - return str(literal.value) + def visit_Literal(self, literal: Literal, **_) -> str: + if literal.dtype.bit_count() != 4: + literal_code = f"{self.visit(literal.dtype)}({literal.value})" + else: + literal_code = str(literal.value) + return literal_code - def visit_Cast(self, cast: Cast, **_): - return self.visit(cast.expr) + def visit_Cast(self, cast: Cast, **_) -> str: + return f"{self.visit(cast.dtype)}({self.visit(cast.expr)})" def visit_HorizontalExecution(self, horizontal_execution: HorizontalExecution, **_): for stmt in horizontal_execution.body: @@ -106,91 +266,33 @@ def visit_HorizontalRestriction(self, horizontal_restriction: HorizontalRestrict self.visit(horizontal_restriction.body) self.body.dedent() - @staticmethod - def compute_extents(node: Stencil, **_) -> tuple[dict[str, Extent], dict[int, Extent]]: - ctx: StencilExtentComputer.Context = StencilExtentComputer().visit(node) - return ctx.fields, ctx.blocks - - def generate_temp_decls( - self, temporary_declarations: list[Temporary], field_extents: dict[str, Extent] - ): - for declaration in temporary_declarations: - self.body.append(self.visit(declaration, field_extents=field_extents)) - - def visit_Temporary(self, temporary_declaration: Temporary, **kwargs): - field_extents = kwargs["field_extents"] - local_field_extent = field_extents[temporary_declaration.name] - i_padding: int = local_field_extent[0][1] - local_field_extent[0][0] - j_padding: int = local_field_extent[1][1] - local_field_extent[1][0] - shape: list[str] = [f"i_size + {i_padding}", f"j_size + {j_padding}", "k_size"] - data_dimensions: list[str] = [str(dim) for dim in temporary_declaration.data_dims] - shape = shape + data_dimensions - shape_decl = ", ".join(shape) - dtype = self.visit(temporary_declaration.dtype) - field_offset = tuple(-ext[0] for ext in local_field_extent) - offset = [str(off) for off in field_offset] + ["0"] * ( - 1 + len(temporary_declaration.data_dims) - ) - return f"{temporary_declaration.name} = Field.empty(({shape_decl}), {dtype}, ({', '.join(offset)}))" + def visit_VerticalLoop(self): + pass - def visit_DataType(self, data_type: DataType, **_): - if data_type not in {DataType.BOOL}: - return f"np.{data_type.name.lower()}" - else: - return data_type.name.lower() + def visit_ScalarAccess(self, scalar_access: ScalarAccess, **_): + return scalar_access.name - def visit_Stencil(self, stencil: Stencil, **_): - field_extents, block_extents = self.compute_extents(stencil) - self.body.append("from gt4py.cartesian.utils import Field") - self.body.append("import numpy as np") + def visit_ScalarDecl(self, scalar_declaration: ScalarDecl, **_) -> str: + return scalar_declaration.name - function_signature = "def run(*" - args = [] - for param in stencil.params: - args.append(self.visit(param)) - function_signature = ",".join([function_signature, *args]) - function_signature += ",_domain_, _origin_):" - self.body.append(function_signature) - self.body.indent() - self.body.append("# ===== Domain Description ===== #") - self.body.append("i_0, j_0, k_0 = 0,0,0") - self.body.append("i_size, j_size, k_size = _domain_") - self.body.empty_line() - self.body.append("# ===== Temporary Declaration ===== #") - self.generate_temp_decls(stencil.declarations, field_extents) - self.body.empty_line() - self.body.append("# ===== Field Declaration ===== #") - self.generate_field_decls(stencil.params) - self.body.empty_line() + def visit_NativeFuncCall(self, native_function_call: NativeFuncCall, **_) -> str: + arglist = [self.visit(arg) for arg in native_function_call.args] + arguments = ",".join(arglist) + return f"ufuncs.{native_function_call.func.value}({arguments})" - for loop in stencil.vertical_loops: - for section in loop.sections: - loop_bounds = self.visit(section.interval, var="k") - loop_code = "for k in range(" + loop_bounds + "):" - self.body.append(loop_code) - self.body.indent() - for execution in section.horizontal_executions: - extents = block_extents[id(execution)] - i_loop = f"for i in range(i_0 + {extents[0][0]} , i_size + {extents[0][1]}):" - self.body.append(i_loop) - self.body.indent() - j_loop = f"for j in range(j_0 + {extents[1][0]} , j_size + {extents[1][1]}):" - self.body.append(j_loop) - self.body.indent() - self.visit(execution) - self.body.dedent() - self.body.dedent() - self.body.dedent() - return self.body.text + def visit_UnaryOp(self, unary_operator: UnaryOp, **_) -> str: + return unary_operator.op.value + " " + self.visit(unary_operator.expr) - def visit_FieldDecl(self, field_decl: FieldDecl, **_): - return str(field_decl.name) + def visit_TernaryOp(self, ternary_operator: TernaryOp, **_) -> None: + return f"{self.visit(ternary_operator.true_expr)} if {self.visit(ternary_operator.cond)} else {self.visit(ternary_operator.false_expr)}" - def visit_AxisBound(self, axis_bound: AxisBound, **kwargs): - if axis_bound.level == LevelMarker.START: - return f"{kwargs['var']}_0 + {axis_bound.offset}" - if axis_bound.level == LevelMarker.END: - return f"{kwargs['var']}_size + {axis_bound.offset}" + def visit_LocalScalar(self, local_scalar: LocalScalar, **__) -> None: + raise NotImplementedError( + "This state should not be reached because LocalTemporariesToScalars should not have been called." + ) - def visit_Interval(self, interval: Interval, **kwargs): - return ",".join([self.visit(interval.start, **kwargs), self.visit(interval.end, **kwargs)]) + def visit_MaskStmt(self, mask_statement: MaskStmt, **_): + self.body.append(f"if {self.visit(mask_statement.mask)}:") + with self.body.indented(): + for statement in mask_statement.body: + self.visit(statement) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py index bab315e782..097ad3b2da 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py @@ -72,3 +72,192 @@ def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.f np.testing.assert_allclose(field_out.view(np.ndarray)[-1:, :, :], 0) np.testing.assert_allclose(field_out.view(np.ndarray)[:, 0:1, :], 0) np.testing.assert_allclose(field_out.view(np.ndarray)[:, -1:, :], 0) + + +def test_backward_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend="debug") + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): + with computation(BACKWARD): + with interval(-1, None): + field_in = 2 + field_out = field_in + with interval(0, -1): + field_in = field_in[0, 0, 1] + 1 + field_out = field_in + + stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 0], 5) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 4) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 2], 3) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 3], 2) + + +def test_while_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend="debug") + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): + with computation(PARALLEL): + with interval(...): + while field_in < 10: + field_in += 1 + field_out = field_in + + stencil(field_in, field_out) + + # the inside of the domain is 10 + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 10) + + +def test_higher_dim_literal_stencil(): + FLOAT64_NDDIM = (np.float64, (4,)) + + field_in = gt_storage.ones( + dtype=FLOAT64_NDDIM, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_in[:, :, :, 2] = 5 + + @gtscript.stencil(backend="debug") + def stencil( + vec_field: gtscript.Field[FLOAT64_NDDIM], + out_field: gtscript.Field[np.float64], + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = vec_field[0, 0, 0][2] + + stencil(field_in, field_out) + + # the inside of the domain is 5 + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 5) + + +def test_higher_dim_scalar_stencil(): + FLOAT64_NDDIM = (np.float64, (4,)) + + field_in = gt_storage.ones( + dtype=FLOAT64_NDDIM, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_in[:, :, :, 2] = 5 + + @gtscript.stencil(backend="debug") + def stencil( + vec_field: gtscript.Field[FLOAT64_NDDIM], + out_field: gtscript.Field[np.float64], + scalar_argument: int, + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = vec_field[0, 0, 0][scalar_argument] + + stencil(field_in, field_out, 2) + + # the inside of the domain is 5 + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 5) + + +def test_native_function_call_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend="debug") + def test_stencil( + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = in_field[0, 0, 0] + sin(0.848062) + + test_stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 1.75) + + +def test_unary_operator_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend="debug") + def test_stencil( + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = -in_field[0, 0, 0] + + test_stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], -1) + + +def test_ternary_operator_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[0, 0, 1] = 20 + + @gtscript.stencil(backend="debug") + def test_stencil( + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = in_field[0, 0, 0] if in_field > 10 else in_field[0, 0, 0] + 1 + + test_stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[0, 0, 1], 20) + np.testing.assert_allclose(field_out.view(np.ndarray)[1:, 1:, 1], 2) + + +def test_mask_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[0, 0, 1] = -20 + + @gtscript.stencil(backend="debug") + def test_stencil( + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], + ): + with computation(PARALLEL), interval(...): + if in_field[0, 0, 0] > 0: + out_field[0, 0, 0] = in_field + else: + out_field[0, 0, 0] = 1 + + test_stencil(field_in, field_out) + + assert np.all(field_out.view(np.ndarray) > 0) From 1913e380fa16b881d64bc4ded10152f3ddbf11c9 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 24 Feb 2025 10:37:19 +0100 Subject: [PATCH 03/19] step3 --- src/gt4py/cartesian/backend/debug_backend.py | 6 + .../cartesian/frontend/gtscript_frontend.py | 40 ++++++ src/gt4py/cartesian/gtc/common.py | 13 +- .../cartesian/gtc/debug/debug_codegen.py | 110 ++++++++++++---- src/gt4py/cartesian/gtc/ufuncs.py | 3 + .../multi_feature_tests/test_debug_backend.py | 120 +++++++++++++++++- 6 files changed, 263 insertions(+), 29 deletions(-) diff --git a/src/gt4py/cartesian/backend/debug_backend.py b/src/gt4py/cartesian/backend/debug_backend.py index 7c430c9fc4..b0b78296ac 100644 --- a/src/gt4py/cartesian/backend/debug_backend.py +++ b/src/gt4py/cartesian/backend/debug_backend.py @@ -20,6 +20,10 @@ from gt4py.cartesian.backend.numpy_backend import ModuleGenerator from gt4py.cartesian.gtc.debug.debug_codegen import DebugCodeGen from gt4py.cartesian.gtc.gtir_to_oir import GTIRToOIR +from gt4py.cartesian.gtc.passes.oir_optimizations.horizontal_execution_merging import ( + HorizontalExecutionMerging, +) +from gt4py.cartesian.gtc.passes.oir_optimizations.temporaries import LocalTemporariesToScalars from gt4py.cartesian.gtc.passes.oir_pipeline import OirPipeline from gt4py.eve.codegen import format_source @@ -60,6 +64,8 @@ def generate_computation(self) -> dict[str, Union[str, dict]]: + ".py" ) oir = GTIRToOIR().visit(self.builder.gtir) + oir = HorizontalExecutionMerging().visit(oir) + oir = LocalTemporariesToScalars().visit(oir) source_code = DebugCodeGen().visit(oir) if self.builder.options.format_source: diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index f56e23564c..ed6775b802 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1388,6 +1388,46 @@ def visit_While(self, node: ast.While) -> list: return result + def _absolute_K_index_method(self, node: ast.Call): + # Dev note: we enforce .at(K=..., ddim=[...]) for the POC + # A better version of this code would look through the keywords + # in any order. `ddim` shall remain optional, K mandatory. + assert _filter_absolute_K_index_method(node) + if len(node.keywords) not in [1, 2]: + raise GTScriptSyntaxError( + message="Absolute K index bad syntax. Must be of the form`.at(K=..., ddim=[...])` " + " with the `ddim` argument optional", + loc=nodes.Location.from_ast_node(node), + ) + if node.keywords[0].arg != "K": + raise GTScriptSyntaxError( + message="Absolute K index: bad syntax, first argument must be `K`. " + "Must be of the form`.at(K=...)`", + loc=nodes.Location.from_ast_node(node), + ) + if len(node.keywords) > 1 and node.keywords[1].arg != "ddim": + raise GTScriptSyntaxError( + message="Absolute K index: bad syntax, second argument (optional) must be `ddim`. " + "Must be of the form`.at(K=..., ddim=[...])`", + loc=nodes.Location.from_ast_node(node), + ) + if ( + len(node.keywords) > 1 + and node.keywords[1].arg == "ddim" + and not isinstance(node.keywords[1].value, ast.List) + ): + raise GTScriptSyntaxError( + message="Absolute K index: bad syntax, second argument `ddim` (optional) must be " + "a list of values. Must be of the form`.at(K=..., ddim=[...])`", + loc=nodes.Location.from_ast_node(node), + ) + field: nodes.FieldRef = self.visit(node.func.value) + assert isinstance(field, nodes.FieldRef) + field.offset = nodes.AbsoluteKIndex(k=self.visit(node.keywords[0].value)) + if len(node.keywords) == 2: + field.data_index = [self.visit(value) for value in node.keywords[1].value.elts] + return field + def visit_Call(self, node: ast.Call): native_fcn = nodes.NativeFunction.PYTHON_SYMBOL_TO_IR_OP[node.func.id] diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index 9ce0df19bd..2e9afe0bc3 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -317,8 +317,17 @@ def zero(cls) -> CartesianOffset: def to_dict(self) -> Dict[str, int]: return {"i": self.i, "j": self.j, "k": self.k} - def to_str(self) -> str: - return f"i + {self.i}, j + {self.j}, k + {self.k}" + def to_str(self, dimensions: tuple[bool, bool, bool]) -> str: + dimension_strings = [] + + if dimensions[0]: + dimension_strings.append(f"i + {self.i}") + if dimensions[1]: + dimension_strings.append(f"j + {self.j}") + if dimensions[2]: + dimension_strings.append(f"k + {self.k}") + + return ",".join(dimension_strings) class VariableKOffset(eve.GenericNode, Generic[ExprT]): diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py index a4a709ddc2..f5fcb01127 100644 --- a/src/gt4py/cartesian/gtc/debug/debug_codegen.py +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -16,6 +16,8 @@ from gt4py.cartesian import utils from gt4py.cartesian.gtc.common import ( AxisBound, + BuiltInLiteral, + CartesianOffset, DataType, FieldAccess, HorizontalInterval, @@ -28,12 +30,15 @@ from gt4py.cartesian.gtc.oir import ( AssignStmt, BinaryOp, + CacheDesc, Cast, Decl, FieldDecl, HorizontalExecution, HorizontalRestriction, + IJCache, Interval, + KCache, Literal, LocalScalar, MaskStmt, @@ -44,16 +49,20 @@ Temporary, TernaryOp, UnaryOp, + UnboundedInterval, + VariableKOffset, VerticalLoop, VerticalLoopSection, ) from gt4py.cartesian.gtc.passes.oir_optimizations.utils import StencilExtentComputer from gt4py.eve import codegen +from gt4py.eve.concepts import SymbolRef class DebugCodeGen(codegen.TemplatedGenerator, eve.VisitorWithSymbolTableTrait): def __init__(self) -> None: self.body = utils.text.TextBlock() + self.symbol_table: dict[str, FieldDecl] = {} def visit_Stencil(self, stencil: Stencil, **_): self.generate_imports() @@ -61,7 +70,7 @@ def visit_Stencil(self, stencil: Stencil, **_): self.generate_run_function(stencil) field_extents, block_extents = self.compute_extents(stencil) - self.initial_declarations(stencil, field_extents) + self.symbol_table = self.initial_declarations(stencil, field_extents) self.generate_stencil_code(stencil, block_extents) return self.body.text @@ -76,31 +85,40 @@ def compute_extents(node: Stencil, **_) -> tuple[dict[str, Extent], dict[int, Ex ctx: StencilExtentComputer.Context = StencilExtentComputer().visit(node) return ctx.fields, ctx.blocks - def initial_declarations(self, stencil: Stencil, field_extents: dict[str, Extent]): + def initial_declarations( + self, stencil: Stencil, field_extents: dict[str, Extent] + ) -> dict[str, FieldDecl]: self.body.append("# ===== Domain Description ===== #") self.body.append("i_0, j_0, k_0 = 0,0,0") self.body.append("i_size, j_size, k_size = _domain_") self.body.empty_line() self.body.append("# ===== Temporary Declaration ===== #") - self.generate_temp_decls(stencil.declarations, field_extents) + symbol_table = self.generate_temp_decls(stencil.declarations, field_extents) self.body.empty_line() self.body.append("# ===== Field Declaration ===== #") - self.generate_field_decls(stencil.params) + symbol_table |= self.generate_field_decls(stencil.params) self.body.empty_line() + return symbol_table def generate_temp_decls( self, temporary_declarations: list[Temporary], field_extents: dict[str, Extent] - ) -> None: + ) -> dict[str, FieldDecl]: + symbol_table: dict[str, FieldDecl] = {} for declaration in temporary_declarations: self.body.append(self.visit(declaration, field_extents=field_extents)) + symbol_table[str(declaration.name)] = declaration + return symbol_table - def generate_field_decls(self, declarations: list[Decl]) -> None: + def generate_field_decls(self, declarations: list[Decl]) -> dict[str, FieldDecl]: + symbol_table = {} for declaration in declarations: if isinstance(declaration, FieldDecl): self.body.append( f"{declaration.name} = Field({declaration.name}, _origin_['{declaration.name}'], " f"({', '.join([str(x) for x in declaration.dimensions])}))" ) + symbol_table[str(declaration.name)] = declaration + return symbol_table def generate_run_function(self, stencil: Stencil): function_signature = "def run(*" @@ -112,7 +130,11 @@ def generate_run_function(self, stencil: Stencil): self.body.append(function_signature) self.body.indent() - def generate_stencil_code(self, stencil: Stencil, block_extents: dict[int, Extent]): + def generate_stencil_code( + self, + stencil: Stencil, + block_extents: dict[int, Extent], + ): for loop in stencil.vertical_loops: for section in loop.sections: with self.create_k_loop_code(section, loop): @@ -198,21 +220,34 @@ def visit_DataType(self, data_type: DataType, **_) -> str: else: return data_type.name.lower() + def visit_VariableKOffset(self, variable_k_offset: VariableKOffset, **_) -> str: + return f"i,j,k+int({self.visit(variable_k_offset.k)})" + + def visit_CartesianOffset(self, cartesian_offset: CartesianOffset, **kwargs) -> str: + dimensions = kwargs["dimensions"] + return cartesian_offset.to_str(dimensions) + + def visit_SymbolRef(self, symbol_ref: SymbolRef) -> str: + return symbol_ref + def visit_FieldAccess(self, field_access: FieldAccess, **_) -> str: + if str(field_access.name) in self.symbol_table: + dimensions = self.symbol_table[str(field_access.name)].dimensions + else: + dimensions = (True, True, True) + + offset_str = self.visit(field_access.offset, dimensions=dimensions) + if field_access.data_index: data_index_access = ",".join( [self.visit(data_index) for data_index in field_access.data_index] ) - full_string = ( - field_access.name - + "[" - + field_access.offset.to_str() - + "," - + data_index_access - + "]" - ) + if offset_str == "": + full_string = field_access.name + f"[{data_index_access}]" + else: + full_string = field_access.name + "[" + offset_str + "," + data_index_access + "]" else: - full_string = field_access.name + "[" + field_access.offset.to_str() + "]" + full_string = field_access.name + "[" + offset_str + "]" return full_string def visit_AssignStmt(self, assignment_statement: AssignStmt, **_): @@ -220,14 +255,22 @@ def visit_AssignStmt(self, assignment_statement: AssignStmt, **_): self.visit(assignment_statement.left) + "=" + self.visit(assignment_statement.right) ) - def visit_BinaryOp(self, binary: BinaryOp, **_): - return self.visit(binary.left) + str(binary.op) + self.visit(binary.right) + def visit_BinaryOp(self, binary: BinaryOp, **_) -> str: + return f"( {self.visit(binary.left)} {binary.op} {self.visit(binary.right)} )" def visit_Literal(self, literal: Literal, **_) -> str: + if literal.dtype == DataType.BOOL: + if literal.value == BuiltInLiteral.TRUE: + literal_value = "True" + else: + literal_value = "False" + else: + literal_value = str(literal.value) + if literal.dtype.bit_count() != 4: - literal_code = f"{self.visit(literal.dtype)}({literal.value})" + literal_code = f"{self.visit(literal.dtype)}({literal_value})" else: - literal_code = str(literal.value) + literal_code = literal_value return literal_code def visit_Cast(self, cast: Cast, **_) -> str: @@ -283,16 +326,31 @@ def visit_NativeFuncCall(self, native_function_call: NativeFuncCall, **_) -> str def visit_UnaryOp(self, unary_operator: UnaryOp, **_) -> str: return unary_operator.op.value + " " + self.visit(unary_operator.expr) - def visit_TernaryOp(self, ternary_operator: TernaryOp, **_) -> None: + def visit_TernaryOp(self, ternary_operator: TernaryOp, **_) -> str: return f"{self.visit(ternary_operator.true_expr)} if {self.visit(ternary_operator.cond)} else {self.visit(ternary_operator.false_expr)}" - def visit_LocalScalar(self, local_scalar: LocalScalar, **__) -> None: - raise NotImplementedError( - "This state should not be reached because LocalTemporariesToScalars should not have been called." - ) - def visit_MaskStmt(self, mask_statement: MaskStmt, **_): self.body.append(f"if {self.visit(mask_statement.mask)}:") with self.body.indented(): for statement in mask_statement.body: self.visit(statement) + + def visit_LocalScalar(self, local_scalar: LocalScalar, **__) -> None: + raise NotImplementedError( + "This state should not be reached because LocalTemporariesToScalars should not have been called." + ) + + def visit_CacheDesc(self, cache_descriptor: CacheDesc, **_): + raise NotImplementedError("Caches should never be visited in the debug backend") + + def visit_IJCache(self, ij_cache: IJCache, **_): + raise NotImplementedError("Caches should never be visited in the debug backend") + + def visit_KCache(self, k_cache: KCache, **_): + raise NotImplementedError("Caches should never be visited in the debug backend") + + def visit_VerticalLoopSection(self, vertical_loop_section: VerticalLoopSection, **_): + raise NotImplementedError("Vertical Loop section is not in the right place.") + + def visit_UnboundedInterval(self, unbounded_interval: UnboundedInterval, **_) -> None: + raise NotImplementedError("Unbounded Intervals are not supported in the debug backend.") diff --git a/src/gt4py/cartesian/gtc/ufuncs.py b/src/gt4py/cartesian/gtc/ufuncs.py index 88c7534602..c0b09f60f3 100644 --- a/src/gt4py/cartesian/gtc/ufuncs.py +++ b/src/gt4py/cartesian/gtc/ufuncs.py @@ -37,6 +37,8 @@ abs: np.ufunc = np.abs # noqa: A001 [builtin-variable-shadowing] minimum: np.ufunc = np.minimum maximum: np.ufunc = np.maximum +max: np.ufunc = np.maximum +min: np.ufunc = np.minimum remainder: np.ufunc = np.remainder sin: np.ufunc = np.sin cos: np.ufunc = np.cos @@ -52,6 +54,7 @@ arctanh: np.ufunc = np.arctanh sqrt: np.ufunc = np.sqrt power: np.ufunc = np.power +pow: np.ufunc = np.power exp: np.ufunc = np.exp log: np.ufunc = np.log log10: np.ufunc = np.log10 diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py index 097ad3b2da..2c20e7361e 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py @@ -16,7 +16,7 @@ from gt4py import storage as gt_storage from gt4py.cartesian import gtscript -from gt4py.cartesian.gtscript import BACKWARD, PARALLEL, computation, interval +from gt4py.cartesian.gtscript import BACKWARD, PARALLEL, THIS_K, computation, interval, sin def test_simple_stencil(): @@ -261,3 +261,121 @@ def test_stencil( test_stencil(field_in, field_out) assert np.all(field_out.view(np.ndarray) > 0) + + +def test_k_offset_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[:, :, 0] *= 10 + offset = -1 + + @gtscript.stencil(backend="debug") + def test_stencil( + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], + scalar_value: int, + ): + with computation(PARALLEL), interval(1, None): + out_field[0, 0, 0] = in_field[0, 0, scalar_value] + + test_stencil(field_in, field_out, offset) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 10) + + +def test_k_offset_field_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_idx = gt_storage.ones(dtype=np.int64, backend="debug", shape=(4, 4), aligned_index=(0, 0)) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[:, :, 0] *= 10 + field_idx[:, :] *= -2 + + @gtscript.stencil(backend="debug") + def test_stencil( + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], + idx_field: gtscript.Field[gtscript.IJ, np.int64], + ): + with computation(PARALLEL), interval(1, None): + out_field[0, 0, 0] = in_field[0, 0, idx_field + 1] + + test_stencil(field_in, field_out, field_idx) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 10) + + +def test_absolute_k_stencil(): + field_in = gt_storage.ones( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[:, :, 0] *= 10 + field_in[:, :, 1] *= 5 + + @gtscript.stencil(backend="debug") + def test_stencil( + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = in_field.at(K=0) + in_field.at(K=1) + + test_stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 15) + + +def test_k_only_access_stencil(): + field_in = np.ones((4,), dtype=np.float64) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[:] = [2, 3, 4, 5] + + @gtscript.stencil(backend="debug") + def test_stencil( + in_field: gtscript.Field[gtscript.K, np.float64], + out_field: gtscript.Field[np.float64], + ): + with computation(PARALLEL): + with interval(0, 1): + out_field[0, 0, 0] = in_field[1] + with interval(1, None): + out_field[0, 0, 0] = in_field[-1] + + test_stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[1, 1, :], [3, 2, 3, 4]) + + +def test_table_access_stencil(): + table_view = np.ones((4,), dtype=np.float64) + field_out = gt_storage.zeros( + dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + table_view[:] = [2, 3, 4, 5] + + @gtscript.stencil(backend="debug") + def test_stencil( + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], + ): + with computation(PARALLEL): + with interval(0, 1): + out_field[0, 0, 0] = table_view.A[1] + with interval(1, None): + out_field[0, 0, 0] = table_view.A[2] + + test_stencil(table_view, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[1, 1, :], [3, 4, 4, 4]) From 62467746b347a8bc1865047f4c857a770661e125 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 24 Feb 2025 15:17:15 +0100 Subject: [PATCH 04/19] simplify to existing features --- src/gt4py/cartesian/backend/debug_backend.py | 8 ++ .../cartesian/frontend/gtscript_frontend.py | 40 -------- src/gt4py/cartesian/gtc/debug/__init__.py | 8 ++ .../cartesian/gtc/debug/debug_codegen.py | 11 +++ src/gt4py/cartesian/gtc/ufuncs.py | 6 +- src/gt4py/cartesian/utils/field.py | 8 ++ .../multi_feature_tests/test_debug_backend.py | 91 ++++++++----------- 7 files changed, 76 insertions(+), 96 deletions(-) diff --git a/src/gt4py/cartesian/backend/debug_backend.py b/src/gt4py/cartesian/backend/debug_backend.py index b0b78296ac..4a4a4eaba5 100644 --- a/src/gt4py/cartesian/backend/debug_backend.py +++ b/src/gt4py/cartesian/backend/debug_backend.py @@ -1,3 +1,11 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + # GT4Py - GridTools Framework # # Copyright (c) 2014-2023, ETH Zurich diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index ed6775b802..f56e23564c 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1388,46 +1388,6 @@ def visit_While(self, node: ast.While) -> list: return result - def _absolute_K_index_method(self, node: ast.Call): - # Dev note: we enforce .at(K=..., ddim=[...]) for the POC - # A better version of this code would look through the keywords - # in any order. `ddim` shall remain optional, K mandatory. - assert _filter_absolute_K_index_method(node) - if len(node.keywords) not in [1, 2]: - raise GTScriptSyntaxError( - message="Absolute K index bad syntax. Must be of the form`.at(K=..., ddim=[...])` " - " with the `ddim` argument optional", - loc=nodes.Location.from_ast_node(node), - ) - if node.keywords[0].arg != "K": - raise GTScriptSyntaxError( - message="Absolute K index: bad syntax, first argument must be `K`. " - "Must be of the form`.at(K=...)`", - loc=nodes.Location.from_ast_node(node), - ) - if len(node.keywords) > 1 and node.keywords[1].arg != "ddim": - raise GTScriptSyntaxError( - message="Absolute K index: bad syntax, second argument (optional) must be `ddim`. " - "Must be of the form`.at(K=..., ddim=[...])`", - loc=nodes.Location.from_ast_node(node), - ) - if ( - len(node.keywords) > 1 - and node.keywords[1].arg == "ddim" - and not isinstance(node.keywords[1].value, ast.List) - ): - raise GTScriptSyntaxError( - message="Absolute K index: bad syntax, second argument `ddim` (optional) must be " - "a list of values. Must be of the form`.at(K=..., ddim=[...])`", - loc=nodes.Location.from_ast_node(node), - ) - field: nodes.FieldRef = self.visit(node.func.value) - assert isinstance(field, nodes.FieldRef) - field.offset = nodes.AbsoluteKIndex(k=self.visit(node.keywords[0].value)) - if len(node.keywords) == 2: - field.data_index = [self.visit(value) for value in node.keywords[1].value.elts] - return field - def visit_Call(self, node: ast.Call): native_fcn = nodes.NativeFunction.PYTHON_SYMBOL_TO_IR_OP[node.func.id] diff --git a/src/gt4py/cartesian/gtc/debug/__init__.py b/src/gt4py/cartesian/gtc/debug/__init__.py index 6c43e2f12a..ce7df089c2 100644 --- a/src/gt4py/cartesian/gtc/debug/__init__.py +++ b/src/gt4py/cartesian/gtc/debug/__init__.py @@ -1,3 +1,11 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + # GT4Py - GridTools Framework # # Copyright (c) 2014-2023, ETH Zurich diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py index f5fcb01127..c256207454 100644 --- a/src/gt4py/cartesian/gtc/debug/debug_codegen.py +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -1,3 +1,11 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + # GT4Py - GridTools Framework # # Copyright (c) 2014-2023, ETH Zurich @@ -12,6 +20,9 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from collections.abc import Generator +from contextlib import contextmanager + from gt4py import eve from gt4py.cartesian import utils from gt4py.cartesian.gtc.common import ( diff --git a/src/gt4py/cartesian/gtc/ufuncs.py b/src/gt4py/cartesian/gtc/ufuncs.py index c0b09f60f3..ebf81df4b8 100644 --- a/src/gt4py/cartesian/gtc/ufuncs.py +++ b/src/gt4py/cartesian/gtc/ufuncs.py @@ -37,8 +37,8 @@ abs: np.ufunc = np.abs # noqa: A001 [builtin-variable-shadowing] minimum: np.ufunc = np.minimum maximum: np.ufunc = np.maximum -max: np.ufunc = np.maximum -min: np.ufunc = np.minimum +max: np.ufunc = np.maximum # noqa: A001 +min: np.ufunc = np.minimum # noqa: A001 remainder: np.ufunc = np.remainder sin: np.ufunc = np.sin cos: np.ufunc = np.cos @@ -54,7 +54,7 @@ arctanh: np.ufunc = np.arctanh sqrt: np.ufunc = np.sqrt power: np.ufunc = np.power -pow: np.ufunc = np.power +pow: np.ufunc = np.power # noqa: A001 exp: np.ufunc = np.exp log: np.ufunc = np.log log10: np.ufunc = np.log10 diff --git a/src/gt4py/cartesian/utils/field.py b/src/gt4py/cartesian/utils/field.py index b2f28decc0..e421250a95 100644 --- a/src/gt4py/cartesian/utils/field.py +++ b/src/gt4py/cartesian/utils/field.py @@ -1,3 +1,11 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + # GT4Py - GridTools Framework # # Copyright (c) 2014-2023, ETH Zurich diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py index 2c20e7361e..04dbebb4aa 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py @@ -1,3 +1,11 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + # GT4Py - GridTools Framework # # Copyright (c) 2014-2023, ETH Zurich @@ -16,7 +24,7 @@ from gt4py import storage as gt_storage from gt4py.cartesian import gtscript -from gt4py.cartesian.gtscript import BACKWARD, PARALLEL, THIS_K, computation, interval, sin +from gt4py.cartesian.gtscript import BACKWARD, PARALLEL, computation, interval, sin def test_simple_stencil(): @@ -28,7 +36,7 @@ def test_simple_stencil(): ) @gtscript.stencil(backend="debug") - def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore with computation(BACKWARD): with interval(-2, -1): # block 1 field_out = field_in @@ -38,7 +46,7 @@ def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.f with interval(-1, None): # block 3 field_out = 2 * field_in with interval(0, -1): # block 4 - field_out = 3 * field_in + field_out[0, 0, 0] = 3 * field_in stencil(field_in, field_out) @@ -55,13 +63,13 @@ def test_tmp_stencil(): ) @gtscript.stencil(backend="debug") - def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore with computation(PARALLEL): with interval(...): tmp = field_in + 1 with computation(PARALLEL): with interval(...): - field_out = tmp[-1, 0, 0] + tmp[1, 0, 0] + field_out[0, 0, 0] = tmp[-1, 0, 0] + tmp[1, 0, 0] stencil(field_in, field_out, origin=(1, 1, 0), domain=(4, 4, 6)) @@ -83,14 +91,14 @@ def test_backward_stencil(): ) @gtscript.stencil(backend="debug") - def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore with computation(BACKWARD): with interval(-1, None): field_in = 2 field_out = field_in with interval(0, -1): field_in = field_in[0, 0, 1] + 1 - field_out = field_in + field_out[0, 0, 0] = field_in stencil(field_in, field_out) @@ -109,12 +117,12 @@ def test_while_stencil(): ) @gtscript.stencil(backend="debug") - def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore with computation(PARALLEL): with interval(...): while field_in < 10: field_in += 1 - field_out = field_in + field_out[0, 0, 0] = field_in stencil(field_in, field_out) @@ -135,8 +143,8 @@ def test_higher_dim_literal_stencil(): @gtscript.stencil(backend="debug") def stencil( - vec_field: gtscript.Field[FLOAT64_NDDIM], - out_field: gtscript.Field[np.float64], + vec_field: gtscript.Field[FLOAT64_NDDIM], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore ): with computation(PARALLEL), interval(...): out_field[0, 0, 0] = vec_field[0, 0, 0][2] @@ -160,8 +168,8 @@ def test_higher_dim_scalar_stencil(): @gtscript.stencil(backend="debug") def stencil( - vec_field: gtscript.Field[FLOAT64_NDDIM], - out_field: gtscript.Field[np.float64], + vec_field: gtscript.Field[FLOAT64_NDDIM], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore scalar_argument: int, ): with computation(PARALLEL), interval(...): @@ -183,8 +191,8 @@ def test_native_function_call_stencil(): @gtscript.stencil(backend="debug") def test_stencil( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore ): with computation(PARALLEL), interval(...): out_field[0, 0, 0] = in_field[0, 0, 0] + sin(0.848062) @@ -204,8 +212,8 @@ def test_unary_operator_stencil(): @gtscript.stencil(backend="debug") def test_stencil( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore ): with computation(PARALLEL), interval(...): out_field[0, 0, 0] = -in_field[0, 0, 0] @@ -226,8 +234,8 @@ def test_ternary_operator_stencil(): @gtscript.stencil(backend="debug") def test_stencil( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore ): with computation(PARALLEL), interval(...): out_field[0, 0, 0] = in_field[0, 0, 0] if in_field > 10 else in_field[0, 0, 0] + 1 @@ -249,8 +257,8 @@ def test_mask_stencil(): @gtscript.stencil(backend="debug") def test_stencil( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore ): with computation(PARALLEL), interval(...): if in_field[0, 0, 0] > 0: @@ -275,8 +283,8 @@ def test_k_offset_stencil(): @gtscript.stencil(backend="debug") def test_stencil( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore scalar_value: int, ): with computation(PARALLEL), interval(1, None): @@ -300,9 +308,9 @@ def test_k_offset_field_stencil(): @gtscript.stencil(backend="debug") def test_stencil( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], - idx_field: gtscript.Field[gtscript.IJ, np.int64], + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + idx_field: gtscript.Field[gtscript.IJ, np.int64], # type: ignore ): with computation(PARALLEL), interval(1, None): out_field[0, 0, 0] = in_field[0, 0, idx_field + 1] @@ -312,29 +320,6 @@ def test_stencil( np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 10) -def test_absolute_k_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_in[:, :, 0] *= 10 - field_in[:, :, 1] *= 5 - - @gtscript.stencil(backend="debug") - def test_stencil( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], - ): - with computation(PARALLEL), interval(...): - out_field[0, 0, 0] = in_field.at(K=0) + in_field.at(K=1) - - test_stencil(field_in, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 15) - - def test_k_only_access_stencil(): field_in = np.ones((4,), dtype=np.float64) field_out = gt_storage.zeros( @@ -344,8 +329,8 @@ def test_k_only_access_stencil(): @gtscript.stencil(backend="debug") def test_stencil( - in_field: gtscript.Field[gtscript.K, np.float64], - out_field: gtscript.Field[np.float64], + in_field: gtscript.Field[gtscript.K, np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore ): with computation(PARALLEL): with interval(0, 1): @@ -367,8 +352,8 @@ def test_table_access_stencil(): @gtscript.stencil(backend="debug") def test_stencil( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], + table_view: gtscript.GlobalTable[(np.float64, (4))], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore ): with computation(PARALLEL): with interval(0, 1): From 37769bddbc529bb69b99763112ddd74ad112e9f3 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 24 Feb 2025 17:56:55 +0100 Subject: [PATCH 05/19] missing mod ufunc --- src/gt4py/cartesian/gtc/ufuncs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/cartesian/gtc/ufuncs.py b/src/gt4py/cartesian/gtc/ufuncs.py index ebf81df4b8..58b2bac5c0 100644 --- a/src/gt4py/cartesian/gtc/ufuncs.py +++ b/src/gt4py/cartesian/gtc/ufuncs.py @@ -39,6 +39,7 @@ maximum: np.ufunc = np.maximum max: np.ufunc = np.maximum # noqa: A001 min: np.ufunc = np.minimum # noqa: A001 +mod: np.ufunc = np.mod remainder: np.ufunc = np.remainder sin: np.ufunc = np.sin cos: np.ufunc = np.cos From 85d73f9a766c826c3b9b683d683a577be3368dc2 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 24 Feb 2025 18:32:48 +0100 Subject: [PATCH 06/19] missing omission from perf-backend names --- tests/cartesian_tests/definitions.py | 4 +++- tests/cartesian_tests/unit_tests/test_cli.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 4d52b9b773..9ffbcfd437 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -48,7 +48,9 @@ def _get_backends_with_storage_info(storage_info_kind: str): GPU_BACKENDS = _get_backends_with_storage_info("gpu") ALL_BACKENDS = CPU_BACKENDS + GPU_BACKENDS -_PERFORMANCE_BACKEND_NAMES = [name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda")] +_PERFORMANCE_BACKEND_NAMES = [ + name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda", "debug") +] PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in _PERFORMANCE_BACKEND_NAMES] DACE_BACKENDS = [ diff --git a/tests/cartesian_tests/unit_tests/test_cli.py b/tests/cartesian_tests/unit_tests/test_cli.py index dbd90a1908..e8587f506b 100644 --- a/tests/cartesian_tests/unit_tests/test_cli.py +++ b/tests/cartesian_tests/unit_tests/test_cli.py @@ -92,6 +92,7 @@ def nocli_backend(scope="module"): "gt:gpu": r"^\s*gt:gpu\s*cuda\s*python\s*Yes", "numpy": r"^\s*numpy\s*python\s*python\s*Yes", "nocli": r"^\s*nocli\s*\?\s*\?\s*No", + "debug": r"^\s*debug\s*python\s*python\s*Yes", } From 7bd7a97c9302bad60c2f9db80e697fa4b54334d0 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Tue, 25 Feb 2025 14:36:25 +0100 Subject: [PATCH 07/19] Update src/gt4py/cartesian/backend/__init__.py Co-authored-by: Hannes Vogt --- src/gt4py/cartesian/backend/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/cartesian/backend/__init__.py b/src/gt4py/cartesian/backend/__init__.py index 3179f116ac..15921f0c4e 100644 --- a/src/gt4py/cartesian/backend/__init__.py +++ b/src/gt4py/cartesian/backend/__init__.py @@ -37,7 +37,6 @@ "GTCpuIfirstBackend", "GTCpuKfirstBackend", "GTGpuBackend", - "GTGpuBackend", "NumpyBackend", "PurePythonBackendCLIMixin", "from_name", From 2901d4a16a09b90cf69801e828b95215890808ef Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Wed, 26 Feb 2025 09:16:34 +0100 Subject: [PATCH 08/19] reviewer's feedback --- .gitignore | 2 +- src/gt4py/cartesian/backend/debug_backend.py | 13 ------------- src/gt4py/cartesian/gtc/debug/__init__.py | 14 -------------- src/gt4py/cartesian/gtc/debug/debug_codegen.py | 13 ------------- src/gt4py/cartesian/utils/field.py | 13 ------------- 5 files changed, 1 insertion(+), 54 deletions(-) diff --git a/.gitignore b/.gitignore index 34992bf9f7..ebbbfaebeb 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ _local /src/__init__.py /tests/__init__.py .gt_cache/ -.gt_cache*/ +.gt4py_cache/ .gt_cache_pytest*/ # DaCe diff --git a/src/gt4py/cartesian/backend/debug_backend.py b/src/gt4py/cartesian/backend/debug_backend.py index 4a4a4eaba5..1144e86462 100644 --- a/src/gt4py/cartesian/backend/debug_backend.py +++ b/src/gt4py/cartesian/backend/debug_backend.py @@ -6,19 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Type, Union diff --git a/src/gt4py/cartesian/gtc/debug/__init__.py b/src/gt4py/cartesian/gtc/debug/__init__.py index ce7df089c2..c9075cc89d 100644 --- a/src/gt4py/cartesian/gtc/debug/__init__.py +++ b/src/gt4py/cartesian/gtc/debug/__init__.py @@ -5,17 +5,3 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py index c256207454..2c91ea0e07 100644 --- a/src/gt4py/cartesian/gtc/debug/debug_codegen.py +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -6,19 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later from collections.abc import Generator from contextlib import contextmanager diff --git a/src/gt4py/cartesian/utils/field.py b/src/gt4py/cartesian/utils/field.py index e421250a95..101a3ba28f 100644 --- a/src/gt4py/cartesian/utils/field.py +++ b/src/gt4py/cartesian/utils/field.py @@ -6,19 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later import numbers from typing import Tuple From 628a5a1041f3f5a0d25362c90827c2d0ec1a5b0a Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Fri, 28 Feb 2025 10:56:07 +0100 Subject: [PATCH 09/19] Update src/gt4py/cartesian/gtc/debug/debug_codegen.py Co-authored-by: Roman Cattaneo --- src/gt4py/cartesian/gtc/debug/debug_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py index 2c91ea0e07..dbcf7a0500 100644 --- a/src/gt4py/cartesian/gtc/debug/debug_codegen.py +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -87,7 +87,7 @@ def initial_declarations( self, stencil: Stencil, field_extents: dict[str, Extent] ) -> dict[str, FieldDecl]: self.body.append("# ===== Domain Description ===== #") - self.body.append("i_0, j_0, k_0 = 0,0,0") + self.body.append("i_0, j_0, k_0 = 0, 0, 0") self.body.append("i_size, j_size, k_size = _domain_") self.body.empty_line() self.body.append("# ===== Temporary Declaration ===== #") From 2e60003dde7ee9fc1ed507080b8ae89ba95e8701 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Fri, 28 Feb 2025 13:27:04 +0100 Subject: [PATCH 10/19] Update src/gt4py/cartesian/gtc/debug/debug_codegen.py Co-authored-by: Roman Cattaneo --- src/gt4py/cartesian/gtc/debug/debug_codegen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py index dbcf7a0500..af789870da 100644 --- a/src/gt4py/cartesian/gtc/debug/debug_codegen.py +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -339,13 +339,13 @@ def visit_LocalScalar(self, local_scalar: LocalScalar, **__) -> None: ) def visit_CacheDesc(self, cache_descriptor: CacheDesc, **_): - raise NotImplementedError("Caches should never be visited in the debug backend") + raise NotImplementedError("CacheDescriptors should never be visited in the debug backends") def visit_IJCache(self, ij_cache: IJCache, **_): - raise NotImplementedError("Caches should never be visited in the debug backend") + raise NotImplementedError("IJCaches should never be visited in the debug backend.") def visit_KCache(self, k_cache: KCache, **_): - raise NotImplementedError("Caches should never be visited in the debug backend") + raise NotImplementedError("KCaches should never be visited in the debug backend.") def visit_VerticalLoopSection(self, vertical_loop_section: VerticalLoopSection, **_): raise NotImplementedError("Vertical Loop section is not in the right place.") From 84b5cc98e16124746fc24b480f9e9adc4a4cc61d Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Fri, 28 Feb 2025 13:32:11 +0100 Subject: [PATCH 11/19] reviewer's feedback --- src/gt4py/cartesian/backend/base.py | 9 ++ src/gt4py/cartesian/backend/debug_backend.py | 13 +-- src/gt4py/cartesian/backend/numpy_backend.py | 12 +-- src/gt4py/cartesian/gtc/common.py | 12 --- .../cartesian/gtc/debug/debug_codegen.py | 93 ++++++++++--------- .../multi_feature_tests/test_debug_backend.py | 14 --- 6 files changed, 58 insertions(+), 95 deletions(-) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 9a638ef0ee..9772f46ab0 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -317,6 +317,15 @@ def make_module_source(self, *, args_data: Optional[ModuleData] = None, **kwargs source = self.MODULE_GENERATOR_CLASS()(args_data, self.builder, **kwargs) return source + def recursive_write(self, root_path: pathlib.Path, tree: dict[str, Union[str, dict]]): + root_path.mkdir(parents=True, exist_ok=True) + for key, value in tree.items(): + if isinstance(value, dict): + self.recursive_write(root_path / key, value) + else: + src_path = root_path / key + src_path.write_text(value) + class MakeModuleSourceCallable(Protocol): def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: ... diff --git a/src/gt4py/cartesian/backend/debug_backend.py b/src/gt4py/cartesian/backend/debug_backend.py index 1144e86462..45e1c85835 100644 --- a/src/gt4py/cartesian/backend/debug_backend.py +++ b/src/gt4py/cartesian/backend/debug_backend.py @@ -27,16 +27,6 @@ from gt4py.cartesian.stencil_object import StencilObject -def recursive_write(root_path: Path, tree: dict[str, Union[str, dict]]): - root_path.mkdir(parents=True, exist_ok=True) - for key, value in tree.items(): - if isinstance(value, dict): - recursive_write(root_path / key, value) - else: - src_path = root_path / key - src_path.write_text(value) - - @register class DebugBackend(BaseBackend, CLIBackendMixin): """Debug backend using plain python loops.""" @@ -44,7 +34,6 @@ class DebugBackend(BaseBackend, CLIBackendMixin): name = "debug" options: ClassVar[dict[str, Any]] = { "oir_pipeline": {"versioning": True, "type": OirPipeline}, - # TODO: Implement this option in source code "ignore_np_errstate": {"versioning": True, "type": bool}, } storage_info = storage.layout.NaiveCPULayout @@ -77,5 +66,5 @@ def generate(self) -> Type["StencilObject"]: src_dir = self.builder.module_path.parent if not self.builder.options._impl_opts.get("disable-code-generation", False): src_dir.mkdir(parents=True, exist_ok=True) - recursive_write(src_dir, self.generate_computation()) + self.recursive_write(src_dir, self.generate_computation()) return self.make_module() diff --git a/src/gt4py/cartesian/backend/numpy_backend.py b/src/gt4py/cartesian/backend/numpy_backend.py index 160bd5eaa8..e61bbb69f8 100644 --- a/src/gt4py/cartesian/backend/numpy_backend.py +++ b/src/gt4py/cartesian/backend/numpy_backend.py @@ -57,16 +57,6 @@ def backend(self) -> NumpyBackend: return cast(NumpyBackend, self.builder.backend) -def recursive_write(root_path: pathlib.Path, tree: Dict[str, Union[str, dict]]): - root_path.mkdir(parents=True, exist_ok=True) - for key, value in tree.items(): - if isinstance(value, dict): - recursive_write(root_path / key, value) - else: - src_path = root_path / key - src_path.write_text(value) - - @register class NumpyBackend(BaseBackend, CLIBackendMixin): """NumPy backend using gtc.""" @@ -105,7 +95,7 @@ def generate(self) -> Type[StencilObject]: src_dir = self.builder.module_path.parent if not self.builder.options._impl_opts.get("disable-code-generation", False): src_dir.mkdir(parents=True, exist_ok=True) - recursive_write(src_dir, self.generate_computation()) + self.recursive_write(src_dir, self.generate_computation()) return self.make_module() def _make_npir(self) -> npir.Computation: diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index 2e9afe0bc3..60236a3e97 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -317,18 +317,6 @@ def zero(cls) -> CartesianOffset: def to_dict(self) -> Dict[str, int]: return {"i": self.i, "j": self.j, "k": self.k} - def to_str(self, dimensions: tuple[bool, bool, bool]) -> str: - dimension_strings = [] - - if dimensions[0]: - dimension_strings.append(f"i + {self.i}") - if dimensions[1]: - dimension_strings.append(f"j + {self.j}") - if dimensions[2]: - dimension_strings.append(f"k + {self.k}") - - return ",".join(dimension_strings) - class VariableKOffset(eve.GenericNode, Generic[ExprT]): k: ExprT diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py index 2c91ea0e07..0c6c2c95eb 100644 --- a/src/gt4py/cartesian/gtc/debug/debug_codegen.py +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -62,7 +62,7 @@ def __init__(self) -> None: self.body = utils.text.TextBlock() self.symbol_table: dict[str, FieldDecl] = {} - def visit_Stencil(self, stencil: Stencil, **_): + def visit_Stencil(self, stencil: Stencil, **_) -> str: self.generate_imports() self.generate_run_function(stencil) @@ -73,7 +73,7 @@ def visit_Stencil(self, stencil: Stencil, **_): return self.body.text - def generate_imports(self): + def generate_imports(self) -> None: self.body.append("import numpy as np") self.body.append("from gt4py.cartesian.gtc import ufuncs") self.body.append("from gt4py.cartesian.utils import Field") @@ -118,7 +118,7 @@ def generate_field_decls(self, declarations: list[Decl]) -> dict[str, FieldDecl] symbol_table[str(declaration.name)] = declaration return symbol_table - def generate_run_function(self, stencil: Stencil): + def generate_run_function(self, stencil: Stencil) -> None: function_signature = "def run(*" args = [] for param in stencil.params: @@ -132,7 +132,7 @@ def generate_stencil_code( self, stencil: Stencil, block_extents: dict[int, Extent], - ): + ) -> None: for loop in stencil.vertical_loops: for section in loop.sections: with self.create_k_loop_code(section, loop): @@ -143,8 +143,8 @@ def generate_stencil_code( @contextmanager def create_k_loop_code(self, section: VerticalLoopSection, loop: VerticalLoop) -> Generator: loop_bounds: str = self.visit(section.interval, var="k", direction=loop.loop_order) - iterator = "1" if loop.loop_order != LoopOrder.BACKWARD else "-1" - loop_code = "for k in range(" + loop_bounds + "," + iterator + "):" + increment = "-1" if loop.loop_order == LoopOrder.BACKWARD else "1" + loop_code = f"for k in range( {loop_bounds} , {increment}):" self.body.append(loop_code) self.body.indent() yield @@ -177,13 +177,13 @@ def visit_While(self, while_node: While, **_) -> None: def visit_FieldDecl(self, field_decl: FieldDecl, **_) -> str: return str(field_decl.name) - def visit_AxisBound(self, axis_bound: AxisBound, **kwargs): + def visit_AxisBound(self, axis_bound: AxisBound, **kwargs) -> str: if axis_bound.level == LevelMarker.START: return f"{kwargs['var']}_0 + {axis_bound.offset}" if axis_bound.level == LevelMarker.END: return f"{kwargs['var']}_size + {axis_bound.offset}" - def visit_Interval(self, interval: Interval, **kwargs): + def visit_Interval(self, interval: Interval, **kwargs) -> str: if kwargs["direction"] == LoopOrder.BACKWARD: return ",".join( [ @@ -191,10 +191,7 @@ def visit_Interval(self, interval: Interval, **kwargs): self.visit(interval.start, **kwargs) + "- 1", ] ) - else: - return ",".join( - [self.visit(interval.start, **kwargs), self.visit(interval.end, **kwargs)] - ) + return ",".join([self.visit(interval.start, **kwargs), self.visit(interval.end, **kwargs)]) def visit_Temporary(self, temporary_declaration: Temporary, **kwargs) -> str: field_extents = kwargs["field_extents"] @@ -213,17 +210,25 @@ def visit_Temporary(self, temporary_declaration: Temporary, **kwargs) -> str: return f"{temporary_declaration.name} = Field.empty(({shape_decl}), {dtype}, ({', '.join(offset)}))" def visit_DataType(self, data_type: DataType, **_) -> str: - if data_type not in {DataType.BOOL}: - return f"np.{data_type.name.lower()}" - else: + if data_type in {DataType.BOOL}: return data_type.name.lower() + return f"np.{data_type.name.lower()}" def visit_VariableKOffset(self, variable_k_offset: VariableKOffset, **_) -> str: return f"i,j,k+int({self.visit(variable_k_offset.k)})" def visit_CartesianOffset(self, cartesian_offset: CartesianOffset, **kwargs) -> str: - dimensions = kwargs["dimensions"] - return cartesian_offset.to_str(dimensions) + if "dimensions" in kwargs.keys(): + dimensions = kwargs["dimensions"] + dimension_strings = [] + if dimensions[0]: + dimension_strings.append(f"i + {cartesian_offset.i}") + if dimensions[1]: + dimension_strings.append(f"j + {cartesian_offset.j}") + if dimensions[2]: + dimension_strings.append(f"k + {cartesian_offset.k}") + return ",".join(dimension_strings) + return f"i + {cartesian_offset.i}, j + {cartesian_offset.j}, k + {cartesian_offset.k}" def visit_SymbolRef(self, symbol_ref: SymbolRef) -> str: return symbol_ref @@ -231,26 +236,22 @@ def visit_SymbolRef(self, symbol_ref: SymbolRef) -> str: def visit_FieldAccess(self, field_access: FieldAccess, **_) -> str: if str(field_access.name) in self.symbol_table: dimensions = self.symbol_table[str(field_access.name)].dimensions + offset_str = self.visit(field_access.offset, dimensions=dimensions) else: - dimensions = (True, True, True) - - offset_str = self.visit(field_access.offset, dimensions=dimensions) + offset_str = self.visit(field_access.offset) if field_access.data_index: data_index_access = ",".join( [self.visit(data_index) for data_index in field_access.data_index] ) if offset_str == "": - full_string = field_access.name + f"[{data_index_access}]" - else: - full_string = field_access.name + "[" + offset_str + "," + data_index_access + "]" - else: - full_string = field_access.name + "[" + offset_str + "]" - return full_string + return field_access.name + f"[{data_index_access}]" + return field_access.name + "[" + offset_str + "," + data_index_access + "]" + return field_access.name + "[" + offset_str + "]" - def visit_AssignStmt(self, assignment_statement: AssignStmt, **_): + def visit_AssignStmt(self, assignment_statement: AssignStmt, **_) -> None: self.body.append( - self.visit(assignment_statement.left) + "=" + self.visit(assignment_statement.right) + f"{self.visit(assignment_statement.left)} = {self.visit(assignment_statement.right)}" ) def visit_BinaryOp(self, binary: BinaryOp, **_) -> str: @@ -258,27 +259,23 @@ def visit_BinaryOp(self, binary: BinaryOp, **_) -> str: def visit_Literal(self, literal: Literal, **_) -> str: if literal.dtype == DataType.BOOL: - if literal.value == BuiltInLiteral.TRUE: - literal_value = "True" - else: - literal_value = "False" + literal_value = "True" if literal.value == BuiltInLiteral.TRUE else "False" else: literal_value = str(literal.value) if literal.dtype.bit_count() != 4: - literal_code = f"{self.visit(literal.dtype)}({literal_value})" - else: - literal_code = literal_value - return literal_code + return f"{self.visit(literal.dtype)}({literal_value})" + + return literal_value def visit_Cast(self, cast: Cast, **_) -> str: return f"{self.visit(cast.dtype)}({self.visit(cast.expr)})" - def visit_HorizontalExecution(self, horizontal_execution: HorizontalExecution, **_): - for stmt in horizontal_execution.body: - self.visit(stmt) + def visit_HorizontalExecution(self, horizontal_execution: HorizontalExecution, **_) -> None: + for statement in horizontal_execution.body: + self.visit(statement) - def visit_HorizontalMask(self, horizontal_mask: HorizontalMask, **_): + def visit_HorizontalMask(self, horizontal_mask: HorizontalMask, **_) -> None: i_min, i_max = self.visit(horizontal_mask.i, var="i") j_min, j_max = self.visit(horizontal_mask.j, var="j") conditions = [] @@ -294,23 +291,27 @@ def visit_HorizontalMask(self, horizontal_mask: HorizontalMask, **_): if_code = f"if( {' and '.join(conditions)} ):" self.body.append(if_code) - def visit_HorizontalInterval(self, horizontal_interval: HorizontalInterval, **kwargs): + def visit_HorizontalInterval( + self, horizontal_interval: HorizontalInterval, **kwargs + ) -> tuple[str | None, str | None]: return self.visit( horizontal_interval.start, **kwargs ) if horizontal_interval.start else None, self.visit( horizontal_interval.end, **kwargs ) if horizontal_interval.end else None - def visit_HorizontalRestriction(self, horizontal_restriction: HorizontalRestriction, **_): + def visit_HorizontalRestriction( + self, horizontal_restriction: HorizontalRestriction, **_ + ) -> None: self.visit(horizontal_restriction.mask) self.body.indent() self.visit(horizontal_restriction.body) self.body.dedent() - def visit_VerticalLoop(self): + def visit_VerticalLoop(self) -> None: pass - def visit_ScalarAccess(self, scalar_access: ScalarAccess, **_): + def visit_ScalarAccess(self, scalar_access: ScalarAccess, **_) -> str: return scalar_access.name def visit_ScalarDecl(self, scalar_declaration: ScalarDecl, **_) -> str: @@ -327,7 +328,7 @@ def visit_UnaryOp(self, unary_operator: UnaryOp, **_) -> str: def visit_TernaryOp(self, ternary_operator: TernaryOp, **_) -> str: return f"{self.visit(ternary_operator.true_expr)} if {self.visit(ternary_operator.cond)} else {self.visit(ternary_operator.false_expr)}" - def visit_MaskStmt(self, mask_statement: MaskStmt, **_): + def visit_MaskStmt(self, mask_statement: MaskStmt, **_) -> None: self.body.append(f"if {self.visit(mask_statement.mask)}:") with self.body.indented(): for statement in mask_statement.body: @@ -347,7 +348,7 @@ def visit_IJCache(self, ij_cache: IJCache, **_): def visit_KCache(self, k_cache: KCache, **_): raise NotImplementedError("Caches should never be visited in the debug backend") - def visit_VerticalLoopSection(self, vertical_loop_section: VerticalLoopSection, **_): + def visit_VerticalLoopSection(self, vertical_loop_section: VerticalLoopSection, **_) -> None: raise NotImplementedError("Vertical Loop section is not in the right place.") def visit_UnboundedInterval(self, unbounded_interval: UnboundedInterval, **_) -> None: diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py index 04dbebb4aa..7c5d0509af 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py @@ -6,20 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - import numpy as np from gt4py import storage as gt_storage From 7d9ac0609c7d688d3215a34b689bbaea53871392 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Fri, 28 Feb 2025 13:49:48 +0100 Subject: [PATCH 12/19] feedback v2 --- src/gt4py/cartesian/gtc/debug/debug_codegen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py index f9beab9d14..76b771c8de 100644 --- a/src/gt4py/cartesian/gtc/debug/debug_codegen.py +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -323,7 +323,7 @@ def visit_NativeFuncCall(self, native_function_call: NativeFuncCall, **_) -> str return f"ufuncs.{native_function_call.func.value}({arguments})" def visit_UnaryOp(self, unary_operator: UnaryOp, **_) -> str: - return unary_operator.op.value + " " + self.visit(unary_operator.expr) + return f"{unary_operator.op.value} {self.visit(unary_operator.expr)}" def visit_TernaryOp(self, ternary_operator: TernaryOp, **_) -> str: return f"{self.visit(ternary_operator.true_expr)} if {self.visit(ternary_operator.cond)} else {self.visit(ternary_operator.false_expr)}" @@ -339,13 +339,13 @@ def visit_LocalScalar(self, local_scalar: LocalScalar, **__) -> None: "This state should not be reached because LocalTemporariesToScalars should not have been called." ) - def visit_CacheDesc(self, cache_descriptor: CacheDesc, **_): + def visit_CacheDesc(self, cache_descriptor: CacheDesc, **_) -> None: raise NotImplementedError("CacheDescriptors should never be visited in the debug backends") - def visit_IJCache(self, ij_cache: IJCache, **_): + def visit_IJCache(self, ij_cache: IJCache, **_) -> None: raise NotImplementedError("IJCaches should never be visited in the debug backend.") - def visit_KCache(self, k_cache: KCache, **_): + def visit_KCache(self, k_cache: KCache, **_) -> None: raise NotImplementedError("KCaches should never be visited in the debug backend.") def visit_VerticalLoopSection(self, vertical_loop_section: VerticalLoopSection, **_) -> None: From 73aa6b5e766440fe37812c1836a66d9d56cddb5b Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Fri, 28 Feb 2025 13:51:26 +0100 Subject: [PATCH 13/19] clean up imports --- tests/cartesian_tests/definitions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 58eb002654..e19748864c 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -14,6 +14,7 @@ cp = None import datetime + import numpy as np import pytest From 64338de6a83f31311fb689d542df838c5743bf67 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 3 Mar 2025 06:52:36 +0100 Subject: [PATCH 14/19] pc hooks --- src/gt4py/cartesian/backend/debug_backend.py | 1 - src/gt4py/cartesian/backend/numpy_backend.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/gt4py/cartesian/backend/debug_backend.py b/src/gt4py/cartesian/backend/debug_backend.py index 45e1c85835..00021c2fa0 100644 --- a/src/gt4py/cartesian/backend/debug_backend.py +++ b/src/gt4py/cartesian/backend/debug_backend.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause -from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Type, Union from gt4py import storage diff --git a/src/gt4py/cartesian/backend/numpy_backend.py b/src/gt4py/cartesian/backend/numpy_backend.py index e61bbb69f8..55194f5e0d 100644 --- a/src/gt4py/cartesian/backend/numpy_backend.py +++ b/src/gt4py/cartesian/backend/numpy_backend.py @@ -8,7 +8,6 @@ from __future__ import annotations -import pathlib from typing import TYPE_CHECKING, Any, ClassVar, Dict, Type, Union, cast from gt4py import storage as gt_storage From 4c7edffb8bf0fe0c03472d16d369fcc8e8704462 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 3 Mar 2025 14:42:24 +0100 Subject: [PATCH 15/19] extract literal precision feature --- src/gt4py/cartesian/backend/debug_backend.py | 7 +++---- .../cartesian/frontend/gtscript_frontend.py | 19 +++++-------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/src/gt4py/cartesian/backend/debug_backend.py b/src/gt4py/cartesian/backend/debug_backend.py index 00021c2fa0..69dc99610a 100644 --- a/src/gt4py/cartesian/backend/debug_backend.py +++ b/src/gt4py/cartesian/backend/debug_backend.py @@ -41,11 +41,10 @@ class DebugBackend(BaseBackend, CLIBackendMixin): def generate_computation(self) -> dict[str, Union[str, dict]]: computation_name = ( - self.builder.caching.module_prefix - + "computation" - + self.builder.caching.module_postfix - + ".py" + f"{self.builder.caching.module_prefix}" + + f"computation{self.builder.caching.module_postfix}.py" ) + oir = GTIRToOIR().visit(self.builder.gtir) oir = HorizontalExecutionMerging().visit(oir) oir = LocalTemporariesToScalars().visit(oir) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index f56e23564c..bb6c7ca392 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1014,20 +1014,11 @@ def visit_Constant( loc=nodes.Location.from_ast_node(node), ) elif isinstance(value, numbers.Number): - if self.dtypes and type(value) in self.dtypes.keys(): - value_type = self.dtypes[type(value)] - else: - if build_settings["literal_floating_point_precision"] is not None: - if isinstance(value, int): - value_type = np.dtype( - f"i{int(int(build_settings['literal_floating_point_precision'])/8)}" - ) - else: - value_type = np.dtype( - f"f{int(int(build_settings['literal_floating_point_precision'])/8)}" - ) - else: - value_type = np.dtype(type(value)) + value_type = ( + self.dtypes[type(value)] + if self.dtypes and type(value) in self.dtypes.keys() + else np.dtype(type(value)) + ) data_type = nodes.DataType.from_dtype(value_type) return nodes.ScalarLiteral(value=value, data_type=data_type) else: From e596c5112c2858d0a374e91da22627f2a60b7eff Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 3 Mar 2025 14:42:37 +0100 Subject: [PATCH 16/19] remove literal precision --- src/gt4py/cartesian/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/cartesian/config.py b/src/gt4py/cartesian/config.py index 4719fed945..5aa32506b7 100644 --- a/src/gt4py/cartesian/config.py +++ b/src/gt4py/cartesian/config.py @@ -65,7 +65,6 @@ "extra_link_args": extra_link_args, "parallel_jobs": multiprocessing.cpu_count(), "cpp_template_depth": os.environ.get("GT_CPP_TEMPLATE_DEPTH", GT_CPP_TEMPLATE_DEPTH), - "literal_floating_point_precision": os.environ.get("GT4PY_LITERAL_PRECISION", None), } if GT4PY_USE_HIP: build_settings["cuda_library_path"] = os.path.join(CUDA_ROOT, "lib") From 12efff930df66333dc4813c8bb1460ae43c4880d Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 3 Mar 2025 14:56:24 +0100 Subject: [PATCH 17/19] generalize debuck backend tests --- .../cartesian/frontend/gtscript_frontend.py | 1 - .../test_code_generation.py | 332 ++++++++++++++++- .../multi_feature_tests/test_debug_backend.py | 352 ------------------ 3 files changed, 327 insertions(+), 358 deletions(-) delete mode 100644 tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index bb6c7ca392..4d8ac98529 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -21,7 +21,6 @@ import numpy as np from gt4py.cartesian import definitions as gt_definitions, gtscript, utils as gt_utils -from gt4py.cartesian.config import build_settings from gt4py.cartesian.frontend import node_util, nodes from gt4py.cartesian.frontend.defir_to_gtir import DefIRToGTIR, UnrollVectorAssignments from gt4py.cartesian.gtc import utils as gtc_utils diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 4e0fa8903c..ca2a19a8ee 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -24,14 +24,11 @@ horizontal, interval, region, + sin, ) from gt4py.storage.cartesian import utils as storage_utils -from cartesian_tests.definitions import ( - ALL_BACKENDS, - CPU_BACKENDS, - get_array_library, -) +from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS, get_array_library from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import ( EXTERNALS_REGISTRY as externals_registry, REGISTRY as stencil_definitions, @@ -563,6 +560,331 @@ def k_to_ijk(outp: Field[np.float64], inp: Field[gtscript.K, np.float64]): np.testing.assert_allclose(0.0, outp[:, :, -1]) +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_tmp_stencil(backend): + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend=backend) + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore + with computation(PARALLEL): + with interval(...): + tmp = field_in + 1 + with computation(PARALLEL): + with interval(...): + field_out[0, 0, 0] = tmp[-1, 0, 0] + tmp[1, 0, 0] + + stencil(field_in, field_out, origin=(1, 1, 0), domain=(4, 4, 6)) + + # the inside of the domain is 4 + np.testing.assert_allclose(field_out.view(np.ndarray)[1:-1, 1:-1, :], 4) + # the rest is 0 + np.testing.assert_allclose(field_out.view(np.ndarray)[0:1, :, :], 0) + np.testing.assert_allclose(field_out.view(np.ndarray)[-1:, :, :], 0) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, 0:1, :], 0) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, -1:, :], 0) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_backward_stencil(backend): + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend=backend) + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore + with computation(BACKWARD): + with interval(-1, None): + field_in = 2 + field_out = field_in + with interval(0, -1): + field_in = field_in[0, 0, 1] + 1 + field_out[0, 0, 0] = field_in + + stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 0], 5) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 4) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 2], 3) + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 3], 2) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_while_stencil(backend): + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend=backend) + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore + with computation(PARALLEL): + with interval(...): + while field_in < 10: + field_in += 1 + field_out[0, 0, 0] = field_in + + stencil(field_in, field_out) + + # the inside of the domain is 10 + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 10) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_higher_dim_literal_stencil(backend): + FLOAT64_NDDIM = (np.float64, (4,)) + + field_in = gt_storage.ones( + dtype=FLOAT64_NDDIM, backend=backend, shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_in[:, :, :, 2] = 5 + + @gtscript.stencil(backend=backend) + def stencil( + vec_field: gtscript.Field[FLOAT64_NDDIM], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = vec_field[0, 0, 0][2] + + stencil(field_in, field_out) + + # the inside of the domain is 5 + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 5) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_higher_dim_scalar_stencil(backend): + FLOAT64_NDDIM = (np.float64, (4,)) + + field_in = gt_storage.ones( + dtype=FLOAT64_NDDIM, backend=backend, shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(6, 6, 6), aligned_index=(0, 0, 0) + ) + field_in[:, :, :, 2] = 5 + + @gtscript.stencil(backend=backend) + def stencil( + vec_field: gtscript.Field[FLOAT64_NDDIM], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + scalar_argument: int, + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = vec_field[0, 0, 0][scalar_argument] + + stencil(field_in, field_out, 2) + + # the inside of the domain is 5 + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 5) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_native_function_call_stencil(backend): + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend=backend) + def test_stencil( + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = in_field[0, 0, 0] + sin(0.848062) + + test_stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 1.75) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_unary_operator_stencil(backend): + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + + @gtscript.stencil(backend=backend) + def test_stencil( + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = -in_field[0, 0, 0] + + test_stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], -1) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_ternary_operator_stencil(backend): + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[0, 0, 1] = 20 + + @gtscript.stencil(backend=backend) + def test_stencil( + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + ): + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = in_field[0, 0, 0] if in_field > 10 else in_field[0, 0, 0] + 1 + + test_stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[0, 0, 1], 20) + np.testing.assert_allclose(field_out.view(np.ndarray)[1:, 1:, 1], 2) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_mask_stencil(backend): + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[0, 0, 1] = -20 + + @gtscript.stencil(backend=backend) + def test_stencil( + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + ): + with computation(PARALLEL), interval(...): + if in_field[0, 0, 0] > 0: + out_field[0, 0, 0] = in_field + else: + out_field[0, 0, 0] = 1 + + test_stencil(field_in, field_out) + + assert np.all(field_out.view(np.ndarray) > 0) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_k_offset_stencil(backend): + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[:, :, 0] *= 10 + offset = -1 + + @gtscript.stencil(backend=backend) + def test_stencil( + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + scalar_value: int, + ): + with computation(PARALLEL), interval(1, None): + out_field[0, 0, 0] = in_field[0, 0, scalar_value] + + test_stencil(field_in, field_out, offset) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 10) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_k_offset_field_stencil(backend): + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_idx = gt_storage.ones(dtype=np.int64, backend=backend, shape=(4, 4), aligned_index=(0, 0)) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[:, :, 0] *= 10 + field_idx[:, :] *= -2 + + @gtscript.stencil(backend=backend) + def test_stencil( + in_field: gtscript.Field[np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + idx_field: gtscript.Field[gtscript.IJ, np.int64], # type: ignore + ): + with computation(PARALLEL), interval(1, None): + out_field[0, 0, 0] = in_field[0, 0, idx_field + 1] + + test_stencil(field_in, field_out, field_idx) + + np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 10) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_k_only_access_stencil(backend): + field_in = np.ones((4,), dtype=np.float64) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + field_in[:] = [2, 3, 4, 5] + + @gtscript.stencil(backend=backend) + def test_stencil( + in_field: gtscript.Field[gtscript.K, np.float64], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + ): + with computation(PARALLEL): + with interval(0, 1): + out_field[0, 0, 0] = in_field[1] + with interval(1, None): + out_field[0, 0, 0] = in_field[-1] + + test_stencil(field_in, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[1, 1, :], [3, 2, 3, 4]) + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_table_access_stencil(backend): + table_view = np.ones((4,), dtype=np.float64) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(4, 4, 4), aligned_index=(0, 0, 0) + ) + table_view[:] = [2, 3, 4, 5] + + @gtscript.stencil(backend=backend) + def test_stencil( + table_view: gtscript.GlobalTable[(np.float64, (4))], # type: ignore + out_field: gtscript.Field[np.float64], # type: ignore + ): + with computation(PARALLEL): + with interval(0, 1): + out_field[0, 0, 0] = table_view.A[1] + with interval(1, None): + out_field[0, 0, 0] = table_view.A[2] + + test_stencil(table_view, field_out) + + np.testing.assert_allclose(field_out.view(np.ndarray)[1, 1, :], [3, 4, 4, 4]) + + @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_pruned_args_match(backend): @gtscript.stencil(backend=backend) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py deleted file mode 100644 index 7c5d0509af..0000000000 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_debug_backend.py +++ /dev/null @@ -1,352 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import numpy as np - -from gt4py import storage as gt_storage -from gt4py.cartesian import gtscript -from gt4py.cartesian.gtscript import BACKWARD, PARALLEL, computation, interval, sin - - -def test_simple_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - - @gtscript.stencil(backend="debug") - def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore - with computation(BACKWARD): - with interval(-2, -1): # block 1 - field_out = field_in - with interval(0, -2): # block 2 - field_out = field_in - with computation(BACKWARD): - with interval(-1, None): # block 3 - field_out = 2 * field_in - with interval(0, -1): # block 4 - field_out[0, 0, 0] = 3 * field_in - - stencil(field_in, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 0:-1], 3) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, -1], 2) - - -def test_tmp_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - - @gtscript.stencil(backend="debug") - def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore - with computation(PARALLEL): - with interval(...): - tmp = field_in + 1 - with computation(PARALLEL): - with interval(...): - field_out[0, 0, 0] = tmp[-1, 0, 0] + tmp[1, 0, 0] - - stencil(field_in, field_out, origin=(1, 1, 0), domain=(4, 4, 6)) - - # the inside of the domain is 4 - np.testing.assert_allclose(field_out.view(np.ndarray)[1:-1, 1:-1, :], 4) - # the rest is 0 - np.testing.assert_allclose(field_out.view(np.ndarray)[0:1, :, :], 0) - np.testing.assert_allclose(field_out.view(np.ndarray)[-1:, :, :], 0) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, 0:1, :], 0) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, -1:, :], 0) - - -def test_backward_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - - @gtscript.stencil(backend="debug") - def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore - with computation(BACKWARD): - with interval(-1, None): - field_in = 2 - field_out = field_in - with interval(0, -1): - field_in = field_in[0, 0, 1] + 1 - field_out[0, 0, 0] = field_in - - stencil(field_in, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 0], 5) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 4) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 2], 3) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 3], 2) - - -def test_while_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - - @gtscript.stencil(backend="debug") - def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): # type: ignore - with computation(PARALLEL): - with interval(...): - while field_in < 10: - field_in += 1 - field_out[0, 0, 0] = field_in - - stencil(field_in, field_out) - - # the inside of the domain is 10 - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 10) - - -def test_higher_dim_literal_stencil(): - FLOAT64_NDDIM = (np.float64, (4,)) - - field_in = gt_storage.ones( - dtype=FLOAT64_NDDIM, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - field_in[:, :, :, 2] = 5 - - @gtscript.stencil(backend="debug") - def stencil( - vec_field: gtscript.Field[FLOAT64_NDDIM], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - ): - with computation(PARALLEL), interval(...): - out_field[0, 0, 0] = vec_field[0, 0, 0][2] - - stencil(field_in, field_out) - - # the inside of the domain is 5 - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 5) - - -def test_higher_dim_scalar_stencil(): - FLOAT64_NDDIM = (np.float64, (4,)) - - field_in = gt_storage.ones( - dtype=FLOAT64_NDDIM, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(6, 6, 6), aligned_index=(0, 0, 0) - ) - field_in[:, :, :, 2] = 5 - - @gtscript.stencil(backend="debug") - def stencil( - vec_field: gtscript.Field[FLOAT64_NDDIM], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - scalar_argument: int, - ): - with computation(PARALLEL), interval(...): - out_field[0, 0, 0] = vec_field[0, 0, 0][scalar_argument] - - stencil(field_in, field_out, 2) - - # the inside of the domain is 5 - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 5) - - -def test_native_function_call_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - - @gtscript.stencil(backend="debug") - def test_stencil( - in_field: gtscript.Field[np.float64], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - ): - with computation(PARALLEL), interval(...): - out_field[0, 0, 0] = in_field[0, 0, 0] + sin(0.848062) - - test_stencil(field_in, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 1.75) - - -def test_unary_operator_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - - @gtscript.stencil(backend="debug") - def test_stencil( - in_field: gtscript.Field[np.float64], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - ): - with computation(PARALLEL), interval(...): - out_field[0, 0, 0] = -in_field[0, 0, 0] - - test_stencil(field_in, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], -1) - - -def test_ternary_operator_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_in[0, 0, 1] = 20 - - @gtscript.stencil(backend="debug") - def test_stencil( - in_field: gtscript.Field[np.float64], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - ): - with computation(PARALLEL), interval(...): - out_field[0, 0, 0] = in_field[0, 0, 0] if in_field > 10 else in_field[0, 0, 0] + 1 - - test_stencil(field_in, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[0, 0, 1], 20) - np.testing.assert_allclose(field_out.view(np.ndarray)[1:, 1:, 1], 2) - - -def test_mask_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_in[0, 0, 1] = -20 - - @gtscript.stencil(backend="debug") - def test_stencil( - in_field: gtscript.Field[np.float64], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - ): - with computation(PARALLEL), interval(...): - if in_field[0, 0, 0] > 0: - out_field[0, 0, 0] = in_field - else: - out_field[0, 0, 0] = 1 - - test_stencil(field_in, field_out) - - assert np.all(field_out.view(np.ndarray) > 0) - - -def test_k_offset_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_in[:, :, 0] *= 10 - offset = -1 - - @gtscript.stencil(backend="debug") - def test_stencil( - in_field: gtscript.Field[np.float64], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - scalar_value: int, - ): - with computation(PARALLEL), interval(1, None): - out_field[0, 0, 0] = in_field[0, 0, scalar_value] - - test_stencil(field_in, field_out, offset) - - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 10) - - -def test_k_offset_field_stencil(): - field_in = gt_storage.ones( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_idx = gt_storage.ones(dtype=np.int64, backend="debug", shape=(4, 4), aligned_index=(0, 0)) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_in[:, :, 0] *= 10 - field_idx[:, :] *= -2 - - @gtscript.stencil(backend="debug") - def test_stencil( - in_field: gtscript.Field[np.float64], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - idx_field: gtscript.Field[gtscript.IJ, np.int64], # type: ignore - ): - with computation(PARALLEL), interval(1, None): - out_field[0, 0, 0] = in_field[0, 0, idx_field + 1] - - test_stencil(field_in, field_out, field_idx) - - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 10) - - -def test_k_only_access_stencil(): - field_in = np.ones((4,), dtype=np.float64) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - field_in[:] = [2, 3, 4, 5] - - @gtscript.stencil(backend="debug") - def test_stencil( - in_field: gtscript.Field[gtscript.K, np.float64], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - ): - with computation(PARALLEL): - with interval(0, 1): - out_field[0, 0, 0] = in_field[1] - with interval(1, None): - out_field[0, 0, 0] = in_field[-1] - - test_stencil(field_in, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[1, 1, :], [3, 2, 3, 4]) - - -def test_table_access_stencil(): - table_view = np.ones((4,), dtype=np.float64) - field_out = gt_storage.zeros( - dtype=np.float64, backend="debug", shape=(4, 4, 4), aligned_index=(0, 0, 0) - ) - table_view[:] = [2, 3, 4, 5] - - @gtscript.stencil(backend="debug") - def test_stencil( - table_view: gtscript.GlobalTable[(np.float64, (4))], # type: ignore - out_field: gtscript.Field[np.float64], # type: ignore - ): - with computation(PARALLEL): - with interval(0, 1): - out_field[0, 0, 0] = table_view.A[1] - with interval(1, None): - out_field[0, 0, 0] = table_view.A[2] - - test_stencil(table_view, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[1, 1, :], [3, 4, 4, 4]) From 321831ad50bb4f884e5d173e482dd0ceadc7a593 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Mon, 3 Mar 2025 15:54:32 +0100 Subject: [PATCH 18/19] discussion roman --- src/gt4py/cartesian/backend/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 9772f46ab0..22ca6dd44d 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -347,6 +347,10 @@ def generate_computation(self) -> Dict[str, Union[str, Dict]]: source = self.make_module_source(ir=self.builder.gtir) return {str(file_name): source} + def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: + """Pure python backends typically will not support bindings.""" + return super().generate_bindings(language_name) + class BasePyExtBackend(BaseBackend): @property From 0570c802177b0a8dfe3e0e2332a705fc911842ae Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Wed, 5 Mar 2025 17:49:07 +0100 Subject: [PATCH 19/19] add cpu copies to tests --- .../test_code_generation.py | 57 +++++++++++-------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index ca2a19a8ee..406d6fc04d 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -581,12 +581,13 @@ def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.f stencil(field_in, field_out, origin=(1, 1, 0), domain=(4, 4, 6)) # the inside of the domain is 4 - np.testing.assert_allclose(field_out.view(np.ndarray)[1:-1, 1:-1, :], 4) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[1:-1, 1:-1, :], 4) # the rest is 0 - np.testing.assert_allclose(field_out.view(np.ndarray)[0:1, :, :], 0) - np.testing.assert_allclose(field_out.view(np.ndarray)[-1:, :, :], 0) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, 0:1, :], 0) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, -1:, :], 0) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[0:1, :, :], 0) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[-1:, :, :], 0) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, 0:1, :], 0) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, -1:, :], 0) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -610,10 +611,11 @@ def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.f stencil(field_in, field_out) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 0], 5) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 4) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 2], 3) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 3], 2) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, 0], 5) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, 1], 4) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, 2], 3) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, 3], 2) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -636,7 +638,8 @@ def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.f stencil(field_in, field_out) # the inside of the domain is 10 - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 10) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, :], 10) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -662,7 +665,8 @@ def stencil( stencil(field_in, field_out) # the inside of the domain is 5 - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 5) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, :], 5) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -689,7 +693,8 @@ def stencil( stencil(field_in, field_out, 2) # the inside of the domain is 5 - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 5) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, :], 5) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -710,8 +715,8 @@ def test_stencil( out_field[0, 0, 0] = in_field[0, 0, 0] + sin(0.848062) test_stencil(field_in, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], 1.75) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, :], 1.75) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -732,8 +737,8 @@ def test_stencil( out_field[0, 0, 0] = -in_field[0, 0, 0] test_stencil(field_in, field_out) - - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, :], -1) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, :], -1) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -756,8 +761,9 @@ def test_stencil( test_stencil(field_in, field_out) - np.testing.assert_allclose(field_out.view(np.ndarray)[0, 0, 1], 20) - np.testing.assert_allclose(field_out.view(np.ndarray)[1:, 1:, 1], 2) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[0, 0, 1], 20) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[1:, 1:, 1], 2) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -783,7 +789,8 @@ def test_stencil( test_stencil(field_in, field_out) - assert np.all(field_out.view(np.ndarray) > 0) + cpu_output = storage_utils.cpu_copy(field_out) + assert np.all(cpu_output.view(np.ndarray) > 0) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -808,7 +815,8 @@ def test_stencil( test_stencil(field_in, field_out, offset) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 10) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, 1], 10) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -834,7 +842,8 @@ def test_stencil( test_stencil(field_in, field_out, field_idx) - np.testing.assert_allclose(field_out.view(np.ndarray)[:, :, 1], 10) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[:, :, 1], 10) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -858,7 +867,8 @@ def test_stencil( test_stencil(field_in, field_out) - np.testing.assert_allclose(field_out.view(np.ndarray)[1, 1, :], [3, 2, 3, 4]) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[1, 1, :], [3, 2, 3, 4]) @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -882,7 +892,8 @@ def test_stencil( test_stencil(table_view, field_out) - np.testing.assert_allclose(field_out.view(np.ndarray)[1, 1, :], [3, 4, 4, 4]) + cpu_output = storage_utils.cpu_copy(field_out) + np.testing.assert_allclose(cpu_output.view(np.ndarray)[1, 1, :], [3, 4, 4, 4]) @pytest.mark.parametrize("backend", ALL_BACKENDS)