Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[cartesian]: Dace backend: expose control flow #1894

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
7a9b083
WIP: Updated daceir with Condition & WhileLoop
Oct 21, 2024
885537d
WIP: tasklet_codegen typehints & cleanups
Oct 21, 2024
8de07f3
Add CodeBlock to OIR-level
Oct 21, 2024
d09a11e
WIP: Update daceir to avoid unecessary nested SDFG
Oct 21, 2024
79dab2f
sdfg context: add support for condition and while
Oct 21, 2024
c303d36
sdfg builder: add visit_{Condition, WhileLoop}
Oct 21, 2024
c8efbd3
WIP: Fix the obivious issues
Oct 21, 2024
1780f9f
WIP: fix node_ctx issue in nested sdfgs
Oct 21, 2024
24c4a53
WIP: ... aaaand we are back to variable shadowing
Oct 21, 2024
5c86ad4
WIP: we don't (shouldn't) have local scalars
Oct 22, 2024
a06c2d7
Fixed variable shadowing in the newer version
Oct 22, 2024
1ac6fc4
Fix typing issue highlighted by tests
Oct 22, 2024
9883755
WIP: Move condition evaluation tasklet to daceir builder
Oct 22, 2024
759a742
Fix issue with extra map entry nodes
Oct 24, 2024
74c83ae
Fix type issue raised in gt4py tests
Oct 24, 2024
484c23b
No Tasklet without ComputationState in DomainMap
Oct 25, 2024
5844f76
WIP: ignore k offset write book keeping
Oct 28, 2024
33527be
WIP: separate node context for eval Tasklet
Oct 29, 2024
51df134
This seems to fix the if-only k-offset write
Oct 30, 2024
ab5d867
WIP: this fails because of duplicate connectors
Oct 30, 2024
3e8afa1
WIP: working on export / import of local scalars
Oct 30, 2024
5746e31
This seems to fix the column_physics_conditional \o/
Oct 30, 2024
f0b2948
This seems to fix the read after write issue
Oct 31, 2024
feea69f
visit node.data_index for oir.FieldAccess nodes
Nov 1, 2024
12a1247
For testing: let's see if this fixes k offset write
Nov 4, 2024
f399097
Revert "For testing: let's see if this fixes k offset write"
Nov 5, 2024
d902b74
Cleanup that makes pre-commit happy
Nov 5, 2024
a3df2ae
Cleanup: We don't collect targets in Tasklet codegen
Nov 5, 2024
59d8282
Attempt to fix dyn memlets and one of the tests
Nov 27, 2024
cf50065
WIP: before starting to mess with memlet generation
Dec 4, 2024
42e43fb
WIP: comments what to do where for specialized memlets
Dec 4, 2024
3e7aa7d
WIP: First version of targetted memlets (per tasklet)
Jan 6, 2025
bf0c54f
WIP: next version of targetted memlets
Jan 7, 2025
0e5ff06
WIP: bit of cleanup from the last commit
Jan 8, 2025
f49cad7
Add explicit indexing as used in unstable develop
Jan 8, 2025
012bab0
WIP: this seems to fix column-physics-conditional
Jan 8, 2025
58e1f96
Fix: correct subset for horizontal regions
Jan 9, 2025
0aafe74
Temporarily skip `ScalarToSymbolPromotion` in dace
Jan 9, 2025
2f30389
Formatting and linting
Jan 10, 2025
2889a03
WIP: support for control flow inside regions
Jan 10, 2025
0eedbb2
Just move some code
Jan 14, 2025
c389648
Revert: Skipping `ScalarToSymbolPromotion`
Jan 16, 2025
240f8d2
WIP: fix horizontal regions issue and skip InlineThreadLocalTransients
romanc Feb 5, 2025
a785091
Fix dangling connections after expansion
romanc Feb 17, 2025
660c55f
Fix read after write index access in condition tasklet
romanc Feb 19, 2025
e0793fa
Fix typo in comment (added with this PR)
romanc Feb 20, 2025
7783bd5
Read after write: tasklet validator
romanc Feb 24, 2025
3b3115c
Tasklet labels and no local tasklet declarations anymore
romanc Feb 25, 2025
436cd24
Use more specific dcir.VariableKOffset
romanc Mar 3, 2025
3bed806
Re-evaluate `InlineThreadLocalTransients` later
romanc Mar 3, 2025
91cd73d
Undo unrelated changes
romanc Mar 4, 2025
fe97249
Remove now unused indirection
romanc Mar 4, 2025
4497574
DaceIR cleanups
romanc Mar 4, 2025
4800c0d
Variable for Tasklet in/out prefixes
romanc Mar 4, 2025
6d82922
Remove debug print statements
romanc Mar 4, 2025
ec5e041
More cleanup
romanc Mar 4, 2025
3d6f5d9
New-style type annotations
romanc Mar 4, 2025
50f1a80
WIP: cleanups in daceir_builder
romanc Mar 4, 2025
c0f54a9
Break circular import loop
romanc Mar 4, 2025
70109bd
More cleanups in daceir_builder
romanc Mar 4, 2025
2132258
Cleanups in SDFG builder / tasklet codegen
romanc Mar 4, 2025
3d8e717
Fix the newly added assert statement
romanc Mar 4, 2025
eaf9b5c
factor out _explicit_indexing function
romanc Mar 5, 2025
6e3dd9c
Tests for get_tasklet_symbol()
romanc Mar 5, 2025
34b82c2
Basic test cases for unexpanded SDFG
romanc Mar 6, 2025
636eb47
Clarify node_ctx dropping
romanc Mar 7, 2025
30b7eb1
Better comments
romanc Mar 7, 2025
3806720
Cleanup test imports and dace markers
romanc Mar 7, 2025
1d216f7
WIP: More DaCe backend tests
romanc Mar 7, 2025
d25ccca
Add tests for daceir and sdfg builder
romanc Mar 10, 2025
98296e6
reorganize all the prefix strings
romanc Mar 11, 2025
1c6721e
Review: use elif to reduce indent level
romanc Mar 12, 2025
83cb69d
Review: rename Condition.{true,false}_states
romanc Mar 12, 2025
679184c
Review: leverage eve's built-in filter function
romanc Mar 13, 2025
3189955
Review: removed wrong return type, added doc string
romanc Mar 13, 2025
d13799a
Review: make _global_grid_subset() a regular class method
romanc Mar 13, 2025
2e48c20
Review: simplify code
romanc Mar 13, 2025
9a39394
Review: remove extra space in test
romanc Mar 13, 2025
bc9cd41
Review: add comment to clarify usage of is_write
romanc Mar 13, 2025
8534c7c
Review: Move comment to docstring
romanc Mar 13, 2025
69b7a36
Review: list() -> []
romanc Mar 13, 2025
617290d
Review: remove intermediate list
romanc Mar 13, 2025
5ed4568
Review: simplify `defined_symbol` in sdfg_builder
romanc Mar 13, 2025
d3a1211
Review: pull names out of mapped_access_iterator
romanc Mar 13, 2025
e8b229f
Revert "Review: pull names out of mapped_access_iterator"
romanc Mar 13, 2025
c928df6
Review: pull names out of mapped_access_iterator (2)
romanc Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder
from gt4py.cartesian.gtc.dace.transformations import (
InlineThreadLocalTransients,
NoEmptyEdgeTrivialMapElimination,
nest_sequential_map_scopes,
)
Expand Down Expand Up @@ -173,7 +172,8 @@ def _post_expand_transformations(sdfg: dace.SDFG):
if node.schedule == dace.ScheduleType.CPU_Multicore and len(node.range) <= 1:
node.schedule = dace.ScheduleType.Sequential

