Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add serialize method and bencode #8

Merged
merged 8 commits into from
Nov 22, 2023
204 changes: 203 additions & 1 deletion scrapscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import base64
import code
import dataclasses
import enum
import json
import logging
Expand Down Expand Up @@ -345,33 +346,63 @@ def parse(tokens: typing.List[str], p: float = 0) -> "Object":

@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Object:
pass
def serialize(self) -> dict[bytes, object]:
cls = type(self)
result: dict[bytes, object] = {b"type": cls.__name__.encode("utf-8")}
for field in dataclasses.fields(cls):
if issubclass(field.type, Object):
value = getattr(self, field.name)
result[field.name.encode("utf-8")] = value.serialize()
else:
raise NotImplementedError("serializing non-Object fields; write your own serialize function")
return result

def _serialize(self, **kwargs: object) -> dict[bytes, object]:
return {
b"type": type(self).__name__.encode("utf-8"),
**{key.encode("utf-8"): value for key, value in kwargs.items()},
}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Int(Object):
value: int

def serialize(self) -> dict[bytes, object]:
return self._serialize(value=self.value)


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class String(Object):
value: str

def serialize(self) -> dict[bytes, object]:
return {b"type": b"String", b"value": self.value.encode("utf-8")}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Bytes(Object):
value: bytes

def serialize(self) -> dict[bytes, object]:
return {b"type": b"Bytes", b"value": self.value}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Var(Object):
name: str

def serialize(self) -> dict[bytes, object]:
return {b"type": b"Var", b"name": self.name.encode("utf-8")}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Bool(Object):
value: bool

def serialize(self) -> dict[bytes, object]:
return {b"type": b"Bool", b"value": self.value}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Hole(Object):
Expand Down Expand Up @@ -429,11 +460,22 @@ class Binop(Object):
left: Object
right: Object

def serialize(self) -> dict[bytes, object]:
return {
b"type": b"Binop",
b"op": self.op.name.encode("utf-8"),
b"left": self.left.serialize(),
b"right": self.right.serialize(),
}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class List(Object):
items: typing.List[Object]

def serialize(self) -> dict[bytes, object]:
return {b"type": b"List", b"items": [item.serialize() for item in self.items]}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Assign(Object):
Expand Down Expand Up @@ -471,10 +513,17 @@ class Assert(Object):
cond: Object


def serialize_env(env: Env) -> dict[bytes, object]:
return {key.encode("utf-8"): value.serialize() for key, value in env.items()}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class EnvObject(Object):
env: Env

def serialize(self) -> dict[bytes, object]:
return self._serialize(value=serialize_env(self.env))


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class MatchCase(Object):
Expand All @@ -486,6 +535,9 @@ class MatchCase(Object):
class MatchFunction(Object):
cases: typing.List[MatchCase]

def serialize(self) -> dict[bytes, object]:
return self._serialize(cases=[case.serialize() for case in self.cases])


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class NativeFunction(Object):
Expand All @@ -497,11 +549,17 @@ class Closure(Object):
env: Env
func: Union[Function, MatchFunction]

def serialize(self) -> dict[bytes, object]:
return self._serialize(env=serialize_env(self.env), func=self.func.serialize())


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Record(Object):
data: Dict[str, Object]

def serialize(self) -> dict[bytes, object]:
return self._serialize(data={key.encode("utf-8"): value.serialize() for key, value in self.data.items()})


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Access(Object):
Expand Down Expand Up @@ -664,6 +722,23 @@ def eval_exp(env: Env, exp: Object) -> Object:
raise NotImplementedError(f"eval_exp not implemented for {exp}")


def bencode(obj: object) -> bytes:
if isinstance(obj, int):
return b"i" + str(int(obj)).encode("ascii") + b"e"
if isinstance(obj, bytes):
return str(len(obj)).encode("ascii") + b":" + obj
if isinstance(obj, list):
return b"l" + b"".join(bencode(x) for x in obj) + b"e"
if isinstance(obj, dict):
sorted_items = sorted(obj.items(), key=lambda x: x[0])
return b"d" + b"".join(bencode(k) + bencode(v) for k, v in sorted_items) + b"e"
raise NotImplementedError(f"bencode not implemented for {type(obj)}")


def serialize(obj: Object) -> bytes:
return bencode(obj.serialize())


class TokenizerTests(unittest.TestCase):
def test_tokenize_digit(self) -> None:
self.assertEqual(tokenize("1"), ["1"])
Expand Down Expand Up @@ -1749,6 +1824,132 @@ def test_stdlib_quote_pipe(self) -> None:
def test_stdlib_quote_reverse_pipe(self) -> None:
self.assertEqual(self._run("$$quote <| 3 + 4"), Binop(BinopKind.ADD, Int(3), Int(4)))

def test_stdlib_serialize(self) -> None:
self.assertEqual(self._run("$$serialize 3", STDLIB), Bytes(value=b"d4:type3:Int5:valuei3ee"))

def test_stdlib_serialize_expr(self) -> None:
self.assertEqual(
self._run("(1+2) |> $$quote |> $$serialize", STDLIB),
Bytes(value=b"d4:leftd4:type3:Int5:valuei1ee2:op3:ADD5:rightd4:type3:Int5:valuei2ee4:type5:Binope"),
)


