Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a compiler from scrapscript to JS #90

Draft
wants to merge 14 commits into
base: trunk
Choose a base branch
from
207 changes: 206 additions & 1 deletion scrapscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import code
import dataclasses
import enum
import functools
import http.server
import json
import logging
Expand Down Expand Up @@ -1254,6 +1255,91 @@ def bencode(obj: object) -> bytes:
raise NotImplementedError(f"bencode not implemented for {type(obj)}")


class JSCompiler:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏 you know what I'm thinking

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imports are our friends

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a prototype before doing it in scrap

def compile(self, env: Env, exp: Object) -> str:
if isinstance(exp, Int):
return str(exp.value)
if isinstance(exp, Binop):
left = self.compile(env, exp.left)
right = self.compile(env, exp.right)
return f"({left})" + BinopKind.to_str(exp.op) + f"({right})"
if isinstance(exp, Var):
# assert exp.name in env
return exp.name
if isinstance(exp, Where):
binding = exp.binding
assert isinstance(binding, Assign)
return self.compile_let(env, binding.name.name, binding.value, exp.body)
if isinstance(exp, Assign):
value = self.compile(env, exp.value)
return f"const {exp.name.name} = {value};\n"
if isinstance(exp, Apply):
func = self.compile(env, exp.func)
arg = self.compile(env, exp.arg)
return f"({func})({arg})"
if isinstance(exp, Function):
arg = self.compile(env, exp.arg)
body = self.compile(env, exp.body)
return f"({arg}) => ({body})"
if isinstance(exp, List):
items = [self.compile(env, item) for item in exp.items]
return "[" + ", ".join(items) + "]"
if isinstance(exp, MatchFunction):
err = "(() => {throw 'oh no'})()"
if not exp.cases:
return err
# TODO(max): Gensym arg name or something
arg = "__x"

def per_case(acc: str, case: MatchCase) -> str:
cond, body = self.compile_match_case(env, arg, case)
return f"({cond}) ? ({body}) : ({acc})"

return f"({arg}) => " + functools.reduce(
per_case,
reversed(exp.cases),
err,
)
if isinstance(exp, Symbol):
if exp.value in ("true", "false"):
return exp.value
return repr(exp.value)
if isinstance(exp, String):
return repr(exp.value)
if isinstance(exp, Access):
obj = self.compile(env, exp.obj)
if isinstance(exp.at, Int):
return f"{obj}[{exp.at}]"
assert isinstance(exp.at, Var)
return f"{obj}.{exp.at}"
if isinstance(exp, Record):
result = "{"
for key, rec_value in exp.data.items():
result += repr(key) + ":" + self.compile(env, rec_value) + ","
return result + "}"
raise NotImplementedError(type(exp), exp)

def compile_let(self, env: Env, name: str, value: Object, body: Object) -> str:
body_str = self.compile(env, body)
value_str = self.compile(env, value)
return f"(({name}) => ({body_str}))({value_str})"

def compile_match_case(self, env: Env, arg: str, case: MatchCase) -> Tuple[str, str]:
pattern = case.pattern
body = case.body
if isinstance(pattern, Int):
return f"{arg} === {pattern.value}", self.compile(env, body)
if isinstance(pattern, Var):
return "true", self.compile_let(env, pattern.name, Var(arg), body)
raise NotImplementedError(type(pattern))


def compile_exp_js(env: Env, exp: Object) -> str:
compiler = JSCompiler()
result = compiler.compile(env, exp)
return result


class Bdecoder:
def __init__(self, msg: str) -> None:
self.msg: str = msg
Expand Down Expand Up @@ -4158,6 +4244,105 @@ def test_pretty_print_symbol(self) -> None:
self.assertEqual(str(obj), "#x")


class JSCompilerTests(unittest.TestCase):
def test_compile_int(self) -> None:
exp = Int(123)
self.assertEqual(compile_exp_js({}, exp), "123")

def test_compile_binop_add(self) -> None:
exp = Binop(BinopKind.ADD, Int(3), Int(4))
self.assertEqual(compile_exp_js({}, exp), "(3)+(4)")

def test_compile_binop_rec(self) -> None:
exp = Binop(BinopKind.MUL, Binop(BinopKind.ADD, Int(3), Int(4)), Int(5))
self.assertEqual(compile_exp_js({}, exp), "((3)+(4))*(5)")

def test_compile_where(self) -> None:
exp = Where(Var("x"), Assign(Var("x"), Int(1)))
self.assertEqual(compile_exp_js({}, exp), "((x) => (x))(1)")

def test_compile_nested_where(self) -> None:
exp = parse(tokenize("x + y . x = 1 . y = 2"))
self.assertEqual(compile_exp_js({}, exp), "((y) => (((x) => ((x)+(y)))(1)))(2)")

