From 4d4bb1ee5484a749fa09a4ad0754aec53ad03c08 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Fri, 28 Feb 2025 16:15:31 +0000 Subject: [PATCH] core: Make IRDLOperation results typed Now, defining a result with `resname = result_def(T)` will make `resname` have the type `SSAValue[T]`. This removes a lot of `cast` and `isa`/`isattr` in our codebase stack-info: PR: https://github.com/xdslproject/xdsl/pull/3991, branch: math-fehr/stack/6 --- docs/Toy/toy/dialects/toy.py | 4 +-- docs/Toy/toy/rewrites/lower_toy_affine.py | 2 +- docs/Toy/toy/rewrites/optimise_toy.py | 5 +-- xdsl/backend/csl/print_csl.py | 8 ++--- .../riscv/lowering/convert_memref_to_riscv.py | 2 +- xdsl/dialects/csl/csl.py | 19 +++-------- xdsl/dialects/memref.py | 6 ---- xdsl/dialects/riscv.py | 4 --- xdsl/dialects/stablehlo.py | 2 +- xdsl/dialects/stencil.py | 10 +++--- xdsl/dialects/tensor.py | 10 +++--- xdsl/interpreters/memref.py | 3 +- xdsl/interpreters/riscv.py | 3 -- xdsl/interpreters/tensor.py | 2 -- xdsl/irdl/operations.py | 26 ++++++++------ .../canonicalization_patterns/riscv.py | 34 ++++++------------- .../canonicalization_patterns/stencil.py | 6 ++-- .../convert_stencil_to_csl_stencil.py | 2 +- .../hls_convert_stencil_to_ll_mlir.py | 2 +- .../stencil_tensorize_z_dimension.py | 7 ++-- xdsl/transforms/lower_riscv_func.py | 3 +- xdsl/transforms/memref_to_dsd.py | 3 +- .../shape_inference_patterns/dmp.py | 2 -- .../shape_inference_patterns/stencil.py | 18 +++------- xdsl/transforms/stencil_inlining.py | 4 +-- 25 files changed, 64 insertions(+), 123 deletions(-) diff --git a/docs/Toy/toy/dialects/toy.py b/docs/Toy/toy/dialects/toy.py index 779ac6e7c2..588ce98f7b 100644 --- a/docs/Toy/toy/dialects/toy.py +++ b/docs/Toy/toy/dialects/toy.py @@ -133,7 +133,7 @@ def infer_shape(cls, op: Operation) -> None: if isinstance(op_res_type := op.res.type, TensorType): assert op_lhs_type.get_shape() == op_res_type.get_shape() else: - op.res.type = op.lhs.type + op.res.type = op_lhs_type @irdl_op_definition @@ -312,7 +312,7 @@ def infer_shape(cls, op: Operation) -> None: if isinstance(op_res_type := op.res.type, TensorType): assert op_lhs_type.get_shape() == op_res_type.get_shape() else: - op.res.type = op.lhs.type + op.res.type = op_lhs_type @irdl_op_definition diff --git a/docs/Toy/toy/rewrites/lower_toy_affine.py b/docs/Toy/toy/rewrites/lower_toy_affine.py index 0906a73e48..1a59d0b2e1 100644 --- a/docs/Toy/toy/rewrites/lower_toy_affine.py +++ b/docs/Toy/toy/rewrites/lower_toy_affine.py @@ -356,7 +356,7 @@ def match_and_rewrite(self, op: toy.ConstantOp, rewriter: PatternRewriter): # When lowering the constant operation, we allocate and assign the constant # values to a corresponding memref allocation. - tensor_type = cast(toy.TensorTypeF64, op.res.type) + tensor_type = op.res.type memref_type = convert_tensor_to_memref(tensor_type) alloc = insert_alloc_and_dealloc(memref_type, op, rewriter) diff --git a/docs/Toy/toy/rewrites/optimise_toy.py b/docs/Toy/toy/rewrites/optimise_toy.py index 7bfbdb18e8..27fda75322 100644 --- a/docs/Toy/toy/rewrites/optimise_toy.py +++ b/docs/Toy/toy/rewrites/optimise_toy.py @@ -1,5 +1,3 @@ -from typing import cast - from xdsl.dialects.builtin import ( DenseIntOrFPElementsAttr, ) @@ -51,8 +49,7 @@ def match_and_rewrite(self, op: ReshapeOp, rewriter: PatternRewriter): # Input defined by another transpose? If not, no match. return - t = cast(TensorTypeF64, op.res.type) - new_op = ReshapeOp.from_input_and_type(reshape_input_op.arg, t) + new_op = ReshapeOp.from_input_and_type(reshape_input_op.arg, op.res.type) rewriter.replace_matched_op(new_op) diff --git a/xdsl/backend/csl/print_csl.py b/xdsl/backend/csl/print_csl.py index 814dac30cb..6b613441d1 100644 --- a/xdsl/backend/csl/print_csl.py +++ b/xdsl/backend/csl/print_csl.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from contextlib import contextmanager from dataclasses import dataclass, field -from typing import IO, Literal, cast +from typing import IO, Literal from xdsl.dialects import arith, csl, memref, scf from xdsl.dialects.builtin import ( @@ -682,13 +682,11 @@ def print_block(self, body: Block): self.variables[res] = f"({arr_name}[{idx_args}])" case csl.AddressOfOp(value=val, res=res): val_name = self._get_variable_name_for(val) - ty = cast(csl.PtrType, res.type) - use = self._var_use(res, ty.constness.data.value) + use = self._var_use(res, res.type.constness.data.value) self.print(f"{use} = &{val_name};") case csl.AddressOfFnOp(fn_name=name, res=res): - ty = cast(csl.PtrType, res.type) - use = self._var_use(res, ty.constness.data.value) + use = self._var_use(res, res.type.constness.data.value) self.print(f"{use} = &{name.string_value()};") case csl.DirectionOp(dir=d, res=res): self._print_or_promote_to_inline_expr(res, str.upper(d.data)) diff --git a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py index 379c5b7a10..cfad661704 100644 --- a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py +++ b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py @@ -341,7 +341,7 @@ def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter): source_type = source.type assert isinstance(source_type, MemRefType) source_type = cast(MemRefType[Attribute], source_type) - result_type = cast(MemRefType[Attribute], result.type) + result_type = result.type result_layout_attr = result_type.layout if isinstance(result_layout_attr, NoneAttr): diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index e8a81296a2..defb2d9cd8 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -1147,8 +1147,6 @@ class GetMemDsdOp(_GetDsdOp): ) def verify_(self) -> None: - if not isinstance(self.result.type, DsdType): - raise VerifyException("DSD type is not DsdType") if self.result.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd]: raise VerifyException("DSD type must be memory DSD") if self.result.type.data == DsdKind.mem1d_dsd and len(self.sizes) != 1: @@ -1190,8 +1188,6 @@ class GetFabDsdOp(_GetDsdOp): wavelet_index_offset = opt_prop_def(BoolAttr) def verify_(self) -> None: - if not isinstance(self.result.type, DsdType): - raise VerifyException("DSD type is not DsdType") if self.result.type.data not in [DsdKind.fabin_dsd, DsdKind.fabout_dsd]: raise VerifyException("DSD type must be fabric DSD") if len(self.sizes) != 1: @@ -1226,8 +1222,7 @@ class SetDsdBaseAddrOp(IRDLOperation): def verify_(self) -> None: if ( - not isinstance(self.result.type, DsdType) - or not isinstance(self.op.type, DsdType) + not isinstance(self.op.type, DsdType) or self.result.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd] or self.op.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd] ): @@ -1266,8 +1261,7 @@ class IncrementDsdOffsetOp(IRDLOperation): def verify_(self) -> None: if ( - not isinstance(self.result.type, DsdType) - or not isinstance(self.op.type, DsdType) + not isinstance(self.op.type, DsdType) or self.result.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd] or self.op.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd] ): @@ -1293,8 +1287,7 @@ class SetDsdLengthOp(IRDLOperation): def verify_(self) -> None: if ( - not isinstance(self.result.type, DsdType) - or not isinstance(self.op.type, DsdType) + not isinstance(self.op.type, DsdType) or self.result.type.data == DsdKind.mem4d_dsd ): raise VerifyException( @@ -1321,8 +1314,7 @@ class SetDsdStrideOp(IRDLOperation): def verify_(self) -> None: if ( - not isinstance(self.result.type, DsdType) - or not isinstance(self.op.type, DsdType) + not isinstance(self.op.type, DsdType) or self.result.type.data != DsdKind.mem1d_dsd ): raise VerifyException(f"{self.name} can only operate on mem1d_dsd type") @@ -1926,9 +1918,6 @@ def _verify_memref_addr(self, val_ty: MemRefType[Attribute], res_ty: PtrType): ) def verify_(self) -> None: - if not isinstance(self.res.type, PtrType): - raise VerifyException("Result type must be a pointer") - val_ty = self.value.type res_ty = self.res.type if isa(val_ty, MemRefType[Attribute]): diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 20b02b3cb5..34cf068693 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -206,9 +206,6 @@ def get( def verify_(self) -> None: memref_type = self.memref.type - if not isinstance(memref_type, MemRefType): - raise VerifyException("expected result to be a memref") - memref_type = cast(MemRefType[Attribute], memref_type) dyn_dims = [x for x in memref_type.shape.data if x.data == -1] if len(dyn_dims) != len(self.dynamic_sizes): @@ -345,9 +342,6 @@ def get( def verify_(self) -> None: memref_type = self.memref.type - if not isinstance(memref_type, MemRefType): - raise VerifyException("expected result to be a memref") - memref_type = cast(MemRefType[Attribute], memref_type) dyn_dims = [x for x in memref_type.shape.data if x.data == -1] if len(dyn_dims) != len(self.dynamic_sizes): diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py index e65d44f869..e492b3d9ea 100644 --- a/xdsl/dialects/riscv.py +++ b/xdsl/dialects/riscv.py @@ -1184,8 +1184,6 @@ def __init__( def verify_(self) -> None: if not self.writeonly: return - if not isinstance(self.rd.type, IntRegisterType): - return if self.rd.type.is_allocated and self.rd.type != Registers.ZERO: raise VerifyException( "When in 'writeonly' mode, destination must be register x0 (a.k.a. 'zero'), " @@ -1330,8 +1328,6 @@ def __init__( def verify_(self) -> None: if self.writeonly is None: return - if not isinstance(self.rd.type, IntRegisterType): - return if self.rd.type.is_allocated and self.rd.type != Registers.ZERO: raise VerifyException( "When in 'writeonly' mode, destination must be register x0 (a.k.a. 'zero'), " diff --git a/xdsl/dialects/stablehlo.py b/xdsl/dialects/stablehlo.py index b123df544c..c6482b826f 100644 --- a/xdsl/dialects/stablehlo.py +++ b/xdsl/dialects/stablehlo.py @@ -454,7 +454,7 @@ def get_permutation(self) -> tuple[int, ...]: def verify_(self) -> None: # Operand and result types are checked before the custom `verify_` o_type = cast(TensorType[Attribute], self.operand.type) - r_type = cast(TensorType[Attribute], self.result.type) + r_type = self.result.type o_shape = o_type.get_shape() r_shape = r_type.get_shape() diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index 22af55589f..e1aac6b8a0 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -615,9 +615,9 @@ def verify_(self) -> None: "buffer-semantic destination operands." ) if len(self.res) > 0: - res_type = cast(TempType[Attribute], self.res[0].type) + res_type = self.res[0].type for other in self.res[1:]: - other = cast(TempType[Attribute], other.type) + other = other.type if res_type.bounds != other.bounds: raise VerifyException( "Expected all output types bounds to be equals." @@ -674,7 +674,7 @@ def get_bounds(self): return self.bounds else: assert self.res - res_type = cast(TempType[Attribute], self.res[0].type) + res_type = self.res[0].type return res_type.bounds @@ -1506,9 +1506,7 @@ def verify_(self) -> None: types = [ot.elem if isinstance(ot, ResultType) else ot for ot in self.arg.types] apply = cast(ApplyOp, self.parent_op()) if len(apply.res) > 0: - res_types = [ - cast(TempType[Attribute], r.type).element_type for r in apply.res - ] + res_types = [r.type.element_type for r in apply.res] else: res_types = [ cast(FieldType[Attribute], o.type).element_type for o in apply.dest diff --git a/xdsl/dialects/tensor.py b/xdsl/dialects/tensor.py index 61aa793272..4c010c7d55 100644 --- a/xdsl/dialects/tensor.py +++ b/xdsl/dialects/tensor.py @@ -257,18 +257,16 @@ def parse(cls, parser: Parser) -> Self: return reshape def verify_(self) -> None: - if ( - not isinstance(source_type := self.source.type, TensorType) - or not isinstance(shape_type := self.shape.type, TensorType) - or not isinstance(res_type := self.result.type, TensorType) - ): + if not isinstance( + source_type := self.source.type, TensorType + ) or not isinstance(shape_type := self.shape.type, TensorType): raise ValueError( "tensor elementwise operation operands and result must be of type TensorType" ) source_type = cast(TensorType[Attribute], source_type) shape_type = cast(TensorType[Attribute], shape_type) - res_type = cast(TensorType[Attribute], res_type) + res_type = self.result.type if source_type.element_type != res_type.element_type: raise VerifyException( diff --git a/xdsl/interpreters/memref.py b/xdsl/interpreters/memref.py index 745a356bc3..d7ccc2107e 100644 --- a/xdsl/interpreters/memref.py +++ b/xdsl/interpreters/memref.py @@ -12,7 +12,6 @@ from xdsl.interpreters.builtin import xtype_for_el_type from xdsl.interpreters.shaped_array import ShapedArray from xdsl.interpreters.utils.ptr import TypedPtr -from xdsl.ir import Attribute from xdsl.traits import SymbolTable @@ -22,7 +21,7 @@ class MemRefFunctions(InterpreterFunctions): def run_alloc( self, interpreter: Interpreter, op: memref.AllocOp, args: PythonValues ) -> PythonValues: - memref_type = cast(memref.MemRefType[Attribute], op.memref.type) + memref_type = op.memref.type shape = memref_type.get_shape() size = prod(shape) diff --git a/xdsl/interpreters/riscv.py b/xdsl/interpreters/riscv.py index c940f341da..ca89a1e5d2 100644 --- a/xdsl/interpreters/riscv.py +++ b/xdsl/interpreters/riscv.py @@ -584,9 +584,6 @@ def run_get_register( ) -> PythonValues: attr = op.res.type - if not isinstance(attr, riscv.RISCVRegisterType): - raise InterpretationError(f"Unexpected type {attr}, expected register type") - if not attr.is_allocated: raise InterpretationError( f"Cannot get value for unallocated register {attr}" diff --git a/xdsl/interpreters/tensor.py b/xdsl/interpreters/tensor.py index 58bbcfd7ca..a5d14a5f37 100644 --- a/xdsl/interpreters/tensor.py +++ b/xdsl/interpreters/tensor.py @@ -12,7 +12,6 @@ from xdsl.interpreters.builtin import xtype_for_el_type from xdsl.interpreters.shaped_array import ShapedArray from xdsl.interpreters.utils.ptr import TypedPtr -from xdsl.ir import Attribute from xdsl.utils.exceptions import InterpretationError @@ -24,7 +23,6 @@ def run_empty( ) -> tuple[Any, ...]: result_type = op.tensor.type assert isinstance(result_type, TensorType) - result_type = cast(TensorType[Attribute], result_type) result_shape = list(result_type.get_shape()) xtype = xtype_for_el_type(result_type.element_type, interpreter.index_bitwidth) return ( diff --git a/xdsl/irdl/operations.py b/xdsl/irdl/operations.py index be8c193919..d9fba3bbd9 100644 --- a/xdsl/irdl/operations.py +++ b/xdsl/irdl/operations.py @@ -400,7 +400,7 @@ def __init__( self.constr = range_constr_coercion(attr) -class VarOpResult(tuple[OpResult, ...]): +class VarOpResult(Generic[AttributeInvT], tuple[OpResult[AttributeInvT], ...]): @property def types(self): return tuple(r.type for r in self) @@ -411,7 +411,7 @@ class OptResultDef(VarResultDef, OptionalDef): """An IRDL optional result definition.""" -OptOpResult: TypeAlias = OpResult | None +OptOpResult: TypeAlias = OpResult[AttributeInvT] | None @dataclass(init=True) @@ -596,42 +596,46 @@ class _SuccessorFieldDef(_OpDefField[SuccessorDef]): def result_def( - constraint: IRDLAttrConstraint = Attribute, + constraint: IRDLGenericAttrConstraint[AttributeInvT] = Attribute, *, default: None = None, resolver: None = None, init: Literal[False] = False, -) -> OpResult: +) -> OpResult[AttributeInvT]: """ Defines a result of an operation. """ - return cast(OpResult, _ResultFieldDef(ResultDef, constraint)) + return cast(OpResult[AttributeInvT], _ResultFieldDef(ResultDef, constraint)) def var_result_def( - constraint: RangeConstraint | IRDLAttrConstraint = Attribute, + constraint: ( + GenericRangeConstraint[AttributeInvT] | IRDLGenericAttrConstraint[AttributeInvT] + ) = Attribute, *, default: None = None, resolver: None = None, init: Literal[False] = False, -) -> VarOpResult: +) -> VarOpResult[AttributeInvT]: """ Defines a variadic result of an operation. """ - return cast(VarOpResult, _ResultFieldDef(VarResultDef, constraint)) + return cast(VarOpResult[AttributeInvT], _ResultFieldDef(VarResultDef, constraint)) def opt_result_def( - constraint: RangeConstraint | IRDLAttrConstraint = Attribute, + constraint: ( + GenericRangeConstraint[AttributeInvT] | IRDLGenericAttrConstraint[AttributeInvT] + ) = Attribute, *, default: None = None, resolver: None = None, init: Literal[False] = False, -) -> OptOpResult: +) -> OptOpResult[AttributeInvT]: """ Defines an optional result of an operation. """ - return cast(OptOpResult, _ResultFieldDef(OptResultDef, constraint)) + return cast(OptOpResult[AttributeInvT], _ResultFieldDef(OptResultDef, constraint)) def prop_def( diff --git a/xdsl/transforms/canonicalization_patterns/riscv.py b/xdsl/transforms/canonicalization_patterns/riscv.py index 519c5b0813..52cd216456 100644 --- a/xdsl/transforms/canonicalization_patterns/riscv.py +++ b/xdsl/transforms/canonicalization_patterns/riscv.py @@ -14,33 +14,21 @@ class RemoveRedundantMv(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: riscv.MVOp, rewriter: PatternRewriter) -> None: - if ( - op.rd.type == op.rs.type - and isinstance(op.rd.type, riscv.RISCVRegisterType) - and op.rd.type.is_allocated - ): + if op.rd.type == op.rs.type and op.rd.type.is_allocated: rewriter.replace_matched_op([], [op.rs]) class RemoveRedundantFMv(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: riscv.FMVOp, rewriter: PatternRewriter) -> None: - if ( - op.rd.type == op.rs.type - and isinstance(op.rd.type, riscv.RISCVRegisterType) - and op.rd.type.is_allocated - ): + if op.rd.type == op.rs.type and op.rd.type.is_allocated: rewriter.replace_matched_op([], [op.rs]) class RemoveRedundantFMvD(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: riscv.FMvDOp, rewriter: PatternRewriter) -> None: - if ( - op.rd.type == op.rs.type - and isinstance(op.rd.type, riscv.RISCVRegisterType) - and op.rd.type.is_allocated - ): + if op.rd.type == op.rs.type and op.rd.type.is_allocated: rewriter.replace_matched_op([], [op.rs]) @@ -114,7 +102,7 @@ class AddImmediateZero(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: riscv.AddiOp, rewriter: PatternRewriter) -> None: if isinstance(op.immediate, IntegerAttr) and op.immediate.value.data == 0: - rd = cast(riscv.IntRegisterType, op.rd.type) + rd = op.rd.type rewriter.replace_matched_op(riscv.MVOp(op.rs1, rd=rd)) @@ -124,7 +112,7 @@ def match_and_rewrite(self, op: riscv.AddiOp, rewriter: PatternRewriter) -> None if (rs1 := get_constant_value(op.rs1)) is not None and isinstance( op.immediate, IntegerAttr ): - rd = cast(riscv.IntRegisterType, op.rd.type) + rd = op.rd.type rewriter.replace_matched_op( riscv.LiOp( rs1.value.data + op.immediate.value.data, @@ -196,7 +184,7 @@ def match_and_rewrite(self, op: riscv.SlliOp, rewriter: PatternRewriter) -> None and isinstance(op.rs1.op.immediate, IntegerAttr) and isinstance(op.immediate, IntegerAttr) ): - rd = cast(riscv.IntRegisterType, op.rd.type) + rd = op.rd.type rewriter.replace_matched_op( riscv.LiOp( op.rs1.op.immediate.value.data << op.immediate.value.data, rd=rd @@ -213,7 +201,7 @@ def match_and_rewrite(self, op: riscv.LwOp, rewriter: PatternRewriter) -> None: and isinstance(op.rs1.op.immediate, IntegerAttr) and isinstance(op.immediate, IntegerAttr) ): - rd = cast(riscv.IntRegisterType, op.rd.type) + rd = op.rd.type rewriter.replace_matched_op( riscv.LwOp( op.rs1.op.rs1, @@ -251,7 +239,7 @@ def match_and_rewrite(self, op: riscv.FLwOp, rewriter: PatternRewriter) -> None: and isinstance(op.rs1.op.immediate, IntegerAttr) and isinstance(op.immediate, IntegerAttr) ): - rd = cast(riscv.FloatRegisterType, op.rd.type) + rd = op.rd.type rewriter.replace_matched_op( riscv.FLwOp( op.rs1.op.rs1, @@ -289,7 +277,7 @@ def match_and_rewrite(self, op: riscv.FLdOp, rewriter: PatternRewriter) -> None: and isinstance(op.rs1.op.immediate, IntegerAttr) and isinstance(op.immediate, IntegerAttr) ): - rd = cast(riscv.FloatRegisterType, op.rd.type) + rd = op.rd.type rewriter.replace_matched_op( riscv.FLdOp( op.rs1.op.rs1, @@ -368,7 +356,7 @@ def match_and_rewrite(self, op: riscv.FAddDOp, rewriter: PatternRewriter) -> Non else: return - rd = cast(riscv.FloatRegisterType, op.rd.type) + rd = op.rd.type rewriter.replace_matched_op( riscv.FMAddDOp( mul.rs1, @@ -477,7 +465,7 @@ def match_and_rewrite(self, op: riscv.LiOp, rewriter: PatternRewriter) -> None: if not (isinstance(op.immediate, IntegerAttr) and op.immediate.value.data == 0): return - rd = cast(riscv.IntRegisterType, op.rd.type) + rd = op.rd.type if rd == riscv.Registers.ZERO: rewriter.replace_matched_op(riscv.GetRegisterOp(riscv.Registers.ZERO)) else: diff --git a/xdsl/transforms/canonicalization_patterns/stencil.py b/xdsl/transforms/canonicalization_patterns/stencil.py index 58fe67f0a5..4fc9677ebe 100644 --- a/xdsl/transforms/canonicalization_patterns/stencil.py +++ b/xdsl/transforms/canonicalization_patterns/stencil.py @@ -1,7 +1,7 @@ from typing import cast from xdsl.dialects import stencil -from xdsl.ir import Attribute, Block, Region, SSAValue +from xdsl.ir import Block, Region, SSAValue from xdsl.pattern_rewriter import ( PatternRewriter, RewritePattern, @@ -65,7 +65,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N new = stencil.ApplyOp.get( operands, block := Block(arg_types=bbargs_type), - [cast(stencil.TempType[Attribute], r.type) for r in op.res], + [r.type for r in op.res], ) rewriter.inline_block(op.region.block, InsertPoint.at_start(block), block.args) @@ -98,7 +98,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N new = stencil.ApplyOp.build( operands=[op.args, op.dest], regions=[Region(block)], - result_types=[[cast(stencil.TempType[Attribute], r.type) for r in results]], + result_types=[[r.type for r in results]], properties=op.properties.copy(), attributes=op.attributes.copy(), ) diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index b852036db5..6eb02bfc59 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -312,7 +312,7 @@ def split_ops( for op in ops: if op in a: recv_chunk_ops.append(op) - if op in cnst_exports and isinstance(op, arith.ConstantOp): + if op in cnst_exports: # create a copy of the constant in the second region done_exch_ops.append(cln := op.clone()) # rewire ops of the second region to use the copied constant diff --git a/xdsl/transforms/experimental/hls_convert_stencil_to_ll_mlir.py b/xdsl/transforms/experimental/hls_convert_stencil_to_ll_mlir.py index 9ca88eaecb..05f41919ea 100644 --- a/xdsl/transforms/experimental/hls_convert_stencil_to_ll_mlir.py +++ b/xdsl/transforms/experimental/hls_convert_stencil_to_ll_mlir.py @@ -665,7 +665,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): self.module.body.block.add_op(get_number_chunks) self.module.body.block.add_op(get_chunk_size) - ndims: int = typing.cast(TempType[Attribute], op.res[0].type).get_num_dims() + ndims: int = op.res[0].type.get_num_dims() rewriter.erase_op(return_op) for new_return_component in new_return_component_lst: diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index e4bd4a7633..b0c660e7dc 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import TypeGuard, cast +from typing import TypeGuard from xdsl.context import Context from xdsl.dialects import builtin, varith @@ -245,10 +245,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): ApplyOp.get( op.args, body, - [ - stencil_temp_to_tensor(cast(TempType[Attribute], r.type)) - for r in op.res - ], + [stencil_temp_to_tensor(r.type) for r in op.res], ) ) diff --git a/xdsl/transforms/lower_riscv_func.py b/xdsl/transforms/lower_riscv_func.py index a02c16d2ec..85082bd78c 100644 --- a/xdsl/transforms/lower_riscv_func.py +++ b/xdsl/transforms/lower_riscv_func.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from typing import cast from xdsl.context import Context from xdsl.dialects import riscv, riscv_func @@ -55,7 +54,7 @@ def match_and_rewrite(self, op: riscv_func.SyscallOp, rewriter: PatternRewriter) ops.append(gr) res = gr.res - mv = riscv.MVOp(res, rd=cast(riscv.IntRegisterType, op.result.type)) + mv = riscv.MVOp(res, rd=op.result.type) ops.append(mv) new_results = mv.results diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py index a95e099c4c..6de8b7d496 100644 --- a/xdsl/transforms/memref_to_dsd.py +++ b/xdsl/transforms/memref_to_dsd.py @@ -371,8 +371,7 @@ class CslVarUpdate(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: csl.VariableOp, rewriter: PatternRewriter, /): if ( - not isinstance(op.res.type, csl.VarType) - or not isa(elem_t := op.res.type.get_element_type(), MemRefType[Attribute]) + not isa(elem_t := op.res.type.get_element_type(), MemRefType[Attribute]) or op.default ): return diff --git a/xdsl/transforms/shape_inference_patterns/dmp.py b/xdsl/transforms/shape_inference_patterns/dmp.py index 86d204db48..2863c9ff7e 100644 --- a/xdsl/transforms/shape_inference_patterns/dmp.py +++ b/xdsl/transforms/shape_inference_patterns/dmp.py @@ -19,8 +19,6 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter): if not op.swapped_values: return swap_t = op.swapped_values.type - if not isinstance(swap_t, stencil.TempType): - return if op.input_stencil.type != swap_t: op.input_stencil.type = swap_t rewrite.handle_operation_modification(op) diff --git a/xdsl/transforms/shape_inference_patterns/stencil.py b/xdsl/transforms/shape_inference_patterns/stencil.py index 410ac643c0..814fd104cc 100644 --- a/xdsl/transforms/shape_inference_patterns/stencil.py +++ b/xdsl/transforms/shape_inference_patterns/stencil.py @@ -35,7 +35,7 @@ def update_result_size( """ if isinstance(value.owner, ApplyOp): apply = value.owner - res_types = (cast(TempType[Attribute], r.type) for r in apply.res) + res_types = (r.type for r in apply.res) newsize = reduce( StencilBoundsAttr.union, ( @@ -48,9 +48,7 @@ def update_result_size( ), ) for res in apply.res: - newtype = TempType( - newsize, cast(TempType[Attribute], res.type).element_type - ) + newtype = TempType(newsize, res.type.element_type) if newtype != res.type: rewriter.modify_value_type(res, newtype) for use in res.uses: @@ -70,15 +68,9 @@ def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter, /): lowerext_res = op.results_[len(op.lower) : len(op.lower) + len(op.lowerext)] upperext_res = op.results_[len(op.lower) + len(op.lowerext) :] - combined_bounds = [ - cast(TempType[Attribute], r.type).bounds for r in combined_res - ] - lowerext_bounds = [ - cast(TempType[Attribute], r.type).bounds for r in lowerext_res - ] - upperext_bounds = [ - cast(TempType[Attribute], r.type).bounds for r in upperext_res - ] + combined_bounds = [r.type.bounds for r in combined_res] + lowerext_bounds = [r.type.bounds for r in lowerext_res] + upperext_bounds = [r.type.bounds for r in upperext_res] lower_bounds = list[StencilBoundsAttr | None]() upper_bounds = list[StencilBoundsAttr | None]() diff --git a/xdsl/transforms/stencil_inlining.py b/xdsl/transforms/stencil_inlining.py index 3f45c1ee66..1192944545 100644 --- a/xdsl/transforms/stencil_inlining.py +++ b/xdsl/transforms/stencil_inlining.py @@ -173,8 +173,8 @@ def redirect_store( ) # Update the bounds if needed - producer_bounds = cast(TempType[Attribute], producer.res[0].type).bounds - consumer_bounds = cast(TempType[Attribute], consumer.res[0].type).bounds + producer_bounds = producer.res[0].type.bounds + consumer_bounds = consumer.res[0].type.bounds if isinstance(producer_bounds, StencilBoundsAttr): new_bounds = producer_bounds | consumer_bounds elif isinstance(consumer_bounds, StencilBoundsAttr):