Skip to content

Commit c610561

Browse files
authored
refactor[cartesian]: gt4py/dace bridge cleanup (#1895)
## Description In preparation for PR #1894, pull out some refactors and cleanups. Notable in this PR are the changes to `src/gt4py/cartesian/gtc/dace/oir_to_dace.py` - visit `stencil.vertical_loops` directly instead of calling `generic_visit` (simplification since there's nothing else to visit) - rename library nodes from `f"{sdfg_name}_computation_{id(node)}"` to `f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}"`. This adds a bit more information (because `sdfg_name` is the same for all library nodes) and thus simplifies debugging workflows. Related issue: GEOS-ESM/NDSL#53 ## Requirements - [x] All fixes and/or new features come with corresponding tests. covered by existing test suite - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <[email protected]>
1 parent 28eba10 commit c610561

File tree

5 files changed

+31
-30
lines changed

5 files changed

+31
-30
lines changed

src/gt4py/cartesian/backend/base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ def generate_computation(self) -> Dict[str, Union[str, Dict]]:
172172
Returns
173173
-------
174174
Dict[str, str | Dict] of source file names / directories -> contents:
175-
If a key's value is a string it is interpreted as a file name and the value as the
176-
source code of that file
177-
If a key's value is a Dict, it is interpreted as a directory name and it's
175+
If a key's value is a string, it is interpreted as a file name and its value as the
176+
source code of that file.
177+
If a key's value is a Dict, it is interpreted as a directory name and its
178178
value as a nested file hierarchy to which the same rules are applied recursively.
179179
The root path is relative to the build directory.
180180
@@ -222,7 +222,7 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]:
222222
223223
Returns
224224
-------
225-
Analog to :py:meth:`generate_computation` but containing bindings source code, The
225+
Analog to :py:meth:`generate_computation` but containing bindings source code. The
226226
dictionary contains a tree of directories with leaves being a mapping from filename to
227227
source code pairs, relative to the build directory.
228228

src/gt4py/cartesian/backend/dace_backend.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from gt4py.cartesian.backend.module_generator import make_args_data_from_gtir
3333
from gt4py.cartesian.gtc import common, gtir
34+
from gt4py.cartesian.gtc.dace import daceir as dcir
3435
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
3536
from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder
3637
from gt4py.cartesian.gtc.dace.transformations import (
@@ -119,8 +120,6 @@ def _set_expansion_orders(sdfg: dace.SDFG):
119120

120121

121122
def _set_tile_sizes(sdfg: dace.SDFG):
122-
import gt4py.cartesian.gtc.dace.daceir as dcir # avoid circular import
123-
124123
for node, _ in filter(
125124
lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive()
126125
):

src/gt4py/cartesian/gtc/dace/daceir.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ def tile_symbol(self) -> eve.SymbolRef:
5151
return eve.SymbolRef("__tile_" + self.lower())
5252

5353
@staticmethod
54-
def dims_3d() -> Generator["Axis", None, None]:
54+
def dims_3d() -> Generator[Axis, None, None]:
5555
yield from [Axis.I, Axis.J, Axis.K]
5656

5757
@staticmethod
58-
def dims_horizontal() -> Generator["Axis", None, None]:
58+
def dims_horizontal() -> Generator[Axis, None, None]:
5959
yield from [Axis.I, Axis.J]
6060

6161
def to_idx(self) -> int:
@@ -357,7 +357,7 @@ def free_symbols(self) -> Set[eve.SymbolRef]:
357357

358358

359359
class GridSubset(eve.Node):
360-
intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]]
360+
intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]]
361361

