@@ -51,11 +51,11 @@ def tile_symbol(self) -> eve.SymbolRef:
51
51
return eve .SymbolRef ("__tile_" + self .lower ())
52
52
53
53
@staticmethod
54
- def dims_3d () -> Generator [" Axis" , None , None ]:
54
+ def dims_3d () -> Generator [Axis , None , None ]:
55
55
yield from [Axis .I , Axis .J , Axis .K ]
56
56
57
57
@staticmethod
58
- def dims_horizontal () -> Generator [" Axis" , None , None ]:
58
+ def dims_horizontal () -> Generator [Axis , None , None ]:
59
59
yield from [Axis .I , Axis .J ]
60
60
61
61
def to_idx (self ) -> int :
@@ -357,7 +357,7 @@ def free_symbols(self) -> Set[eve.SymbolRef]:
357
357
358
358
359
359
class GridSubset (eve .Node ):
360
- intervals : Dict [Axis , Union [DomainInterval , TileInterval , IndexWithExtent ]]
360
+ intervals : Dict [Axis , Union [DomainInterval , IndexWithExtent , TileInterval ]]
361
361
362
362
def __iter__ (self ):
363
363
for axis in Axis .dims_3d ():
@@ -429,10 +429,10 @@ def from_gt4py_extent(cls, extent: gt4py.cartesian.gtc.definitions.Extent):
429
429
@classmethod
430
430
def from_interval (
431
431
cls ,
432
- interval : Union [oir . Interval , TileInterval , DomainInterval , IndexWithExtent ],
432
+ interval : Union [DomainInterval , IndexWithExtent , oir . Interval , TileInterval ],
433
433
axis : Axis ,
434
434
):
435
- res_interval : Union [IndexWithExtent , TileInterval , DomainInterval ]
435
+ res_interval : Union [DomainInterval , IndexWithExtent , TileInterval ]
436
436
if isinstance (interval , (DomainInterval , oir .Interval )):
437
437
res_interval = DomainInterval (
438
438
start = AxisBound (
@@ -441,7 +441,7 @@ def from_interval(
441
441
end = AxisBound (level = interval .end .level , offset = interval .end .offset , axis = Axis .K ),
442
442
)
443
443
else :
444
- assert isinstance (interval , (TileInterval , IndexWithExtent ))
444
+ assert isinstance (interval , (IndexWithExtent , TileInterval ))
445
445
res_interval = interval
446
446
447
447
return cls (intervals = {axis : res_interval })
@@ -464,7 +464,7 @@ def full_domain(cls, axes=None):
464
464
return GridSubset (intervals = res_subsets )
465
465
466
466
def tile (self , tile_sizes : Dict [Axis , int ]):
467
- res_intervals : Dict [Axis , Union [DomainInterval , TileInterval , IndexWithExtent ]] = {}
467
+ res_intervals : Dict [Axis , Union [DomainInterval , IndexWithExtent , TileInterval ]] = {}
468
468
for axis , interval in self .intervals .items ():
469
469
if isinstance (interval , DomainInterval ) and axis in tile_sizes :
470
470
if axis == Axis .K :
@@ -505,15 +505,15 @@ def union(self, other):
505
505
intervals [axis ] = interval1 .union (interval2 )
506
506
else :
507
507
assert (
508
- isinstance (interval2 , (TileInterval , DomainInterval ))
509
- and isinstance (interval1 , (IndexWithExtent , DomainInterval ))
508
+ isinstance (interval2 , (DomainInterval , TileInterval ))
509
+ and isinstance (interval1 , (DomainInterval , IndexWithExtent ))
510
510
) or (
511
- isinstance (interval1 , (TileInterval , DomainInterval ))
511
+ isinstance (interval1 , (DomainInterval , TileInterval ))
512
512
and isinstance (interval2 , IndexWithExtent )
513
513
)
514
514
intervals [axis ] = (
515
515
interval1
516
- if isinstance (interval1 , (TileInterval , DomainInterval ))
516
+ if isinstance (interval1 , (DomainInterval , TileInterval ))
517
517
else interval2
518
518
)
519
519
return GridSubset (intervals = intervals )
@@ -747,7 +747,7 @@ class IndexAccess(common.FieldAccess, Expr):
747
747
offset : Optional [Union [common .CartesianOffset , VariableKOffset ]]
748
748
749
749
750
- class AssignStmt (common .AssignStmt [Union [ScalarAccess , IndexAccess ], Expr ], Stmt ):
750
+ class AssignStmt (common .AssignStmt [Union [IndexAccess , ScalarAccess ], Expr ], Stmt ):
751
751
_dtype_validation = common .assign_stmt_dtype_validation (strict = True )
752
752
753
753
@@ -851,14 +851,14 @@ class Tasklet(ComputationNode, IterationNode, eve.SymbolTableTrait):
851
851
class DomainMap (ComputationNode , IterationNode ):
852
852
index_ranges : List [Range ]
853
853
schedule : MapSchedule
854
- computations : List [Union [Tasklet , DomainMap , NestedSDFG ]]
854
+ computations : List [Union [DomainMap , NestedSDFG , Tasklet ]]
855
855
856
856
857
857
class ComputationState (IterationNode ):
858
- computations : List [Union [Tasklet , DomainMap ]]
858
+ computations : List [Union [DomainMap , Tasklet ]]
859
859
860
860
861
- class DomainLoop (IterationNode , ComputationNode ):
861
+ class DomainLoop (ComputationNode , IterationNode ):
862
862
axis : Axis
863
863
index_range : Range
864
864
loop_states : List [Union [ComputationState , DomainLoop ]]
@@ -868,7 +868,7 @@ class NestedSDFG(ComputationNode, eve.SymbolTableTrait):
868
868
label : eve .Coerced [eve .SymbolRef ]
869
869
field_decls : List [FieldDecl ]
870
870
symbol_decls : List [SymbolDecl ]
871
- states : List [Union [DomainLoop , ComputationState ]]
871
+ states : List [Union [ComputationState , DomainLoop ]]
872
872
873
873
874
874
# There are circular type references with string placeholders. These statements let datamodels resolve those.
0 commit comments