Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List pattern matching #14

Merged
merged 1 commit into from
Nov 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions scrapscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def parse(tokens: typing.List[str], p: float = 0) -> "Object":
else:
l.items.append(parse(tokens, 2))
while tokens.pop(0) != "]":
# TODO: Implement .. and ... operators
l.items.append(parse(tokens, 2))
elif token == "{":
l = Record({})
Expand Down Expand Up @@ -630,6 +631,21 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
assert isinstance(result, dict) # for .update()
result.update(part)
return result
if isinstance(pattern, List):
if not isinstance(obj, List):
return None
if len(pattern.items) != len(obj.items):
gregorybchris marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Remove this check when implementing ... operator
return None
result: Env = {} # type: ignore
for i, pattern_item in enumerate(pattern.items):
obj_item = obj.items[i]
part = match(obj_item, pattern_item)
if part is None:
return None
assert isinstance(result, dict) # for .update()
result.update(part)
return result
raise NotImplementedError(f"match not implemented for {type(pattern).__name__}")


Expand Down Expand Up @@ -1265,6 +1281,60 @@ def test_match_record_with_non_matching_const_returns_none(self) -> None:
None,
)

def test_match_list_with_non_list_returns_none(self) -> None:
self.assertEqual(
match(
Int(2),
pattern=List([Var("x"), Var("y")]),
),
None,
)

def test_match_list_with_more_fields_in_pattern_returns_none(self) -> None:
self.assertEqual(
match(
List([Int(1), Int(2)]),
pattern=List([Var("x"), Var("y"), Var("z")]),
),
None,
)

def test_match_list_with_fewer_fields_in_pattern_returns_none(self) -> None:
self.assertEqual(
match(
List([Int(1), Int(2)]),
pattern=List([Var("x")]),
),
None,
)

def test_match_list_with_vars_returns_dict_with_keys(self) -> None:
self.assertEqual(
match(
List([Int(1), Int(2)]),
pattern=List([Var("x"), Var("y")]),
),
{"x": Int(1), "y": Int(2)},
)

def test_match_list_with_matching_const_returns_dict_with_other_keys(self) -> None:
self.assertEqual(
match(
List([Int(1), Int(2)]),
pattern=List([Int(1), Var("y")]),
),
{"y": Int(2)},
)

def test_match_list_with_non_matching_const_returns_none(self) -> None:
self.assertEqual(
match(
List([Int(1), Int(2)]),
pattern=List([Int(3), Var("y")]),
),
None,
)


class EvalTests(unittest.TestCase):
def test_eval_int_returns_int(self) -> None:
Expand Down Expand Up @@ -1799,6 +1869,70 @@ def test_match_record_doubly_binds_vars(self) -> None:
Int(3),
)

def test_match_list_binds_vars(self) -> None:
self.assertEqual(
self._run(
"""
mult xs
. xs = [1, 2, 3, 4, 5]
. mult =
| [1, x, 3, y, 5] -> x * y
"""
),
Int(8),
)

def test_match_list_incorrect_length_does_not_match(self) -> None:
with self.assertRaises(MatchError):
self._run(
"""
mult xs
. xs = [1, 2, 3]
. mult =
| [1, 2] -> 1
| [1, 2, 3, 4] -> 1
| [1, 3] -> 1
"""
)

def test_match_list_with_constant(self) -> None:
self.assertEqual(
self._run(
"""
middle xs
. xs = [4, 5, 6]
. middle =
| [1, x, 3] -> x
| [4, x, 6] -> x
| [7, x, 9] -> x
"""
),
Int(5),
)

def test_match_list_with_non_list_fails(self) -> None:
with self.assertRaises(MatchError):
self._run(
"""
get_x 3
. get_x =
| [2, x] -> x
"""
)

def test_match_list_doubly_binds_vars(self) -> None:
self.assertEqual(
self._run(
"""
mult xs
. xs = [1, 2, 3, 2, 1]
. mult =
| [1, x, 3, x, 1] -> x
"""
),
Int(2),
)

def test_pipe(self) -> None:
self.assertEqual(self._run("1 |> (a -> a + 2)"), Int(3))

Expand Down