sdfg.apply_transformations_repeated(InlineThreadLocalTransients, validate=False)
# To be re-evaluated with https://github.com/GridTools/gt4py/issues/1896
# sdfg.apply_transformations_repeated(InlineThreadLocalTransients, validate=False) # noqa: ERA001
sdfg.simplify(validate=False)
nest_sequential_map_scopes(sdfg)
for sd in sdfg.all_sdfgs_recursive():
Expand Down
140 changes: 128 additions & 12 deletions src/gt4py/cartesian/gtc/dace/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from gt4py import eve
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.common import LocNode
from gt4py.cartesian.gtc.dace import prefix
from gt4py.cartesian.gtc.dace.symbol_utils import (
get_axis_bound_dace_symbol,
get_axis_bound_diff_str,
Expand Down Expand Up @@ -525,10 +526,6 @@ class FieldAccessInfo(eve.Node):
dynamic_access: bool = False
variable_offset_axes: List[Axis] = eve.field(default_factory=list)

@property
def is_dynamic(self) -> bool:
return self.dynamic_access or len(self.variable_offset_axes) > 0

def axes(self):
yield from self.grid_subset.axes()

Expand Down Expand Up @@ -713,7 +710,7 @@ def axes(self):

@property
def is_dynamic(self) -> bool:
return self.access_info.is_dynamic
return self.access_info.dynamic_access

def with_set_access_info(self, access_info: FieldAccessInfo) -> FieldDecl:
return FieldDecl(
Expand All @@ -730,7 +727,8 @@ class Literal(common.Literal, Expr):


class ScalarAccess(common.ScalarAccess, Expr):
pass
is_target: bool
original_name: Optional[str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this optional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use original_name as part of handling "local"1 temporaries that are potentially written in one Tasklet and later read in another one. Not every scalar access is a local temporary. For example, scalar stencil parameters "globally" (as in through the stencil) available and don't need/have an original_name.

Footnotes

  1. "local" as in local within the horizontal execution, which might be split (with this PR) in more than one Tasklet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope for the PR: this kinds of plead for more information on the access. Especially how and where they apply. Of course IR are suppose to be atomic information, but I wonder if a type defining the context of the access: temporary, parameter...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this should be revisited in a follow-up issue. It's more of a duck-tape solution than nice engineering.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a note (linking here) in #1898.



class VariableKOffset(common.VariableKOffset[Expr]):
Expand All @@ -744,7 +742,12 @@ def no_casts_in_offset_expression(self, _: datamodels.Attribute, expression: Exp


class IndexAccess(common.FieldAccess, Expr):
offset: Optional[Union[common.CartesianOffset, VariableKOffset]]
# ScalarAccess used for indirect addressing
offset: Optional[common.CartesianOffset | Literal | ScalarAccess | VariableKOffset]
is_target: bool

explicit_indices: Optional[list[Literal | ScalarAccess | VariableKOffset]] = None
"""Used to access as a full field with explicit indices"""


class AssignStmt(common.AssignStmt[Union[IndexAccess, ScalarAccess], Expr], Stmt):
Expand Down Expand Up @@ -842,33 +845,146 @@ class IterationNode(eve.Node):
grid_subset: GridSubset


class Condition(eve.Node):
condition: Tasklet
true_states: list[ComputationState | Condition | WhileLoop]

# Currently unused due to how `if` statements are parsed in `gtir_to_oir`, see
# https://github.com/GridTools/gt4py/issues/1898
false_states: list[ComputationState | Condition | WhileLoop] = eve.field(default_factory=list)

@datamodels.validator("condition")
def condition_has_boolean_expression(
self, attribute: datamodels.Attribute, tasklet: Tasklet
) -> None:
assert isinstance(tasklet, Tasklet)
assert len(tasklet.stmts) == 1
assert isinstance(tasklet.stmts[0], AssignStmt)
assert isinstance(tasklet.stmts[0].left, ScalarAccess)
if tasklet.stmts[0].left.original_name is None:
raise ValueError(
f"Original node name not found for {tasklet.stmts[0].left.name}. DaCe IR error."
)
assert isinstance(tasklet.stmts[0].right, Expr)
if tasklet.stmts[0].right.dtype != common.DataType.BOOL:
raise ValueError("Condition must be a boolean expression.")


class Tasklet(ComputationNode, IterationNode, eve.SymbolTableTrait):
decls: List[LocalScalarDecl]
label: str
stmts: List[Stmt]
grid_subset: GridSubset = GridSubset.single_gridpoint()

@datamodels.validator("stmts")
def non_empty_list(self, attribute: datamodels.Attribute, v: list[Stmt]) -> None:
if len(v) < 1:
raise ValueError("Tasklet must contain at least one statement.")

@datamodels.validator("stmts")
def read_after_write(self, attribute: datamodels.Attribute, statements: list[Stmt]) -> None:
def _remove_prefix(name: eve.SymbolRef) -> str:
return name.removeprefix(prefix.TASKLET_OUT).removeprefix(prefix.TASKLET_IN)

class ReadAfterWriteChecker(eve.NodeVisitor):
def visit_IndexAccess(self, node: IndexAccess, writes: set[str]) -> None:
if node.is_target:
# Keep track of writes
writes.add(_remove_prefix(node.name))
return

# Check reads
if (
node.name.startswith(prefix.TASKLET_OUT)
and _remove_prefix(node.name) not in writes
):
raise ValueError(f"Reading undefined '{node.name}'. DaCe IR error.")

if _remove_prefix(node.name) in writes and not node.name.startswith(
prefix.TASKLET_OUT
):
raise ValueError(
f"Read after write of '{node.name}' not connected to out connector. DaCe IR error."
)

def visit_ScalarAccess(self, node: ScalarAccess, writes: set[str]) -> None:
# Handle stencil parameters differently because they are always available
if not node.name.startswith(prefix.TASKLET_IN) and not node.name.startswith(
prefix.TASKLET_OUT
):
return

# Keep track of writes
if node.is_target:
writes.add(_remove_prefix(node.name))
return

# Make sure we don't read uninitialized memory
if (
node.name.startswith(prefix.TASKLET_OUT)
and _remove_prefix(node.name) not in writes
):
raise ValueError(f"Reading undefined '{node.name}'. DaCe IR error.")

if _remove_prefix(node.name) in writes and not node.name.startswith(
prefix.TASKLET_OUT
):
raise ValueError(
f"Read after write of '{node.name}' not connected to out connector. DaCe IR error."
)

def visit_AssignStmt(self, node: AssignStmt, writes: Set[eve.SymbolRef]) -> None:
# Visiting order matters because `writes` must not contain the symbols from the left visit
self.visit(node.right, writes=writes)
self.visit(node.left, writes=writes)

writes: set[str] = set()
checker = ReadAfterWriteChecker()
for statement in statements:
checker.visit(statement, writes=writes)


class DomainMap(ComputationNode, IterationNode):
index_ranges: List[Range]
schedule: MapSchedule
computations: List[Union[DomainMap, NestedSDFG, Tasklet]]
computations: List[Union[Tasklet, DomainMap, NestedSDFG]]


class ComputationState(IterationNode):
computations: List[Union[DomainMap, Tasklet]]
computations: List[Union[Tasklet, DomainMap]]


class DomainLoop(ComputationNode, IterationNode):
axis: Axis
index_range: Range
loop_states: List[Union[ComputationState, DomainLoop]]
loop_states: list[ComputationState | Condition | DomainLoop | WhileLoop]


class WhileLoop(eve.Node):
condition: Tasklet
body: list[ComputationState | Condition | WhileLoop]

@datamodels.validator("condition")
def condition_has_boolean_expression(
self, attribute: datamodels.Attribute, tasklet: Tasklet
) -> None:
assert isinstance(tasklet, Tasklet)
assert len(tasklet.stmts) == 1
assert isinstance(tasklet.stmts[0], AssignStmt)
assert isinstance(tasklet.stmts[0].left, ScalarAccess)
if tasklet.stmts[0].left.original_name is None:
raise ValueError(
f"Original node name not found for {tasklet.stmts[0].left.name}. DaCe IR error."
)
assert isinstance(tasklet.stmts[0].right, Expr)
if tasklet.stmts[0].right.dtype != common.DataType.BOOL:
raise ValueError("Condition must be a boolean expression.")


class NestedSDFG(ComputationNode, eve.SymbolTableTrait):
label: eve.Coerced[eve.SymbolRef]
field_decls: List[FieldDecl]
symbol_decls: List[SymbolDecl]
states: List[Union[ComputationState, DomainLoop]]
states: list[ComputationState | Condition | DomainLoop | WhileLoop]


# There are circular type references with string placeholders. These statements let datamodels resolve those.
Expand Down
Loading