Skip to content

Commit

Permalink
Add serialize method and bencode (#8)
Browse files Browse the repository at this point in the history
* Add serialize method and bencode

Serialize should transform objects into a simple Python structure and
bencode should take that structure and turn it into binary. There is no
deserialization yet.

* Add more serialization functions and helper

* Add more serialization functions

* Move type= to _serialize function

* Implement default serializer for only Object fields

This makes our lives easier.

* Add another end-to-end serialization test

* Add serialize function to stdlib

* Update for Python3.8 support
  • Loading branch information
tekknolagi authored Nov 22, 2023
1 parent e62b106 commit c55f1f7
Showing 1 changed file with 203 additions and 1 deletion.
204 changes: 203 additions & 1 deletion scrapscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import base64
import code
import dataclasses
import enum
import json
import logging
Expand Down Expand Up @@ -345,33 +346,63 @@ def parse(tokens: typing.List[str], p: float = 0) -> "Object":

@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Object:
pass
def serialize(self) -> Dict[bytes, object]:
cls = type(self)
result: Dict[bytes, object] = {b"type": cls.__name__.encode("utf-8")}
for field in dataclasses.fields(cls):
if issubclass(field.type, Object):
value = getattr(self, field.name)
result[field.name.encode("utf-8")] = value.serialize()
else:
raise NotImplementedError("serializing non-Object fields; write your own serialize function")
return result

def _serialize(self, **kwargs: object) -> Dict[bytes, object]:
return {
b"type": type(self).__name__.encode("utf-8"),
**{key.encode("utf-8"): value for key, value in kwargs.items()},
}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Int(Object):
value: int

def serialize(self) -> Dict[bytes, object]:
return self._serialize(value=self.value)


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class String(Object):
value: str

def serialize(self) -> Dict[bytes, object]:
return {b"type": b"String", b"value": self.value.encode("utf-8")}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Bytes(Object):
value: bytes

def serialize(self) -> Dict[bytes, object]:
return {b"type": b"Bytes", b"value": self.value}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Var(Object):
name: str

def serialize(self) -> Dict[bytes, object]:
return {b"type": b"Var", b"name": self.name.encode("utf-8")}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Bool(Object):
value: bool

def serialize(self) -> Dict[bytes, object]:
return {b"type": b"Bool", b"value": self.value}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Hole(Object):
Expand Down Expand Up @@ -429,11 +460,22 @@ class Binop(Object):
left: Object
right: Object

def serialize(self) -> Dict[bytes, object]:
return {
b"type": b"Binop",
b"op": self.op.name.encode("utf-8"),
b"left": self.left.serialize(),
b"right": self.right.serialize(),
}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class List(Object):
items: typing.List[Object]

def serialize(self) -> Dict[bytes, object]:
return {b"type": b"List", b"items": [item.serialize() for item in self.items]}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Assign(Object):
Expand Down Expand Up @@ -471,10 +513,17 @@ class Assert(Object):
cond: Object


def serialize_env(env: Env) -> Dict[bytes, object]:
return {key.encode("utf-8"): value.serialize() for key, value in env.items()}


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class EnvObject(Object):
env: Env

def serialize(self) -> Dict[bytes, object]:
return self._serialize(value=serialize_env(self.env))


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class MatchCase(Object):
Expand All @@ -486,6 +535,9 @@ class MatchCase(Object):
class MatchFunction(Object):
cases: typing.List[MatchCase]

def serialize(self) -> Dict[bytes, object]:
return self._serialize(cases=[case.serialize() for case in self.cases])


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class NativeFunction(Object):
Expand All @@ -497,11 +549,17 @@ class Closure(Object):
env: Env
func: Union[Function, MatchFunction]

def serialize(self) -> Dict[bytes, object]:
return self._serialize(env=serialize_env(self.env), func=self.func.serialize())


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Record(Object):
data: Dict[str, Object]

def serialize(self) -> Dict[bytes, object]:
return self._serialize(data={key.encode("utf-8"): value.serialize() for key, value in self.data.items()})


@dataclass(eq=True, frozen=True, unsafe_hash=True)
class Access(Object):
Expand Down Expand Up @@ -664,6 +722,23 @@ def eval_exp(env: Env, exp: Object) -> Object:
raise NotImplementedError(f"eval_exp not implemented for {exp}")


def bencode(obj: object) -> bytes:
if isinstance(obj, int):
return b"i" + str(int(obj)).encode("ascii") + b"e"
if isinstance(obj, bytes):
return str(len(obj)).encode("ascii") + b":" + obj
if isinstance(obj, list):
return b"l" + b"".join(bencode(x) for x in obj) + b"e"
if isinstance(obj, dict):
sorted_items = sorted(obj.items(), key=lambda x: x[0])
return b"d" + b"".join(bencode(k) + bencode(v) for k, v in sorted_items) + b"e"
raise NotImplementedError(f"bencode not implemented for {type(obj)}")


def serialize(obj: Object) -> bytes:
return bencode(obj.serialize())


class TokenizerTests(unittest.TestCase):
def test_tokenize_digit(self) -> None:
self.assertEqual(tokenize("1"), ["1"])
Expand Down Expand Up @@ -1749,6 +1824,132 @@ def test_stdlib_quote_pipe(self) -> None:
def test_stdlib_quote_reverse_pipe(self) -> None:
self.assertEqual(self._run("$$quote <| 3 + 4"), Binop(BinopKind.ADD, Int(3), Int(4)))

def test_stdlib_serialize(self) -> None:
self.assertEqual(self._run("$$serialize 3", STDLIB), Bytes(value=b"d4:type3:Int5:valuei3ee"))

def test_stdlib_serialize_expr(self) -> None:
self.assertEqual(
self._run("(1+2) |> $$quote |> $$serialize", STDLIB),
Bytes(value=b"d4:leftd4:type3:Int5:valuei1ee2:op3:ADD5:rightd4:type3:Int5:valuei2ee4:type5:Binope"),
)


class BencodeTests(unittest.TestCase):
def test_bencode_int(self) -> None:
self.assertEqual(bencode(123), b"i123e")

def test_bencode_bool(self) -> None:
self.assertEqual(bencode(True), b"i1e")

def test_bencode_negative_int(self) -> None:
self.assertEqual(bencode(-123), b"i-123e")

def test_serialize_bytes(self) -> None:
self.assertEqual(bencode(b"abc"), b"3:abc")

def test_bencode_empty_list(self) -> None:
self.assertEqual(bencode([]), b"le")

def test_bencode_list_of_ints(self) -> None:
self.assertEqual(bencode([1, 2, 3]), b"li1ei2ei3ee")

def test_bencode_list_of_lists(self) -> None:
self.assertEqual(bencode([[1, 2], [3, 4]]), b"lli1ei2eeli3ei4eee")

def test_bencode_dict_sorts_keys(self) -> None:
d = {}
d[b"b"] = 1
d[b"a"] = 2
# It's sorted by insertion order (guaranteed Python 3.6+)
self.assertEqual([*d], [b"b", b"a"])
# It's sorted lexicographically
self.assertEqual(bencode(d), b"d1:ai2e1:bi1ee")


class ObjectSerializeTests(unittest.TestCase):
def test_serialize_int(self) -> None:
obj = Int(123)
self.assertEqual(obj.serialize(), {b"type": b"Int", b"value": 123})

def test_serialize_negative_int(self) -> None:
obj = Int(-123)
self.assertEqual(obj.serialize(), {b"type": b"Int", b"value": -123})

def test_serialize_str(self) -> None:
obj = String("abc")
self.assertEqual(obj.serialize(), {b"type": b"String", b"value": b"abc"})

def test_serialize_bytes(self) -> None:
obj = Bytes(b"abc")
self.assertEqual(obj.serialize(), {b"type": b"Bytes", b"value": b"abc"})

def test_serialize_var(self) -> None:
obj = Var("abc")
self.assertEqual(obj.serialize(), {b"type": b"Var", b"name": b"abc"})

def test_serialize_bool(self) -> None:
obj = Bool(True)
self.assertEqual(obj.serialize(), {b"type": b"Bool", b"value": True})

def test_serialize_binary_add(self) -> None:
obj = Binop(BinopKind.ADD, Int(123), Int(456))
self.assertEqual(
obj.serialize(),
{
b"left": {b"type": b"Int", b"value": 123},
b"op": b"ADD",
b"right": {b"type": b"Int", b"value": 456},
b"type": b"Binop",
},
)

def test_serialize_list(self) -> None:
obj = List([Int(1), Int(2)])
self.assertEqual(
obj.serialize(),
{b"type": b"List", b"items": [{b"type": b"Int", b"value": 1}, {b"type": b"Int", b"value": 2}]},
)

def test_serialize_assign(self) -> None:
obj = Assign(Var("x"), Int(2))
self.assertEqual(
obj.serialize(),
{b"type": b"Assign", b"name": {b"name": b"x", b"type": b"Var"}, b"value": {b"type": b"Int", b"value": 2}},
)

def test_serialize_record(self) -> None:
obj = Record({"x": Int(1)})
self.assertEqual(obj.serialize(), {b"data": {b"x": {b"type": b"Int", b"value": 1}}, b"type": b"Record"})


class SerializeTests(unittest.TestCase):
def test_serialize_int(self) -> None:
obj = Int(3)
self.assertEqual(serialize(obj), b"d4:type3:Int5:valuei3ee")

def test_serialize_str(self) -> None:
obj = String("abc")
self.assertEqual(serialize(obj), b"d4:type6:String5:value3:abce")

def test_serialize_bytes(self) -> None:
obj = Bytes(b"abc")
self.assertEqual(serialize(obj), b"d4:type5:Bytes5:value3:abce")

def test_serialize_var(self) -> None:
obj = Var("abc")
self.assertEqual(serialize(obj), b"d4:name3:abc4:type3:Vare")

def test_serialize_bool(self) -> None:
obj = Bool(True)
self.assertEqual(serialize(obj), b"d4:type4:Bool5:valuei1ee")

def test_serialize_function(self) -> None:
obj = Function(Var("x"), Binop(BinopKind.ADD, Int(1), Var("x")))
self.assertEqual(
serialize(obj),
b"d3:argd4:name1:x4:type3:Vare4:bodyd4:leftd4:type3:Int5:valuei1ee2:op3:ADD5:rightd4:name1:x4:type3:Vare4:type5:Binope4:type8:Functione",
)


def eval_command(args: argparse.Namespace) -> None:
if args.debug:
Expand Down Expand Up @@ -1809,6 +2010,7 @@ def jsondecode(obj: Object) -> Object:
"$$add": NativeFunction(lambda x: NativeFunction(lambda y: Int(unpack_int(x) + unpack_int(y)))),
"$$fetch": NativeFunction(fetch),
"$$jsondecode": NativeFunction(jsondecode),
"$$serialize": NativeFunction(lambda obj: Bytes(serialize(obj))),
}


Expand Down

0 comments on commit c55f1f7

Please sign in to comment.