Skip to content

Commit

Permalink
Fix link match disregard of deeper assignments
Browse files Browse the repository at this point in the history
  • Loading branch information
Andre Senna committed Aug 25, 2023
1 parent 3982c05 commit 63b9451
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions das/pattern_matcher/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def __repr__(self):

def get_handle(self, db: DBInterface) -> str:
if not self.handle:
target_handles = [target.get_handle(db) for target in self.targets]
target_handles = [target if type(target) is str else target.get_handle(db) for target in self.targets]
if any(handle is None for handle in target_handles):
return None
self.handle = db.get_link_handle(self.atom_type, target_handles)
Expand Down Expand Up @@ -499,6 +499,29 @@ def _typed_variable_matched(self, db: DBInterface, answer: PatternMatchingAnswer
first_typed_variable = False
return all(target.matched(db, answer) for target in self.targets)

def _apply_assignment(self, assignment: OrderedAssignment, db: DBInterface) -> str:
targets = []
for t in self.targets:
if type(t) is Node:
targets.append(t.get_handle(db))
elif type(t) is Link:
targets.append(t._apply_assignment(assignment, db))
elif type(t) is Variable or type(t) is TypedVariable:
targets.append(assignment.mapping[t.name])
link = Link(self.atom_type, targets, self.ordered)
return link.get_handle(db)

def apply_assignment(self, assignment: OrderedAssignment, db: DBInterface) -> str:
targets = []
for t in self.targets:
if type(t) is Node:
targets.append(t.get_handle(db))
elif type(t) is Link:
targets.append(t._apply_assignment(assignment, db))
elif type(t) is Variable or type(t) is TypedVariable:
targets.append(assignment.mapping[t.name])
return Link(self.atom_type, targets, self.ordered)

def matched(self, db: DBInterface, answer: PatternMatchingAnswer) -> bool:
if DEBUG_LINK: print('link match', self)
if any(isinstance(atom, LinkTemplate) for atom in self.targets):
Expand Down Expand Up @@ -535,7 +558,17 @@ def matched(self, db: DBInterface, answer: PatternMatchingAnswer) -> bool:
return bool(answer.assignments)
else:
if DEBUG_LINK: print('matched()', f'leaving 2 self = {self}')
return db.link_exists(self.atom_type, target_handles) or bool(answer.assignments)
if db.link_exists(self.atom_type, target_handles):
return True
else:
new_assignments = set()
for assignment in answer.assignments:
assert type(assignment) is OrderedAssignment
link = self.apply_assignment(assignment, db)
if db.link_exists(link.atom_type, link.targets):
new_assignments.add(assignment)
answer.assignments = new_assignments
return bool(answer.assignments)

class Variable(Atom):
"""
Expand Down

0 comments on commit 63b9451

Please sign in to comment.