Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

One more refactoring in operator requirement fixer #1763

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions src/beanmachine/ppl/compiler/fix_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,11 @@ def _node_meets_requirement(self, node: bn.BMGNode, r: bt.Requirement) -> bool:
)
return self._type_meets_requirement(lattice_type, r)

def _meet_constant_requirement(
def _try_to_meet_constant_requirement(
self,
node: bn.ConstantNode,
requirement: bt.Requirement,
consumer: bn.BMGNode,
edge: str,
) -> bn.BMGNode:
) -> Optional[bn.BMGNode]:
# We have a constant node that either (1) is untyped, and therefore
# needs to be replaced by an equivalent typed node, or (2) is typed
# but is of the wrong type, and needs to be replaced by an equivalent
Expand Down Expand Up @@ -162,13 +160,26 @@ def _meet_constant_requirement(
result = self.bmg.add_constant_of_type(node.value, required_type)
assert self._node_meets_requirement(result, requirement)
return result
return None

def _meet_constant_requirement(
self,
node: bn.ConstantNode,
requirement: bt.Requirement,
consumer: bn.BMGNode,
edge: str,
) -> bn.BMGNode:

result = self._try_to_meet_constant_requirement(node, requirement)
if result is not None:
return result

# We cannot convert this node to any type that meets the requirement.
# Add an error.
self.errors.add_error(
Violation(
node,
it,
self._typer[node],
requirement,
consumer,
edge,
Expand Down Expand Up @@ -415,13 +426,13 @@ def _try_to_force_to_neg_real(self, node, requirement) -> Optional[bn.BMGNode]:

return self.bmg.add_to_negative_real(node)

def _meet_operator_requirement(
def _try_to_meet_operator_requirement(
self,
node: bn.OperatorNode,
requirement: bt.Requirement,
consumer: bn.BMGNode,
edge: str,
) -> bn.BMGNode:
) -> Optional[bn.BMGNode]:
# We should not have called this function if the input node already meets
# the requirement on the edge.

Expand Down Expand Up @@ -470,14 +481,30 @@ def _meet_operator_requirement(
if result is not None:
return result

node_type = self._typer[node]

result = self._try_to_force_to_neg_real(node, requirement)
if result is not None:
return result

# Those are the only techniques we have to make an operator meet a requirement.
# We have no way to make the conversion we need, so add an error.
# We couldn't meet the requirement.

return None

def _meet_operator_requirement(
self,
node: bn.OperatorNode,
requirement: bt.Requirement,
consumer: bn.BMGNode,
edge: str,
) -> bn.BMGNode:
assert not self._node_meets_requirement(node, requirement)
result = self._try_to_meet_operator_requirement(
node, requirement, consumer, edge
)
if result is not None:
return result

# We were unable to meet a requirement; add an error.
node_type = self._typer[node]
self.errors.add_error(
Violation(
node,
Expand Down