diff --git a/scrapscript.py b/scrapscript.py index ee1b5dfc..0dc712c8 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -542,11 +542,11 @@ class MatchError(Exception): pass -def match(obj: Object, pattern: Object) -> bool | Env: +def match(obj: Object, pattern: Object) -> Optional[Env]: if isinstance(pattern, Int): - return isinstance(obj, Int) and obj.value == pattern.value + return {} if isinstance(obj, Int) and obj.value == pattern.value else None if isinstance(pattern, String): - return isinstance(obj, String) and obj.value == pattern.value + return {} if isinstance(obj, String) and obj.value == pattern.value else None if isinstance(pattern, Var): return {pattern.name: obj} raise NotImplementedError("TODO: match") @@ -607,12 +607,8 @@ def eval(env: Env, exp: Object) -> Object: arg = eval(env, exp.arg) for case in callee.func.cases: m = match(arg, case.pattern) - if isinstance(m, bool): - if m: - return eval(env, case.body) - else: - continue - assert isinstance(m, dict) + if m is None: + continue return eval({**env, **m}, case.body) raise MatchError("no matching cases") else: @@ -1077,23 +1073,27 @@ def test_parse_double_compose(self) -> None: class MatchTests(unittest.TestCase): - def test_match_with_equal_ints_returns_true(self) -> None: - self.assertTrue(match(Int(1), pattern=Int(1))) + def test_match_with_equal_ints_returns_empty_dict(self) -> None: + self.assertEqual(match(Int(1), pattern=Int(1)), {}) - def test_match_with_inequal_ints_returns_false(self) -> None: - self.assertFalse(match(Int(2), pattern=Int(1))) + def test_match_with_inequal_ints_returns_none(self) -> None: + self.assertEqual(match(Int(2), pattern=Int(1)), None) - def test_match_int_with_non_int_returns_false(self) -> None: - self.assertFalse(match(String("abc"), pattern=Int(1))) + def test_match_int_with_non_int_returns_none(self) -> None: + self.assertEqual(match(String("abc"), pattern=Int(1)), None) - def test_match_with_equal_strings_returns_true(self) -> None: - self.assertTrue(match(String("a"), pattern=String("a"))) + def test_match_with_equal_strings_returns_empty_dict(self) -> None: + self.assertEqual(match(String("a"), pattern=String("a")), {}) - def test_match_with_inequal_strings_returns_false(self) -> None: - self.assertFalse(match(String("b"), pattern=String("a"))) + def test_match_with_inequal_strings_returns_none(self) -> None: + self.assertEqual(match(String("b"), pattern=String("a")), None) + + def test_match_string_with_non_string_returns_none(self) -> None: + self.assertEqual(match(Int(1), pattern=String("abc")), None) + + def test_match_var_returns_dict_with_var_name(self) -> None: + self.assertEqual(match(String("abc"), pattern=Var("a")), {"a": String("abc")}) - def test_match_string_with_non_string_returns_false(self) -> None: - self.assertFalse(match(Int(1), pattern=String("abc"))) class EvalTests(unittest.TestCase):