class BencodeTests(unittest.TestCase):
def test_bencode_int(self) -> None:
self.assertEqual(bencode(123), b"i123e")

def test_bencode_bool(self) -> None:
self.assertEqual(bencode(True), b"i1e")

def test_bencode_negative_int(self) -> None:
self.assertEqual(bencode(-123), b"i-123e")

def test_serialize_bytes(self) -> None:
self.assertEqual(bencode(b"abc"), b"3:abc")

def test_bencode_empty_list(self) -> None:
self.assertEqual(bencode([]), b"le")

def test_bencode_list_of_ints(self) -> None:
self.assertEqual(bencode([1, 2, 3]), b"li1ei2ei3ee")

def test_bencode_list_of_lists(self) -> None:
self.assertEqual(bencode([[1, 2], [3, 4]]), b"lli1ei2eeli3ei4eee")

def test_bencode_dict_sorts_keys(self) -> None:
d = {}
d[b"b"] = 1
d[b"a"] = 2
# It's sorted by insertion order (guaranteed Python 3.6+)
self.assertEqual([*d], [b"b", b"a"])
# It's sorted lexicographically
self.assertEqual(bencode(d), b"d1:ai2e1:bi1ee")


class ObjectSerializeTests(unittest.TestCase):
def test_serialize_int(self) -> None:
obj = Int(123)
self.assertEqual(obj.serialize(), {b"type": b"Int", b"value": 123})

def test_serialize_negative_int(self) -> None:
obj = Int(-123)
self.assertEqual(obj.serialize(), {b"type": b"Int", b"value": -123})

def test_serialize_str(self) -> None:
obj = String("abc")
self.assertEqual(obj.serialize(), {b"type": b"String", b"value": b"abc"})

def test_serialize_bytes(self) -> None:
obj = Bytes(b"abc")
self.assertEqual(obj.serialize(), {b"type": b"Bytes", b"value": b"abc"})

def test_serialize_var(self) -> None:
obj = Var("abc")
self.assertEqual(obj.serialize(), {b"type": b"Var", b"name": b"abc"})

def test_serialize_bool(self) -> None:
obj = Bool(True)
self.assertEqual(obj.serialize(), {b"type": b"Bool", b"value": True})

def test_serialize_binary_add(self) -> None:
obj = Binop(BinopKind.ADD, Int(123), Int(456))
self.assertEqual(
obj.serialize(),
{
b"left": {b"type": b"Int", b"value": 123},
b"op": b"ADD",
b"right": {b"type": b"Int", b"value": 456},
b"type": b"Binop",
},
)

def test_serialize_list(self) -> None:
obj = List([Int(1), Int(2)])
self.assertEqual(
obj.serialize(),
{b"type": b"List", b"items": [{b"type": b"Int", b"value": 1}, {b"type": b"Int", b"value": 2}]},
)

def test_serialize_assign(self) -> None:
obj = Assign(Var("x"), Int(2))
self.assertEqual(
obj.serialize(),
{b"type": b"Assign", b"name": {b"name": b"x", b"type": b"Var"}, b"value": {b"type": b"Int", b"value": 2}},
)

def test_serialize_record(self) -> None:
obj = Record({"x": Int(1)})
self.assertEqual(obj.serialize(), {b"data": {b"x": {b"type": b"Int", b"value": 1}}, b"type": b"Record"})


class SerializeTests(unittest.TestCase):
def test_serialize_int(self) -> None:
obj = Int(3)
self.assertEqual(serialize(obj), b"d4:type3:Int5:valuei3ee")

def test_serialize_str(self) -> None:
obj = String("abc")
self.assertEqual(serialize(obj), b"d4:type6:String5:value3:abce")

def test_serialize_bytes(self) -> None:
obj = Bytes(b"abc")
self.assertEqual(serialize(obj), b"d4:type5:Bytes5:value3:abce")

def test_serialize_var(self) -> None:
obj = Var("abc")
self.assertEqual(serialize(obj), b"d4:name3:abc4:type3:Vare")

def test_serialize_bool(self) -> None:
obj = Bool(True)
self.assertEqual(serialize(obj), b"d4:type4:Bool5:valuei1ee")

def test_serialize_function(self) -> None:
obj = Function(Var("x"), Binop(BinopKind.ADD, Int(1), Var("x")))
self.assertEqual(
serialize(obj),
b"d3:argd4:name1:x4:type3:Vare4:bodyd4:leftd4:type3:Int5:valuei1ee2:op3:ADD5:rightd4:name1:x4:type3:Vare4:type5:Binope4:type8:Functione",
)


def eval_command(args: argparse.Namespace) -> None:
if args.debug:
Expand Down Expand Up @@ -1809,6 +2010,7 @@ def jsondecode(obj: Object) -> Object:
"$$add": NativeFunction(lambda x: NativeFunction(lambda y: Int(unpack_int(x) + unpack_int(y)))),
"$$fetch": NativeFunction(fetch),
"$$jsondecode": NativeFunction(jsondecode),
"$$serialize": NativeFunction(lambda obj: Bytes(serialize(obj))),
}


Expand Down
Loading