Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[cartesian]: Dace backend: expose control flow #1894

Merged

Conversation

romanc
Copy link
Contributor

@romanc romanc commented Mar 3, 2025

Description

This PR refactors the GT4Py/DaCe bridge to expose control flow elements (if statements and while loops) to DaCe. Previously, the whole contents of a vertical loop was put in one big Tasklet. With this PR, that Tasklet is broken apart in case control flow is found such that control flow is visible in the SDFG. This allows DaCe to better analyze code and will be crucial in future (within the current milestone) performance optimization work.

The main ideas in this PR are the following

  1. Introduce oir.CodeBlock to recursively break down oir.HorizontalExecutions into smaller pieces that are either code flow or evaluated in (smaller) Tasklets.
  2. Introduce dcir.Conditionand dcir.WhileLoop to represent if statements and while loops that are translated into SDFG states. We keep the current dcir.MaskStmt / dcir.While for if statements / while loops inside horizontal regions, which aren't yet exposed to DaCe (see cartesian: expose HorizontalRegions to DaCe #1900).
  3. Add support for if statements and while loops in the state machine of sdfg_builder.py
  4. We are breaking up vertical loops inside stencils in multiple Tasklets. It might thus happen that we write a "local" scalar in one Tasklet and read it in another Tasklet (downstream). We thus create output connectors for all scalar writes in a Tasklet and input connectors for all reads (unless previously written in the same Tasklet).
  5. Memlets can't be generated per horizontal execution anymore and need to be more fine grained. TaskletAccessInfoCollector does this work for us, duplicating some logic in AccessInfoCollector. A refactor task has been logged to fix/re-evaluate this later.

This PR depends on the following (downstream) DaCe fixes

which have been merged by now.

Follow-up issues

Related issue: GEOS-ESM/NDSL#53

Requirements

Sorry, something went wrong.

@romanc romanc force-pushed the romanc/bridge-explicit-indexing-with-linting branch from 31f1890 to 947c1a1 Compare March 4, 2025 07:55
romanc added a commit that referenced this pull request Mar 4, 2025
## Description

In preparation for PR #1894, pull
out some refactors and cleanups. Notable in this PR are the changes to
`src/gt4py/cartesian/gtc/dace/oir_to_dace.py`

- visit `stencil.vertical_loops` directly instead of calling
`generic_visit` (simplification since there's nothing else to visit)
- rename library nodes from `f"{sdfg_name}_computation_{id(node)}"` to
`f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}"`. This adds
a bit more information (because `sdfg_name` is the same for all library
nodes) and thus simplifies debugging workflows.

Related issue: GEOS-ESM/NDSL#53

## Requirements

- [x] All fixes and/or new features come with corresponding tests.
  covered by existing test suite
- [ ] Important design decisions have been documented in the appropriate
ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md)
folder.
  N/A

---------

Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
Roman Cattaneo added 25 commits March 4, 2025 17:22
We don't need to replace visit_MaskStmt and visit_While because we won't
have `if` statements or `while` loops in tasklet code anymore. These
constructs are represented directly in DaCe (allowing DaCe to do more of
its magic).
This will be used in the gt4py/dace bridge only (for now). We'll have to
move this concept up and have all the backends work with it.
Add support for adding conditions and while loops to the `SDFGContext`,
a helper class managing state transitions when building the SDFG from
the dace IR.
Add support for generating sdfg nodes from dcir.Condition and
dcir.WhileLoop nodes.
Some of them might be fixed "the wrong way" ...
we should fix this at another level in the future
This seems to fix the variable shadowing problem in the newer version
(without all the nexted sdfgs). This also seems to add support for
variables declared inside if/else statements.
This still seems to fail consistently for while loops because they try
to connect to the (duplicated) entry map. No clue where this is comming
from (again now ...).
While loops can contain other while loops. This was forgotten and tests
complained about it.
not needed anymore since we now properly separate read and write memlets

not working either. needs fixes to how we setup the "conditional
evaluation tasklet", i.e. we don't go through the "normal" ComputeState
transformations and thus lack the updated node-ctx information. This is
why it tries to connect to outer maps.
This still ends in a segfault - for whatever reason ...

This seems to work now ... let's just see in ndsl.
if input and output connectors are called the same, e.g. in code like

```python
level = 1
condition = True
while condition:
	condition = level < 10
	level = level + 1
```

which translates to

```none
Tasklet1:
	level = 1
	condtion = True

While condition:
	Tasklet2:
		condition = level < 10
		level = level + 1
```

then the sdfg complains because DaCe can't handle reading level and
writing to it in the same tasklet. This creates and input connector for
level and an output connector for level, which is ambiguous.
the ideas this time is to have different connector names and the "same
name" outside in the things that are passed around.
@romanc romanc requested a review from FlorianDeconinck March 10, 2025 10:23
Copy link
Contributor

@FlorianDeconinck FlorianDeconinck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - will require a re-read with the documentation you are writing for a go-ahead

@@ -730,7 +727,8 @@ class Literal(common.Literal, Expr):