362362
def __iter__(self):
363363
for axis in Axis.dims_3d():
@@ -429,10 +429,10 @@ def from_gt4py_extent(cls, extent: gt4py.cartesian.gtc.definitions.Extent):
429429
@classmethod
430430
def from_interval(
431431
cls,
432-
interval: Union[oir.Interval, TileInterval, DomainInterval, IndexWithExtent],
432+
interval: Union[DomainInterval, IndexWithExtent, oir.Interval, TileInterval],
433433
axis: Axis,
434434
):
435-
res_interval: Union[IndexWithExtent, TileInterval, DomainInterval]
435+
res_interval: Union[DomainInterval, IndexWithExtent, TileInterval]
436436
if isinstance(interval, (DomainInterval, oir.Interval)):
437437
res_interval = DomainInterval(
438438
start=AxisBound(
@@ -441,7 +441,7 @@ def from_interval(
441441
end=AxisBound(level=interval.end.level, offset=interval.end.offset, axis=Axis.K),
442442
)
443443
else:
444-
assert isinstance(interval, (TileInterval, IndexWithExtent))
444+
assert isinstance(interval, (IndexWithExtent, TileInterval))
445445
res_interval = interval
446446

447447
return cls(intervals={axis: res_interval})
@@ -464,7 +464,7 @@ def full_domain(cls, axes=None):
464464
return GridSubset(intervals=res_subsets)
465465

466466
def tile(self, tile_sizes: Dict[Axis, int]):
467-
res_intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]] = {}
467+
res_intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] = {}
468468
for axis, interval in self.intervals.items():
469469
if isinstance(interval, DomainInterval) and axis in tile_sizes:
470470
if axis == Axis.K:
@@ -505,15 +505,15 @@ def union(self, other):
505505
intervals[axis] = interval1.union(interval2)
506506
else:
507507
assert (
508-
isinstance(interval2, (TileInterval, DomainInterval))
509-
and isinstance(interval1, (IndexWithExtent, DomainInterval))
508+
isinstance(interval2, (DomainInterval, TileInterval))
509+
and isinstance(interval1, (DomainInterval, IndexWithExtent))
510510
) or (
511-
isinstance(interval1, (TileInterval, DomainInterval))
511+
isinstance(interval1, (DomainInterval, TileInterval))
512512
and isinstance(interval2, IndexWithExtent)
513513
)
514514
intervals[axis] = (
515515
interval1
516-
if isinstance(interval1, (TileInterval, DomainInterval))
516+
if isinstance(interval1, (DomainInterval, TileInterval))
517517
else interval2
518518
)
519519
return GridSubset(intervals=intervals)
@@ -747,7 +747,7 @@ class IndexAccess(common.FieldAccess, Expr):
747747
offset: Optional[Union[common.CartesianOffset, VariableKOffset]]
748748

749749

750-
class AssignStmt(common.AssignStmt[Union[ScalarAccess, IndexAccess], Expr], Stmt):
750+
class AssignStmt(common.AssignStmt[Union[IndexAccess, ScalarAccess], Expr], Stmt):
751751
_dtype_validation = common.assign_stmt_dtype_validation(strict=True)
752752

753753

@@ -851,14 +851,14 @@ class Tasklet(ComputationNode, IterationNode, eve.SymbolTableTrait):
851851
class DomainMap(ComputationNode, IterationNode):
852852
index_ranges: List[Range]
853853
schedule: MapSchedule
854-
computations: List[Union[Tasklet, DomainMap, NestedSDFG]]
854+
computations: List[Union[DomainMap, NestedSDFG, Tasklet]]
855855

856856

857857
class ComputationState(IterationNode):
858-
computations: List[Union[Tasklet, DomainMap]]
858+
computations: List[Union[DomainMap, Tasklet]]
859859

860860

861-
class DomainLoop(IterationNode, ComputationNode):
861+
class DomainLoop(ComputationNode, IterationNode):
862862
axis: Axis
863863
index_range: Range
864864
loop_states: List[Union[ComputationState, DomainLoop]]
@@ -868,7 +868,7 @@ class NestedSDFG(ComputationNode, eve.SymbolTableTrait):
868868
label: eve.Coerced[eve.SymbolRef]
869869
field_decls: List[FieldDecl]
870870
symbol_decls: List[SymbolDecl]
871-
states: List[Union[DomainLoop, ComputationState]]
871+
states: List[Union[ComputationState, DomainLoop]]
872872

