Skip to content

Commit 77cad7c

Browse files
feat[dace][next]: Fixing strides in optimization (#1782)
Added functionality to properly handle changes of strides. During the implementation of the scan we found that the strides were not handled properly. Most importantly a change on one level was not propagated into the next levels, i.e. they were still using the old strides. This PR Solves most of the problems, but there are still some issues that are unsolved: - Views are not adjusted yet (Fixed in [PR@1784](#1784)). - It is not properly checked if the symbols of the propagated strides are safe to introduce into the nested SDFG. The initial functionality of this PR was done by Edoardo Paone (@edopao). --------- Co-authored-by: edopao <[email protected]>
1 parent 06b398a commit 77cad7c

File tree

6 files changed

+1238
-26
lines changed

6 files changed

+1238
-26
lines changed

src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@
3535
gt_simplify,
3636
gt_substitute_compiletime_symbols,
3737
)
38-
from .strides import gt_change_transient_strides
38+
from .strides import (
39+
gt_change_transient_strides,
40+
gt_map_strides_to_dst_nested_sdfg,
41+
gt_map_strides_to_src_nested_sdfg,
42+
gt_propagate_strides_from_access_node,
43+
gt_propagate_strides_of,
44+
)
3945
from .util import gt_find_constant_arguments, gt_make_transients_persistent
4046

4147

@@ -59,6 +65,10 @@
5965
"gt_gpu_transformation",
6066
"gt_inline_nested_sdfg",
6167
"gt_make_transients_persistent",
68+
"gt_map_strides_to_dst_nested_sdfg",
69+
"gt_map_strides_to_src_nested_sdfg",
70+
"gt_propagate_strides_from_access_node",
71+
"gt_propagate_strides_of",
6272
"gt_reduce_distributed_buffering",
6373
"gt_set_gpu_blocksize",
6474
"gt_set_iteration_order",

src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def gt_gpu_transformation(
9595

9696
if try_removing_trivial_maps:
9797
# In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on
98-
# GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So
98+
# GPU. `sdfg.apply_gpu_transformations()` will wrap such Tasklets in a Map. So
9999
# we might end up with lots of these trivial Maps, each requiring a separate
100100
# kernel launch. To prevent this we will combine these trivial maps, if
101101
# possible, with their downstream maps.

src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def _perform_pointwise_test(
950950

951951
def apply(
952952
self,
953-
graph: dace.SDFGState | dace.SDFG,
953+
graph: dace.SDFGState,
954954
sdfg: dace.SDFG,
955955
) -> None:
956956
# Removal
@@ -971,6 +971,9 @@ def apply(
971971
tmp_out_subset = dace_subsets.Range.from_array(tmp_desc)
972972
assert glob_in_subset is not None
973973

974+
# Recursively visit the nested SDFGs for mapping of strides from inner to outer array
975+
gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac)
976+
974977
# We now remove the `tmp` node, and create a new connection between
975978
# the global node and the map exit.
976979
new_map_to_glob_edge = graph.add_edge(

0 commit comments

Comments
 (0)