class ScalarAccess(common.ScalarAccess, Expr):
pass
is_target: bool
original_name: Optional[str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope for the PR: this kinds of plead for more information on the access. Especially how and where they apply. Of course IR are suppose to be atomic information, but I wonder if a type defining the context of the access: temporary, parameter...

- rename constants -> prefix
- add dace's passthrough prefixes
- update import style
@romanc
Copy link
Contributor Author

romanc commented Mar 11, 2025

Two things for reviewers

  1. There's technical docs available in https://geos-esm.github.io/SMT-Nebulae/technical/backend/dace-bridge/. Feedback welcome if you have any.
  2. I pushed a small PR today to re-organize all the prefix strings that we use in the bridge.

Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some general comment or suggestions.
But I have some suggestions regarding DaCe, especially the usage of newer API.

Copy link
Contributor

@edopao edopao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My 2 cents review.

assert isinstance(node.condition.stmts[0].left, dcir.ScalarAccess)
condition_name = node.condition.stmts[0].left.original_name

after_state = self.sdfg.add_state("while_after")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an idea, but please do as you prefer. Already on dace 1.0 it is possible to instantiate LoopRegion constructs, although it is not possible to apply SDFG transformations nor codegen from it (you need dace from main branch, for that). However, after the SDFG is built, it is possible to call dace.sdfg.utils.inline_loop_blocks(sdfg) that will turn the LoopRegion nodes into the equivalent state machines. In this way, you can prepare the SDFG for the upgrade to dace main.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Eduardo. That is valuable feedback. I left a note in #1898 such that we don't forget it in the future.

inside_horizontal_region: bool = False,
**kwargs: Any,
) -> dcir.MaskStmt | dcir.Condition:
if inside_horizontal_region:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a question, not a comment. I do not know this code base, I've only worked on gt4py-next. I've learned from my mistakes that lowering if-statements inside tasklets does not ensure exclusive branch execution: all inputs of a tasklet are evaluated, no matter the value of the if-condition inside the tasklet node. Randomly, this can result in correct code or not.
I guess the gt4py user code mostly uses dynamic K-offsets, which limits the vertical domain of intermediate fields. On the horizontal region, fields are always defined on the full domain. Is this the motivation behind treating the vertical region differently?
If what I wrote so far makes sense, my question is the following. Is it theoretically possible to write cartesian programs that use horizontal dynamic offsets, and would also require dcir.Condition in the horizontal region?

Copy link
Contributor Author

@romanc romanc Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've learned from my mistakes that lowering if-statements inside tasklets does not ensure exclusive branch execution: all inputs of a tasklet are evaluated, no matter the value of the if-condition inside the tasklet node. Randomly, this can result in correct code or not.

I don't quite follow. Let's assume we have the following stencil

def small_conditional(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(...):
        if in_field < 4:
            tmp = in_field
        else:
            tmp = in_field + 1

Are you saying both branches, if and else, are evaluated? That is not what we are observing and looking at the generated code, this translates to real if statements in cpp code that work as I would expect them to. Can you elaborate on this?

Background: Before this PR, gt4py-cartesian would would have all conditionals inside Tasklets. I'm surprised to learn that this might pose a problem.

I guess the gt4py user code mostly uses dynamic K-offsets, which limits the vertical domain of intermediate fields. On the horizontal region, fields are always defined on the full domain. Is this the motivation behind treating the vertical region differently?

@FlorianDeconinck / @twicki can you make sense of that? I'm just reading words here without understanding ...

If what I wrote so far makes sense, my question is the following. Is it theoretically possible to write cartesian programs that use horizontal dynamic offsets, and would also require dcir.Condition in the horizontal region?

What I can say is that it is not possible to have dcir.Conditions (or dcir.WhileLoops) inside horizontal regions. There might be dcir.MaskStmts (or dcir.While loops) inside horizontal regions. In that case, we won't translate them to dcir.Conditions / dcir.WhileLoops and codegen if/while statements inside the Tasklet (without exposing them to DaCe) just as we'd do before this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, only one of the if-banchesis is executed. What I meant is the following, consider:

def small_conditional(in_field1: FloatField, in_field2: FloatField, k: IntField):
    with computation(PARALLEL), interval(...):
        if k > 10:
            tmp = in_field1
        else:
            tmp = in_field2

Both inputs to the tasklet (in_field1 and in_field2) have to be defined for all k values [0:N]. That because the two dataflows that compute in_field1 and in_field2 are executed before evaluating the if-expression.

However, this case maybe is handled differently in cartesian.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K-domain validation will indeed take care of that.

@romanc
Copy link
Contributor Author

romanc commented Mar 13, 2025

Thanks for all the feedback. I have addressed and/or answered all review comments. Please have a second look if you have strong opinions.

@FlorianDeconinck you wanted to do a re-read once the docs are done. The docs are here.

@FlorianDeconinck
Copy link
Contributor

@edopao / @philip-paul-mueller : Thank you for engaging here despite this part of the code being further away than your current focus. Very much appreciated.

Copy link
Contributor

@FlorianDeconinck FlorianDeconinck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - one ticket to log

@@ -730,7 +727,8 @@ class Literal(common.Literal, Expr):


class ScalarAccess(common.ScalarAccess, Expr):
pass
is_target: bool
original_name: Optional[str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log it

@romanc romanc merged commit e6b9398 into GridTools:main Mar 18, 2025
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
gt4py.cartesian Issues concerning the current version with support only for cartesian grids.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants