Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize code generation for pattern matching #203

Draft
wants to merge 7 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 245 additions & 29 deletions compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Hole,
Int,
List,
MatchCase,
MatchFunction,
Object,
Record,
Expand Down Expand Up @@ -60,9 +61,218 @@ def decl(self) -> str:
return f"struct object* {self.name}({args})"


class MatchKind:
def compile(self, arg: str) -> str:
raise NotImplementedError


class AcceptAny(MatchKind):
def compile(self, arg: str) -> str:
return "true"


class IsNumber(MatchKind):
def compile(self, arg: str) -> str:
return f"is_num({arg})"


class IsHole(MatchKind):
def compile(self, arg: str) -> str:
return f"is_hole({arg})"


class IsString(MatchKind):
def compile(self, arg: str) -> str:
return f"is_string({arg})"


class IsVariant(MatchKind):
def compile(self, arg: str) -> str:
return f"is_variant({arg})"


class IsList(MatchKind):
pass


class IsRecord(MatchKind):
pass


@dataclasses.dataclass
class NumberHasValue(MatchKind):
value: int

def compile(self, arg: str) -> str:
return f"is_num_equal_word({arg}, {self.value})"


def coerce_int(object: Object) -> int:
assert isinstance(object, Int)
return object.value


@dataclasses.dataclass
class StringHasValue(MatchKind):
value: str

def compile(self, arg: str) -> str:
if len(self.value) < 8:
return f"({arg} == mksmallstring({json.dumps(self.value)}, {len(self.value)}))"
return f"string_equal_cstr_len({arg}, {json.dumps(self.value)}, {len(self.value)})"


def coerce_string(object: Object) -> str:
assert isinstance(object, String)
return object.value


@dataclasses.dataclass
class VariantHasTag(MatchKind):
tag: str

def compile(self, arg: str) -> str:
return f"(variant_tag({arg}) == Tag_{self.tag})"


@dataclasses.dataclass(frozen=True)
class CondExpr(Object):
arg: Var # Actually, probably this one isn't needed??
condition: MatchKind
body: Object


@dataclasses.dataclass(frozen=True)
class MatchExpr(Object):
arg: Object # Maybe not needed?
cases: typing.List[CondExpr]
fallthrough_case: Where | None


@dataclasses.dataclass(frozen=True)
class VariantValueExpr(Object):
variant: Object


def group_cases(
cases: typing.List[MatchCase], keyof: object, is_fallthrough: object
) -> tuple[typing.List[typing.List[MatchCase]], MatchCase | None]:
print("ungrouped cases")
print(cases)
groups = {}
fallthrough = None
for case in cases:
if is_fallthrough(case):
fallthrough = case
# nothing can match after the var
break
else:
if keyof(case) in groups:
groups[keyof(case)].append(case)
else:
groups[keyof(case)] = [case]

print("grouped cases")
print(groups)
return list(groups.values()), fallthrough


def typename(case: MatchCase) -> str:
return type(case.pattern).__name__


def pattern_is_var(case: MatchCase) -> bool:
return isinstance(case.pattern, Var)


def let(name: Var, value: Object, body: Object) -> Where:
return Where(body, Assign(name, value))


def compile_match_function(match_fn: MatchFunction) -> Function:
fn_arg = Var(gensym("fn_arg"))
match_arg = Var(gensym("match"))
cases, fallthrough_case = compile_ungrouped_match_cases(match_arg, match_fn.cases, typename, pattern_is_var)
return Function(fn_arg, let(match_arg, fn_arg, MatchExpr(match_arg, cases, fallthrough_case)))


def compile_ungrouped_match_cases(
arg: Var, cases: typing.List[MatchCase], group_key: object, is_fallthrough: object
) -> tuple[typing.List[CondExpr], Where | None]:
grouped, fallthrough_case = group_cases(cases, group_key, is_fallthrough)
return [expand_group(arg, group, fallthrough_case) for group in grouped], compile_var_case(arg, fallthrough_case)


def compile_var_case(arg: Var, case: MatchCase | None) -> Where | None:
if case:
assert isinstance(case.pattern, Var)
return Where(case.body, Assign(case.pattern, arg))
return None


def compile_int_cases(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None):
cases = [CondExpr(arg, NumberHasValue(coerce_int(case.pattern)), case.body) for case in group]
return MatchExpr(arg, cases, compile_var_case(arg, fallthrough_case))


def compile_string_cases(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None):
cases = [CondExpr(arg, StringHasValue(coerce_string(case.pattern)), case.body) for case in group]
return MatchExpr(arg, cases, compile_var_case(arg, fallthrough_case))


def compile_variant_cases(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None):
def case_tag(case: MatchCase):
assert isinstance(case.pattern, Variant)
return case.pattern.tag

grouped_by_variant, _ = group_cases(group, case_tag, lambda x: False)
cond_exprs = []
for group in grouped_by_variant:
lifted_matches = [MatchCase(case.pattern.value, case.body) for case in group]
print("lifted_matches", repr(lifted_matches))
inner_arg = Var(gensym("variant_match"))
expanded_cases, inner_fallthrough_case = compile_ungrouped_match_cases(
inner_arg, lifted_matches, typename, pattern_is_var
)
match_expr = let(inner_arg, VariantValueExpr(arg), MatchExpr(inner_arg, expanded_cases, inner_fallthrough_case))
cond_exprs.append(CondExpr(arg, VariantHasTag(group[0].pattern.tag), match_expr))

return MatchExpr(arg, cond_exprs, compile_var_case(arg, fallthrough_case))


def expand_group(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None):
if not group:
assert fallthrough_case
return compile_var_case(arg, fallthrough_case)
canonical_case = group[0]
if isinstance(canonical_case.pattern, Int):
return CondExpr(arg, IsNumber(), compile_int_cases(arg, group, fallthrough_case))
if isinstance(canonical_case.pattern, Hole):
# throwing away subsequent holes
return CondExpr(arg, IsHole(), canonical_case.body)
if isinstance(canonical_case.pattern, Var):
raise Exception("saw a var")
if isinstance(canonical_case.pattern, Variant):
return CondExpr(arg, IsVariant(), compile_variant_cases(arg, group, fallthrough_case))
if isinstance(canonical_case.pattern, String):
return CondExpr(arg, IsString(), compile_string_cases(arg, group, fallthrough_case))
# if isinstance(canonical_case.pattern, List):
# if isinstance(canonical_case.pattern, Record):
raise NotImplementedError("expand_group", canonical_case.pattern)


gensym_counter = 0


def gensym(stem: str = "tmp") -> str:
global gensym_counter
gensym_counter += 1
return f"{stem}_{gensym_counter-1}"


class Compiler:
def __init__(self, main_fn: CompiledFunction) -> None:
self.gensym_counter: int = 0
# self.gensym_counter: int = 0
self.functions: typing.List[CompiledFunction] = [main_fn]
self.function: CompiledFunction = main_fn
self.record_keys: Dict[str, int] = {}
Expand Down Expand Up @@ -105,8 +315,7 @@ def variant_tag(self, key: str) -> int:
return result

def gensym(self, stem: str = "tmp") -> str:
self.gensym_counter += 1
return f"{stem}_{self.gensym_counter-1}"
return gensym(stem)

def _emit(self, line: str) -> None:
self.function.code.append(line)
Expand Down Expand Up @@ -152,7 +361,7 @@ def compile_assign(self, env: Env, exp: Assign) -> Env:
return {**env, name: value}
if isinstance(exp.value, MatchFunction):
# Named match function
value = self.compile_match_function(env, exp.value, name)
value = self.compile_function(env, compile_match_function(exp.value), name)
return {**env, name: value}
value = self.compile(env, exp.value)
return {**env, name: value}
Expand Down Expand Up @@ -262,29 +471,6 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En
return updates
raise NotImplementedError("try_match", pattern)

def compile_match_function(self, env: Env, exp: MatchFunction, name: Optional[str]) -> str:
arg = self.gensym()
fn = self.make_compiled_function(arg, exp, name)
self.functions.append(fn)
cur = self.function
self.function = fn
funcenv = self.compile_function_env(fn, name)
for i, case in enumerate(exp.cases):
fallthrough = f"case_{i+1}" if i < len(exp.cases) - 1 else "no_match"
env_updates = self.try_match(funcenv, arg, case.pattern, fallthrough)
case_result = self.compile({**funcenv, **env_updates}, case.body)
self._emit(f"return {case_result};")
self._emit(f"{fallthrough}:;")
self._emit(r'fprintf(stderr, "no matching cases\n");')
self._emit("abort();")
# Pacify the C compiler
self._emit("return NULL;")
self.function = cur
if not fn.fields:
# TODO(max): Closure over freevars but only consts
return self._const_closure(fn)
return self.make_closure(env, fn)

def make_closure(self, env: Env, fn: CompiledFunction) -> str:
name = self._mktemp(f"mkclosure(heap, {fn.name}, {len(fn.fields)})")
for i, field in enumerate(fn.fields):
Expand Down Expand Up @@ -448,8 +634,38 @@ def compile(self, env: Env, exp: Object) -> str:
return self.compile_function(env, exp, name=None)
if isinstance(exp, MatchFunction):
# Anonymous match function
return self.compile_match_function(env, exp, name=None)
raise NotImplementedError(f"exp {type(exp)} {exp}")
return self.compile_function(env, compile_match_function(exp), name=None)
if isinstance(exp, MatchExpr):
return self.compile_match_expr(env, exp)
if isinstance(exp, VariantValueExpr):
value = self.compile(env, exp.variant)
return self._mktemp(f"variant_value({value});")
raise NotImplementedError(f"exp {type(exp)} {exp!r}")

def compile_match_expr(self, env: Env, match_expr: MatchExpr) -> str:
arg = self.compile(env, match_expr.arg)
result = self.gensym("result")
done = self.gensym("done")
self._emit(f"struct object* {result} = NULL;")
for cond in match_expr.cases:
if isinstance(cond.condition, VariantHasTag):
self.variant_tag(cond.condition.tag)
fallthrough = self.gensym("case")
c_cond = cond.condition.compile(arg)
self._emit(f"if (!{c_cond}) goto {fallthrough};")
case_result = self.compile(env, cond.body)
self._emit(f"{result} = {case_result};")
self._emit(f"goto {done};")
self._emit(f"{fallthrough}:;")
if match_expr.fallthrough_case:
c_name = self.compile(env, match_expr.fallthrough_case)
self._emit(f"{result} = {c_name};")
self._emit(f"goto {done};")
else:
self._emit(r'fprintf(stderr, "no matching cases\n");')
self._emit("abort();")
self._emit(f"{done}:;")
return result


def compile_to_string(program: Object, debug: bool) -> str:
Expand Down
4 changes: 3 additions & 1 deletion scrapscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,9 @@ def free_in(exp: Object) -> Set[str]:
if isinstance(exp, Closure):
# TODO(max): Should this remove the set of keys in the closure env?
return free_in(exp.func)
raise NotImplementedError(("free_in", type(exp)))
# :'(
return set()
# raise NotImplementedError(("free_in", type(exp)))


def improve_closure(closure: Closure) -> Closure:
Expand Down
Loading