Skip to content

Commit

Permalink
Add support for floats (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorybchris authored Jan 3, 2024
1 parent d7873b7 commit 05d2e00
Showing 1 changed file with 139 additions and 28 deletions.
167 changes: 139 additions & 28 deletions scrapscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class IntLit(Token):
value: int


@dataclass(eq=True)
class FloatLit(Token):
value: float


@dataclass(eq=True)
class StringLit(Token):
value: str
Expand Down Expand Up @@ -174,7 +179,7 @@ def read_one(self) -> Token:
return self.read_bytes()
raise ParseError(f"unexpected token {c!r}")
if c.isdigit():
return self.read_integer(c)
return self.read_number(c)
if c in "()[]{}":
custom = {
"(": LeftParen,
Expand Down Expand Up @@ -205,11 +210,23 @@ def read_comment(self) -> None:
while self.has_input() and self.read_char() != "\n":
pass

def read_integer(self, first_digit: str) -> Token:
def read_number(self, first_digit: str) -> Token:
# TODO: Support floating point numbers with no integer part
buf = first_digit
while self.has_input() and (c := self.peek_char()).isdigit():
has_decimal = False
while self.has_input():
c = self.peek_char()
if c == ".":
if has_decimal:
raise ParseError(f"unexpected token {c!r}")
has_decimal = True
elif not c.isdigit():
break
self.read_char()
buf += c

if has_decimal:
return self.make_token(FloatLit, float(buf))
return self.make_token(IntLit, int(buf))

def _starts_operator(self, buf: str) -> bool:
Expand Down Expand Up @@ -350,8 +367,9 @@ def parse(tokens: typing.List[Token], p: float = 0) -> "Object":
token = tokens.pop(0)
l: Object
if isinstance(token, IntLit):
# TODO: Handle float literals
l = Int(token.value)
elif isinstance(token, FloatLit):
l = Float(token.value)
elif isinstance(token, Name):
# TODO: Handle kebab case vars
l = Var(token.value)
Expand Down Expand Up @@ -554,6 +572,21 @@ def __str__(self) -> str:
return str(self.value)


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Float(Object):
value: float

def serialize(self) -> Dict[bytes, object]:
raise NotImplementedError("serialization for Float is not supported")

@staticmethod
def deserialize(msg: Dict[str, object]) -> "Float":
raise NotImplementedError("serialization for Float is not supported")

def __str__(self) -> str:
return str(self.value)


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class String(Object):
value: str
Expand Down Expand Up @@ -971,15 +1004,15 @@ def __str__(self) -> str:
return f"#{self.value}"


def unpack_int(obj: Object) -> int:
if not isinstance(obj, Int):
raise TypeError(f"expected Int, got {type(obj).__name__}")
def unpack_number(obj: Object) -> Union[int, float]:
if not isinstance(obj, (Int, Float)):
raise TypeError(f"expected Int or Float, got {type(obj).__name__}")
return obj.value


def eval_int(env: Env, exp: Object) -> int:
def eval_number(env: Env, exp: Object) -> Union[int, float]:
result = eval_exp(env, exp)
return unpack_int(result)
return unpack_number(result)


def eval_str(env: Env, exp: Object) -> str:
Expand Down Expand Up @@ -1007,20 +1040,30 @@ def make_bool(x: bool) -> Object:
return Symbol("true" if x else "false")


def wrap_inferred_number_type(x: Union[int, float]) -> Object:
# TODO: Since this is intended to be a reference implementation
# we should avoid relying heavily on Python's implementation of
# arithmetic operations, type inference, and multiple dispatch.
# Update this to make the interpreter more language agnostic.
if isinstance(x, int):
return Int(x)
return Float(x)


BINOP_HANDLERS: Dict[BinopKind, Callable[[Env, Object, Object], Object]] = {
BinopKind.ADD: lambda env, x, y: Int(eval_int(env, x) + eval_int(env, y)),
BinopKind.SUB: lambda env, x, y: Int(eval_int(env, x) - eval_int(env, y)),
BinopKind.MUL: lambda env, x, y: Int(eval_int(env, x) * eval_int(env, y)),
BinopKind.DIV: lambda env, x, y: Int(eval_int(env, x) // eval_int(env, y)),
BinopKind.FLOOR_DIV: lambda env, x, y: Int(eval_int(env, x) // eval_int(env, y)),
BinopKind.EXP: lambda env, x, y: Int(eval_int(env, x) ** eval_int(env, y)),
BinopKind.MOD: lambda env, x, y: Int(eval_int(env, x) % eval_int(env, y)),
BinopKind.ADD: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) + eval_number(env, y)),
BinopKind.SUB: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) - eval_number(env, y)),
BinopKind.MUL: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) * eval_number(env, y)),
BinopKind.DIV: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) / eval_number(env, y)),
BinopKind.FLOOR_DIV: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) // eval_number(env, y)),
BinopKind.EXP: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) ** eval_number(env, y)),
BinopKind.MOD: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) % eval_number(env, y)),
BinopKind.EQUAL: lambda env, x, y: make_bool(eval_exp(env, x) == eval_exp(env, y)),
BinopKind.NOT_EQUAL: lambda env, x, y: make_bool(eval_exp(env, x) != eval_exp(env, y)),
BinopKind.LESS: lambda env, x, y: make_bool(eval_int(env, x) < eval_int(env, y)),
BinopKind.GREATER: lambda env, x, y: make_bool(eval_int(env, x) > eval_int(env, y)),
BinopKind.LESS_EQUAL: lambda env, x, y: make_bool(eval_int(env, x) <= eval_int(env, y)),
BinopKind.GREATER_EQUAL: lambda env, x, y: make_bool(eval_int(env, x) >= eval_int(env, y)),
BinopKind.LESS: lambda env, x, y: make_bool(eval_number(env, x) < eval_number(env, y)),
BinopKind.GREATER: lambda env, x, y: make_bool(eval_number(env, x) > eval_number(env, y)),
BinopKind.LESS_EQUAL: lambda env, x, y: make_bool(eval_number(env, x) <= eval_number(env, y)),
BinopKind.GREATER_EQUAL: lambda env, x, y: make_bool(eval_number(env, x) >= eval_number(env, y)),
BinopKind.BOOL_AND: lambda env, x, y: make_bool(eval_bool(env, x) and eval_bool(env, y)),
BinopKind.BOOL_OR: lambda env, x, y: make_bool(eval_bool(env, x) or eval_bool(env, y)),
BinopKind.STRING_CONCAT: lambda env, x, y: String(eval_str(env, x) + eval_str(env, y)),
Expand All @@ -1037,6 +1080,8 @@ class MatchError(Exception):
def match(obj: Object, pattern: Object) -> Optional[Env]:
if isinstance(pattern, Int):
return {} if isinstance(obj, Int) and obj.value == pattern.value else None
if isinstance(pattern, Float):
raise MatchError("pattern matching is not supported for Floats")
if isinstance(pattern, String):
return {} if isinstance(obj, String) and obj.value == pattern.value else None
if isinstance(pattern, Var):
Expand Down Expand Up @@ -1092,7 +1137,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
# pylint: disable=redefined-builtin
def eval_exp(env: Env, exp: Object) -> Object:
logger.debug(exp)
if isinstance(exp, (Int, String, Bytes, Hole, Closure, NativeFunction, Symbol)):
if isinstance(exp, (Int, Float, String, Bytes, Hole, Closure, NativeFunction, Symbol)):
return exp
if isinstance(exp, Var):
value = env.get(exp.name)
Expand Down Expand Up @@ -1285,6 +1330,23 @@ def test_tokenize_multiple_digits(self) -> None:
def test_tokenize_negative_int(self) -> None:
self.assertEqual(tokenize("-123"), [Operator("-"), IntLit(123)])