def test_compile_apply(self) -> None:
exp = Apply(Var("f"), Var("x"))
self.assertEqual(compile_exp_js({}, exp), "(f)(x)")

def test_compile_apply_nested(self) -> None:
exp = Apply(Apply(Var("f"), Var("x")), Var("y"))
self.assertEqual(compile_exp_js({}, exp), "((f)(x))(y)")

def test_compile_function(self) -> None:
exp = Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(1)))
self.assertEqual(compile_exp_js({}, exp), "(x) => ((x)+(1))")

def test_compile_function_nested(self) -> None:
exp = parse(tokenize("x -> y -> x + y"))
self.assertEqual(compile_exp_js({}, exp), "(x) => ((y) => ((x)+(y)))")

def test_compile_list(self) -> None:
exp = List([Binop(BinopKind.ADD, Int(1), Int(2)), Binop(BinopKind.MUL, Int(3), Int(4))])
self.assertEqual(compile_exp_js({}, exp), "[(1)+(2), (3)*(4)]")

def test_compile_match_function(self) -> None:
exp = parse(tokenize("| 1 -> 2 | 2 -> 3"))
self.assertEqual(
compile_exp_js({}, exp), "(__x) => (__x === 1) ? (2) : ((__x === 2) ? (3) : ((() => {throw 'oh no'})()))"
)

def test_compile_match_function_var(self) -> None:
exp = parse(tokenize("| 1 -> 2 | x -> x"))
self.assertEqual(
compile_exp_js({}, exp),
"(__x) => (__x === 1) ? (2) : ((true) ? (((x) => (x))(__x)) : ((() => {throw 'oh no'})()))",
)

def test_compile_symbol_bool_true(self) -> None:
exp = Symbol("true")
self.assertEqual(compile_exp_js({}, exp), "true")

def test_compile_symbol_bool_false(self) -> None:
exp = Symbol("false")
self.assertEqual(compile_exp_js({}, exp), "false")

def test_compile_symbol(self) -> None:
exp = Symbol("hello")
self.assertEqual(compile_exp_js({}, exp), "'hello'")

def test_compile_string(self) -> None:
exp = String("hello")
self.assertEqual(compile_exp_js({}, exp), "'hello'")

def test_compile_string_single_quotes(self) -> None:
exp = String("'hello'")
self.assertEqual(compile_exp_js({}, exp), "\"'hello'\"")

def test_compile_string_double_quotes(self) -> None:
exp = String('"hello"')
self.assertEqual(compile_exp_js({}, exp), "'\"hello\"'")

def test_compile_access_int(self) -> None:
exp = Access(Var("x"), Int(1))
self.assertEqual(compile_exp_js({}, exp), "x[1]")

def test_compile_access_field(self) -> None:
exp = Access(Var("x"), Var("y"))
self.assertEqual(compile_exp_js({}, exp), "x.y")

def test_compile_nested_access(self) -> None:
exp = Access(Access(Var("x"), Var("y")), Var("z"))
self.assertEqual(compile_exp_js({}, exp), "x.y.z")

def test_compile_empty_record(self) -> None:
exp = Record({})
self.assertEqual(compile_exp_js({}, exp), "{}")

def test_compile_record(self) -> None:
exp = Record({"a": Int(1), "b": Int(2)})
self.assertEqual(compile_exp_js({}, exp), "{'a':1,'b':2,}")


def fetch(url: Object) -> Object:
if not isinstance(url, String):
raise TypeError(f"fetch expected String, but got {type(url).__name__}")
Expand Down Expand Up @@ -4323,6 +4508,25 @@ def runsource(self, source: str, filename: str = "<input>", symbol: str = "singl
return False


class JSRepl(ScrapRepl):
def runsource(self, source: str, filename: str = "<input>", symbol: str = "single") -> bool:
try:
tokens = tokenize(source)
logger.debug("Tokens: %s", tokens)
ast = parse(tokens)
logger.debug("AST: %s", ast)
result = compile_exp_js(self.env, ast)
print(result)
except UnexpectedEOFError:
# Need to read more text
return True
except ParseError as e:
print(f"Parse error: {e}", file=sys.stderr)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return False


def eval_command(args: argparse.Namespace) -> None:
if args.debug:
logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -4352,7 +4556,7 @@ def repl_command(args: argparse.Namespace) -> None:
if args.debug:
logging.basicConfig(level=logging.DEBUG)

repl = ScrapRepl()
repl = JSRepl() if args.js else ScrapRepl()
if readline:
repl.enable_readline()
repl.interact(banner="")
Expand Down Expand Up @@ -4390,6 +4594,7 @@ def main() -> None:
repl = subparsers.add_parser("repl")
repl.set_defaults(func=repl_command)
repl.add_argument("--debug", action="store_true")
repl.add_argument("--js", action="store_true")

test = subparsers.add_parser("test")
test.set_defaults(func=test_command)
Expand Down