Skip to content

Commit

Permalink
core: Move SSAValue definitions around (NFC) (#3986)
Browse files Browse the repository at this point in the history
Move the definition of SSAValue after the definition of Attributes.
  • Loading branch information
math-fehr authored Feb 28, 2025
1 parent cca5ed5 commit 63c9c62
Showing 1 changed file with 183 additions and 183 deletions.
366 changes: 183 additions & 183 deletions xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,189 +78,6 @@ def split_name(name: str) -> tuple[str, str]:
raise ValueError(f"Invalid operation or attribute name {name}.") from e


@dataclass(frozen=True)
class Use:
"""The use of a SSA value."""

operation: Operation
"""The operation using the value."""

index: int
"""The index of the operand using the value in the operation."""


@dataclass(eq=False)
class IRWithUses(ABC):
"""IRNode which stores a list of its uses."""

uses: set[Use] = field(init=False, default_factory=set, repr=False)
"""All uses of the value."""

def add_use(self, use: Use):
"""Add a new use of the value."""
self.uses.add(use)

def remove_use(self, use: Use):
"""Remove a use of the value."""
assert use in self.uses, "use to be removed was not in use list"
self.uses.remove(use)


@dataclass(eq=False)
class SSAValue(IRWithUses, ABC):
"""
A reference to an SSA variable.
An SSA variable is either an operation result, or a basic block argument.
"""

type: Attribute
"""Each SSA variable is associated to a type."""

_name: str | None = field(init=False, default=None)

_name_regex: ClassVar[re.Pattern[str]] = re.compile(r"([A-Za-z_$.-][\w$.-]*)")

@property
@abstractmethod
def owner(self) -> Operation | Block:
"""
An SSA variable is either an operation result, or a basic block argument.
This property returns the Operation or Block that currently defines a specific value.
"""
pass

@property
def name_hint(self) -> str | None:
return self._name

@name_hint.setter
def name_hint(self, name: str | None):
# only allow valid names
if SSAValue.is_valid_name(name):
# Remove `_` followed by numbers at the end of the name
if name is not None:
r1 = re.compile(r"(_\d+)+$")
if match := r1.search(name):
name = name[: match.start()]
self._name = name
else:
raise ValueError(
"Invalid SSA Value name format!",
r"Make sure names contain only characters of [A-Za-z0-9_$.-] and don't start with a number!",
)

@classmethod
def is_valid_name(cls, name: str | None):
return name is None or cls._name_regex.fullmatch(name)

@staticmethod
def get(arg: SSAValue | Operation) -> SSAValue:
"Get a new SSAValue from either a SSAValue, or an operation with a single result."
match arg:
case SSAValue():
return arg
case Operation():
if len(arg.results) == 1:
return arg.results[0]
raise ValueError(
"SSAValue.build: expected operation with a single result."
)

def replace_by(self, value: SSAValue) -> None:
"""Replace the value by another value in all its uses."""
for use in self.uses.copy():
use.operation.operands[use.index] = value
# carry over name if possible
if value.name_hint is None:
value.name_hint = self.name_hint
assert not self.uses, "unexpected error in xdsl"

def replace_by_if(self, value: SSAValue, test: Callable[[Use], bool]):
"""
Replace the value by another value in all its uses that pass the given test
function.
"""
for use in self.uses.copy():
if test(use):
use.operation.operands[use.index] = value
# carry over name if possible
if value.name_hint is None:
value.name_hint = self.name_hint

def erase(self, safe_erase: bool = True) -> None:
"""
Erase the value.
If safe_erase is True, then check that no operations use the value anymore.
If safe_erase is False, then replace its uses by an ErasedSSAValue.
"""
if safe_erase and len(self.uses) != 0:
raise Exception(
"Attempting to delete SSA value that still has uses of result "
f"of operation:\n{self.owner}"
)
self.replace_by(ErasedSSAValue(self.type, self))

def __hash__(self):
"""
Make SSAValue hashable. Two SSA Values are never the same, therefore
the use of `id` is allowed here.
"""
return id(self)

def __eq__(self, other: object) -> bool:
return self is other


@dataclass(eq=False)
class OpResult(SSAValue):
"""A reference to an SSA variable defined by an operation result."""

op: Operation
"""The operation defining the variable."""

index: int
"""The index of the result in the defining operation."""

@property
def owner(self) -> Operation:
return self.op

def __repr__(self) -> str:
return f"<{self.__class__.__name__}[{self.type}] index: {self.index}, operation: {self.op.name}, uses: {len(self.uses)}>"


@dataclass(eq=False)
class BlockArgument(SSAValue):
"""A reference to an SSA variable defined by a basic block argument."""

block: Block
"""The block defining the variable."""

index: int
"""The index of the variable in the block arguments."""

@property
def owner(self) -> Block:
return self.block

def __repr__(self) -> str:
return f"<{self.__class__.__name__}[{self.type}] index: {self.index}, uses: {len(self.uses)}>"


@dataclass(eq=False)
class ErasedSSAValue(SSAValue):
"""
An erased SSA variable.
This is used during transformations when a SSA variable is destroyed but still used.
"""

old_value: SSAValue

@property
def owner(self) -> Operation | Block:
return self.old_value.owner


A = TypeVar("A", bound="Attribute")


Expand Down Expand Up @@ -632,6 +449,189 @@ def parse_with_type(
def print_without_type(self, printer: Printer): ...


@dataclass(frozen=True)
class Use:
"""The use of a SSA value."""

operation: Operation
"""The operation using the value."""

index: int
"""The index of the operand using the value in the operation."""


@dataclass(eq=False)
class IRWithUses(ABC):
"""IRNode which stores a list of its uses."""

uses: set[Use] = field(init=False, default_factory=set, repr=False)
"""All uses of the value."""

def add_use(self, use: Use):
"""Add a new use of the value."""
self.uses.add(use)

def remove_use(self, use: Use):
"""Remove a use of the value."""
assert use in self.uses, "use to be removed was not in use list"
self.uses.remove(use)


@dataclass(eq=False)
class SSAValue(IRWithUses, ABC):
"""
A reference to an SSA variable.
An SSA variable is either an operation result, or a basic block argument.
"""

type: Attribute
"""Each SSA variable is associated to a type."""

_name: str | None = field(init=False, default=None)

_name_regex: ClassVar[re.Pattern[str]] = re.compile(r"([A-Za-z_$.-][\w$.-]*)")

@property
@abstractmethod
def owner(self) -> Operation | Block:
"""
An SSA variable is either an operation result, or a basic block argument.
This property returns the Operation or Block that currently defines a specific value.
"""
pass

@property
def name_hint(self) -> str | None:
return self._name

@name_hint.setter
def name_hint(self, name: str | None):
# only allow valid names
if SSAValue.is_valid_name(name):
# Remove `_` followed by numbers at the end of the name
if name is not None:
r1 = re.compile(r"(_\d+)+$")
if match := r1.search(name):
name = name[: match.start()]
self._name = name
else:
raise ValueError(
"Invalid SSA Value name format!",
r"Make sure names contain only characters of [A-Za-z0-9_$.-] and don't start with a number!",
)

@classmethod
def is_valid_name(cls, name: str | None):
return name is None or cls._name_regex.fullmatch(name)

@staticmethod
def get(arg: SSAValue | Operation) -> SSAValue:
"Get a new SSAValue from either a SSAValue, or an operation with a single result."
match arg:
case SSAValue():
return arg
case Operation():
if len(arg.results) == 1:
return arg.results[0]
raise ValueError(
"SSAValue.build: expected operation with a single result."
)

def replace_by(self, value: SSAValue) -> None:
"""Replace the value by another value in all its uses."""
for use in self.uses.copy():
use.operation.operands[use.index] = value
# carry over name if possible
if value.name_hint is None:
value.name_hint = self.name_hint
assert not self.uses, "unexpected error in xdsl"

def replace_by_if(self, value: SSAValue, test: Callable[[Use], bool]):
"""
Replace the value by another value in all its uses that pass the given test
function.
"""
for use in self.uses.copy():
if test(use):
use.operation.operands[use.index] = value
# carry over name if possible
if value.name_hint is None:
value.name_hint = self.name_hint

def erase(self, safe_erase: bool = True) -> None:
"""
Erase the value.
If safe_erase is True, then check that no operations use the value anymore.
If safe_erase is False, then replace its uses by an ErasedSSAValue.
"""
if safe_erase and len(self.uses) != 0:
raise Exception(
"Attempting to delete SSA value that still has uses of result "
f"of operation:\n{self.owner}"
)
self.replace_by(ErasedSSAValue(self.type, self))

def __hash__(self):
"""
Make SSAValue hashable. Two SSA Values are never the same, therefore
the use of `id` is allowed here.
"""
return id(self)

def __eq__(self, other: object) -> bool:
return self is other


@dataclass(eq=False)
class OpResult(SSAValue):
"""A reference to an SSA variable defined by an operation result."""

op: Operation
"""The operation defining the variable."""

index: int
"""The index of the result in the defining operation."""

@property
def owner(self) -> Operation:
return self.op

def __repr__(self) -> str:
return f"<{self.__class__.__name__}[{self.type}] index: {self.index}, operation: {self.op.name}, uses: {len(self.uses)}>"


@dataclass(eq=False)
class BlockArgument(SSAValue):
"""A reference to an SSA variable defined by a basic block argument."""

block: Block
"""The block defining the variable."""

index: int
"""The index of the variable in the block arguments."""

@property
def owner(self) -> Block:
return self.block

def __repr__(self) -> str:
return f"<{self.__class__.__name__}[{self.type}] index: {self.index}, uses: {len(self.uses)}>"


@dataclass(eq=False)
class ErasedSSAValue(SSAValue):
"""
An erased SSA variable.
This is used during transformations when a SSA variable is destroyed but still used.
"""

old_value: SSAValue

@property
def owner(self) -> Operation | Block:
return self.old_value.owner


@dataclass(init=False)
class IRNode(ABC):
def is_ancestor(self, op: IRNode) -> bool:
Expand Down

0 comments on commit 63c9c62

Please sign in to comment.