def test_tokenize_float(self) -> None:
self.assertEqual(tokenize("3.14"), [FloatLit(3.14)])

def test_tokenize_negative_float(self) -> None:
self.assertEqual(tokenize("-3.14"), [Operator("-"), FloatLit(3.14)])

@unittest.skip("TODO: support floats with no integer part")
def test_tokenize_float_with_no_integer_part(self) -> None:
self.assertEqual(tokenize(".14"), [FloatLit(0.14)])

def test_tokenize_float_with_no_decimal_part(self) -> None:
self.assertEqual(tokenize("10."), [FloatLit(10.0)])

def test_tokenize_float_with_multiple_decimal_points_raises_parse_error(self) -> None:
with self.assertRaisesRegex(ParseError, re.escape("unexpected token '.'")):
tokenize("1.0.1")

def test_tokenize_binop(self) -> None:
self.assertEqual(tokenize("1 + 2"), [IntLit(1), Operator("+"), IntLit(2)])

Expand Down Expand Up @@ -1708,6 +1770,12 @@ def test_parse_negative_int_binds_tighter_than_apply(self) -> None:
Apply(Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")),
)

def test_parse_decimal_returns_float(self) -> None:
self.assertEqual(parse([FloatLit(3.14)]), Float(3.14))

def test_parse_negative_float_returns_binary_sub_float(self) -> None:
self.assertEqual(parse([Operator("-"), FloatLit(3.14)]), Binop(BinopKind.SUB, Int(0), Float(3.14)))