873873

874874
# There are circular type references with string placeholders. These statements let datamodels resolve those.

src/gt4py/cartesian/gtc/dace/oir_to_dace.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class SDFGContext:
4141
decls: Dict[str, oir.Decl]
4242
block_extents: Dict[int, Extent]
4343
access_infos: Dict[str, dcir.FieldAccessInfo]
44+
loop_counter: int = 0
4445

4546
def __init__(self, stencil: oir.Stencil):
4647
self.sdfg = dace.SDFG(stencil.name)
@@ -98,14 +99,21 @@ def _make_dace_subset(self, local_access_info, field):
9899
global_access_info, local_access_info, self.decls[field].data_dims
99100
)
100101

102+
def _vloop_name(self, node: oir.VerticalLoop, ctx: OirSDFGBuilder.SDFGContext) -> str:
103+
sdfg_name = ctx.sdfg.name
104+
counter = ctx.loop_counter
105+
ctx.loop_counter += 1
106+
107+
return f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}"
108+
101109
def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext):
102110
declarations = {
103111
acc.name: ctx.decls[acc.name]
104112
for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess)
105113
if acc.name in ctx.decls
106114
}
107115
library_node = StencilComputation(
108-
name=f"{ctx.sdfg.name}_computation_{id(node)}",
116+
name=self._vloop_name(node, ctx),
109117
extents=ctx.block_extents,
110118
declarations=declarations,
111119
oir_node=node,
@@ -174,6 +182,6 @@ def visit_Stencil(self, node: oir.Stencil):
174182
lifetime=dace.AllocationLifetime.Persistent,
175183
debuginfo=get_dace_debuginfo(decl),
176184
)
177-
self.generic_visit(node, ctx=ctx)
185+
self.visit(node.vertical_loops, ctx=ctx)
178186
ctx.sdfg.validate()
179187
return ctx.sdfg

tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py

-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# SPDX-License-Identifier: BSD-3-Clause
88

99
import numpy as np
10-
import pytest
1110

1211
from gt4py.cartesian import gtscript, testing as gt_testing
1312
from gt4py.cartesian.gtscript import (
@@ -25,7 +24,6 @@
2524
from .stencil_definitions import optional_field, two_optional_fields
2625

2726

28-
# ---- Identity stencil ----
2927
class TestIdentity(gt_testing.StencilTestSuite):
3028
"""Identity stencil."""
3129

@@ -43,7 +41,6 @@ def validation(field_a, domain=None, origin=None):
4341
pass
4442

4543

46-
# ---- Copy stencil ----
4744
class TestCopy(gt_testing.StencilTestSuite):
4845
"""Copy stencil."""
4946

@@ -86,7 +83,6 @@ def validation(field_a, field_b, domain=None, origin=None):
8683
field_b[...] = (field_b[...] - 1.0) / 2.0
8784

8885

89-
# ---- Scale stencil ----
9086
class TestGlobalScale(gt_testing.StencilTestSuite):
9187
"""Scale stencil using a global global_name."""
9288

@@ -108,7 +104,6 @@ def validation(field_a, domain, origin, **kwargs):
108104
field_a[...] = SCALE_FACTOR * field_a # noqa: F821 [undefined-name]
109105

110106

111-
# ---- Parametric scale stencil -----
112107
class TestParametricScale(gt_testing.StencilTestSuite):
113108
"""Scale stencil using a parameter."""
114109

@@ -128,7 +123,6 @@ def validation(field_a, *, scale, domain, origin, **kwargs):
128123
field_a[...] = scale * field_a
129124

130125

131-
# --- Parametric-mix stencil ----
132126
class TestParametricMix(gt_testing.StencilTestSuite):
133127
"""Linear combination of input fields using several parameters."""
134128

0 commit comments

Comments
 (0)