-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor[cartesian]: Dace backend: expose control flow #1894
refactor[cartesian]: Dace backend: expose control flow #1894
Conversation
31f1890
to
947c1a1
Compare
## Description In preparation for PR #1894, pull out some refactors and cleanups. Notable in this PR are the changes to `src/gt4py/cartesian/gtc/dace/oir_to_dace.py` - visit `stencil.vertical_loops` directly instead of calling `generic_visit` (simplification since there's nothing else to visit) - rename library nodes from `f"{sdfg_name}_computation_{id(node)}"` to `f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}"`. This adds a bit more information (because `sdfg_name` is the same for all library nodes) and thus simplifies debugging workflows. Related issue: GEOS-ESM/NDSL#53 ## Requirements - [x] All fixes and/or new features come with corresponding tests. covered by existing test suite - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
We don't need to replace visit_MaskStmt and visit_While because we won't have `if` statements or `while` loops in tasklet code anymore. These constructs are represented directly in DaCe (allowing DaCe to do more of its magic).
This will be used in the gt4py/dace bridge only (for now). We'll have to move this concept up and have all the backends work with it.
Add support for adding conditions and while loops to the `SDFGContext`, a helper class managing state transitions when building the SDFG from the dace IR.
Add support for generating sdfg nodes from dcir.Condition and dcir.WhileLoop nodes.
Some of them might be fixed "the wrong way" ...
we should fix this at another level in the future
This seems to fix the variable shadowing problem in the newer version (without all the nexted sdfgs). This also seems to add support for variables declared inside if/else statements.
This still seems to fail consistently for while loops because they try to connect to the (duplicated) entry map. No clue where this is comming from (again now ...).
While loops can contain other while loops. This was forgotten and tests complained about it.
not needed anymore since we now properly separate read and write memlets not working either. needs fixes to how we setup the "conditional evaluation tasklet", i.e. we don't go through the "normal" ComputeState transformations and thus lack the updated node-ctx information. This is why it tries to connect to outer maps.
This still ends in a segfault - for whatever reason ... This seems to work now ... let's just see in ndsl.
if input and output connectors are called the same, e.g. in code like ```python level = 1 condition = True while condition: condition = level < 10 level = level + 1 ``` which translates to ```none Tasklet1: level = 1 condtion = True While condition: Tasklet2: condition = level < 10 level = level + 1 ``` then the sdfg complains because DaCe can't handle reading level and writing to it in the same tasklet. This creates and input connector for level and an output connector for level, which is ambiguous.
the ideas this time is to have different connector names and the "same name" outside in the things that are passed around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - will require a re-read with the documentation you are writing for a go-ahead
@@ -730,7 +727,8 @@ class Literal(common.Literal, Expr): | |||
|
|||
|
|||
class ScalarAccess(common.ScalarAccess, Expr): | |||
pass | |||
is_target: bool | |||
original_name: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of scope for the PR: this kinds of plead for more information on the access. Especially how and where they apply. Of course IR are suppose to be atomic information, but I wonder if a type
defining the context of the access: temporary, parameter...
- rename constants -> prefix - add dace's passthrough prefixes - update import style
Two things for reviewers
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some general comment or suggestions.
But I have some suggestions regarding DaCe, especially the usage of newer API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My 2 cents review.
assert isinstance(node.condition.stmts[0].left, dcir.ScalarAccess) | ||
condition_name = node.condition.stmts[0].left.original_name | ||
|
||
after_state = self.sdfg.add_state("while_after") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an idea, but please do as you prefer. Already on dace 1.0 it is possible to instantiate LoopRegion
constructs, although it is not possible to apply SDFG transformations nor codegen from it (you need dace from main branch, for that). However, after the SDFG is built, it is possible to call dace.sdfg.utils.inline_loop_blocks(sdfg)
that will turn the LoopRegion
nodes into the equivalent state machines. In this way, you can prepare the SDFG for the upgrade to dace main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Eduardo. That is valuable feedback. I left a note in #1898 such that we don't forget it in the future.
inside_horizontal_region: bool = False, | ||
**kwargs: Any, | ||
) -> dcir.MaskStmt | dcir.Condition: | ||
if inside_horizontal_region: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a question, not a comment. I do not know this code base, I've only worked on gt4py-next. I've learned from my mistakes that lowering if-statements inside tasklets does not ensure exclusive branch execution: all inputs of a tasklet are evaluated, no matter the value of the if-condition inside the tasklet node. Randomly, this can result in correct code or not.
I guess the gt4py user code mostly uses dynamic K-offsets, which limits the vertical domain of intermediate fields. On the horizontal region, fields are always defined on the full domain. Is this the motivation behind treating the vertical region differently?
If what I wrote so far makes sense, my question is the following. Is it theoretically possible to write cartesian programs that use horizontal dynamic offsets, and would also require dcir.Condition
in the horizontal region?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've learned from my mistakes that lowering if-statements inside tasklets does not ensure exclusive branch execution: all inputs of a tasklet are evaluated, no matter the value of the if-condition inside the tasklet node. Randomly, this can result in correct code or not.
I don't quite follow. Let's assume we have the following stencil
def small_conditional(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(...):
if in_field < 4:
tmp = in_field
else:
tmp = in_field + 1
Are you saying both branches, if
and else
, are evaluated? That is not what we are observing and looking at the generated code, this translates to real if
statements in cpp code that work as I would expect them to. Can you elaborate on this?
Background: Before this PR, gt4py-cartesian would would have all conditionals inside Tasklets. I'm surprised to learn that this might pose a problem.
I guess the gt4py user code mostly uses dynamic K-offsets, which limits the vertical domain of intermediate fields. On the horizontal region, fields are always defined on the full domain. Is this the motivation behind treating the vertical region differently?
@FlorianDeconinck / @twicki can you make sense of that? I'm just reading words here without understanding ...
If what I wrote so far makes sense, my question is the following. Is it theoretically possible to write cartesian programs that use horizontal dynamic offsets, and would also require
dcir.Condition
in the horizontal region?
What I can say is that it is not possible to have dcir.Condition
s (or dcir.WhileLoop
s) inside horizontal regions. There might be dcir.MaskStmt
s (or dcir.While
loops) inside horizontal regions. In that case, we won't translate them to dcir.Condition
s / dcir.WhileLoop
s and codegen if
/while
statements inside the Tasklet (without exposing them to DaCe) just as we'd do before this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, only one of the if-banchesis is executed. What I meant is the following, consider:
def small_conditional(in_field1: FloatField, in_field2: FloatField, k: IntField):
with computation(PARALLEL), interval(...):
if k > 10:
tmp = in_field1
else:
tmp = in_field2
Both inputs to the tasklet (in_field1
and in_field2
) have to be defined for all k
values [0:N]
. That because the two dataflows that compute in_field1
and in_field2
are executed before evaluating the if-expression.
However, this case maybe is handled differently in cartesian.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
K-domain validation will indeed take care of that.
Thanks for all the feedback. I have addressed and/or answered all review comments. Please have a second look if you have strong opinions. @FlorianDeconinck you wanted to do a re-read once the docs are done. The docs are here. |
@edopao / @philip-paul-mueller : Thank you for engaging here despite this part of the code being further away than your current focus. Very much appreciated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - one ticket to log
@@ -730,7 +727,8 @@ class Literal(common.Literal, Expr): | |||
|
|||
|
|||
class ScalarAccess(common.ScalarAccess, Expr): | |||
pass | |||
is_target: bool | |||
original_name: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Log it
Description
This PR refactors the GT4Py/DaCe bridge to expose control flow elements (
if
statements andwhile
loops) to DaCe. Previously, the whole contents of a vertical loop was put in one big Tasklet. With this PR, that Tasklet is broken apart in case control flow is found such that control flow is visible in the SDFG. This allows DaCe to better analyze code and will be crucial in future (within the current milestone) performance optimization work.The main ideas in this PR are the following
oir.CodeBlock
to recursively break downoir.HorizontalExecution
s into smaller pieces that are either code flow or evaluated in (smaller) Tasklets.dcir.Condition
anddcir.WhileLoop
to represent if statements and while loops that are translated into SDFG states. We keep the currentdcir.MaskStmt
/dcir.While
for if statements / while loops inside horizontal regions, which aren't yet exposed to DaCe (see cartesian: expose HorizontalRegions to DaCe #1900).if
statements andwhile
loops in the state machine ofsdfg_builder.py
TaskletAccessInfoCollector
does this work for us, duplicating some logic inAccessInfoCollector
. A refactor task has been logged to fix/re-evaluate this later.This PR depends on the following (downstream) DaCe fixes
StateFusion
misses read-write conflict due to early return spcl/dace#1954which have been merged by now.
Follow-up issues
InlineThreadLocalTransients
in gt4py/dace bridge #1896Related issue: GEOS-ESM/NDSL#53
Requirements
Added new tests and increased coverage of horizontal regions with PRs test[cartesian]: Increased coverage for horizontal regions #1807 and tests[cartesian]: Increase horizontal region test coverage #1851.
Docs are in our knowledge base for now. Will be ported.