def test_parse_var_returns_var(self) -> None:
self.assertEqual(parse([Name("abc_123")]), Var("abc_123"))

Expand Down Expand Up @@ -2122,6 +2190,18 @@ def test_match_with_inequal_ints_returns_none(self) -> None:
def test_match_int_with_non_int_returns_none(self) -> None:
self.assertEqual(match(String("abc"), pattern=Int(1)), None)

def test_match_with_equal_floats_raises_match_error(self) -> None:
with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")):
match(Float(1), pattern=Float(1))

def test_match_with_inequal_floats_raises_match_error(self) -> None:
with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")):
match(Float(2), pattern=Float(1))

def test_match_float_with_non_float_raises_match_error(self) -> None:
with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")):
match(String("abc"), pattern=Float(1))

def test_match_with_equal_strings_returns_empty_dict(self) -> None:
self.assertEqual(match(String("a"), pattern=String("a")), {})

Expand Down Expand Up @@ -2379,6 +2459,10 @@ def test_eval_int_returns_int(self) -> None:
exp = Int(5)
self.assertEqual(eval_exp({}, exp), Int(5))

def test_eval_float_returns_float(self) -> None:
exp = Float(3.14)
self.assertEqual(eval_exp({}, exp), Float(3.14))

def test_eval_str_returns_str(self) -> None:
exp = String("xyz")
self.assertEqual(eval_exp({}, exp), String("xyz"))
Expand Down Expand Up @@ -2410,7 +2494,7 @@ def test_eval_with_binop_add_with_int_string_raises_type_error(self) -> None:
exp = Binop(BinopKind.ADD, Int(1), String("hello"))
with self.assertRaises(TypeError) as ctx:
eval_exp({}, exp)
self.assertEqual(ctx.exception.args[0], "expected Int, got String")
self.assertEqual(ctx.exception.args[0], "expected Int or Float, got String")

def test_eval_with_binop_sub(self) -> None:
exp = Binop(BinopKind.SUB, Int(1), Int(2))
Expand All @@ -2421,8 +2505,8 @@ def test_eval_with_binop_mul(self) -> None:
self.assertEqual(eval_exp({}, exp), Int(6))

def test_eval_with_binop_div(self) -> None:
exp = Binop(BinopKind.DIV, Int(2), Int(3))
self.assertEqual(eval_exp({}, exp), Int(0))
exp = Binop(BinopKind.DIV, Int(3), Int(10))
self.assertEqual(eval_exp({}, exp), Float(0.3))

def test_eval_with_binop_floor_div(self) -> None:
exp = Binop(BinopKind.FLOOR_DIV, Int(2), Int(3))
Expand Down Expand Up @@ -2707,7 +2791,7 @@ def test_eval_less_returns_bool(self) -> None:

def test_eval_less_on_non_bool_raises_type_error(self) -> None:
ast = Binop(BinopKind.LESS, String("xyz"), Int(4))
with self.assertRaisesRegex(TypeError, re.escape("expected Int, got String")):
with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")):
eval_exp({}, ast)

def test_eval_less_equal_returns_bool(self) -> None:
Expand All @@ -2716,7 +2800,7 @@ def test_eval_less_equal_returns_bool(self) -> None:

def test_eval_less_equal_on_non_bool_raises_type_error(self) -> None:
ast = Binop(BinopKind.LESS_EQUAL, String("xyz"), Int(4))
with self.assertRaisesRegex(TypeError, re.escape("expected Int, got String")):
with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")):
eval_exp({}, ast)

def test_eval_greater_returns_bool(self) -> None:
Expand All @@ -2725,7 +2809,7 @@ def test_eval_greater_returns_bool(self) -> None:

def test_eval_greater_on_non_bool_raises_type_error(self) -> None:
ast = Binop(BinopKind.GREATER, String("xyz"), Int(4))
with self.assertRaisesRegex(TypeError, re.escape("expected Int, got String")):
with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")):
eval_exp({}, ast)

def test_eval_greater_equal_returns_bool(self) -> None:
Expand All @@ -2734,7 +2818,7 @@ def test_eval_greater_equal_returns_bool(self) -> None:

def test_eval_greater_equal_on_non_bool_raises_type_error(self) -> None:
ast = Binop(BinopKind.GREATER_EQUAL, String("xyz"), Int(4))
with self.assertRaisesRegex(TypeError, re.escape("expected Int, got String")):
with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")):
eval_exp({}, ast)

def test_boolean_and_evaluates_args(self) -> None:
Expand Down Expand Up @@ -2791,6 +2875,21 @@ def test_eval_record_with_spread_fails(self) -> None:
def test_eval_symbol_returns_symbol(self) -> None:
self.assertEqual(eval_exp({}, Symbol("abc")), Symbol("abc"))

def test_eval_float_and_float_addition_returns_float(self) -> None:
self.assertEqual(eval_exp({}, Binop(BinopKind.ADD, Float(1.0), Float(2.0))), Float(3.0))

def test_eval_int_and_float_addition_returns_float(self) -> None:
self.assertEqual(eval_exp({}, Binop(BinopKind.ADD, Int(1), Float(2.0))), Float(3.0))

def test_eval_int_and_float_division_returns_float(self) -> None:
self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Int(1), Float(2.0))), Float(0.5))

def test_eval_float_and_int_division_returns_float(self) -> None:
self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Float(1.0), Int(2))), Float(0.5))

def test_eval_int_and_int_division_returns_float(self) -> None:
self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Int(1), Int(2))), Float(0.5))


class EndToEndTestsBase(unittest.TestCase):
def _run(self, text: str, env: Optional[Env] = None) -> Object:
Expand All @@ -2805,6 +2904,9 @@ class EndToEndTests(EndToEndTestsBase):
def test_int_returns_int(self) -> None:
self.assertEqual(self._run("1"), Int(1))

def test_float_returns_float(self) -> None:
self.assertEqual(self._run("3.14"), Float(3.14))

def test_bytes_returns_bytes(self) -> None:
self.assertEqual(self._run("~~QUJD"), Bytes(b"ABC"))

Expand Down Expand Up @@ -3604,6 +3706,11 @@ def test_serialize_negative_int(self) -> None:
obj = Int(-123)
self.assertEqual(obj.serialize(), {b"type": b"Int", b"value": -123})

def test_serialize_float_raises_not_implemented_error(self) -> None:
obj = Float(3.14)
with self.assertRaisesRegex(NotImplementedError, re.escape("serialization for Float is not supported")):
obj.serialize()

def test_serialize_str(self) -> None:
obj = String("abc")
self.assertEqual(obj.serialize(), {b"type": b"String", b"value": b"abc"})
Expand Down Expand Up @@ -3816,6 +3923,10 @@ def test_pretty_print_int(self) -> None:
obj = Int(1)
self.assertEqual(str(obj), "1")

def test_pretty_print_float(self) -> None:
obj = Float(3.14)
self.assertEqual(str(obj), "3.14")

def test_pretty_print_string(self) -> None:
obj = String("hello")
self.assertEqual(str(obj), '"hello"')
Expand Down

0 comments on commit 05d2e00

Please sign in to comment.