From 2aaf9922423a928017a433f3da12b3a5ca1da914 Mon Sep 17 00:00:00 2001 From: Paul Govereau Date: Mon, 30 Dec 2024 14:31:01 -0500 Subject: [PATCH] chore: remove unused code --- NKL.lean | 2 +- NKL/NKI.lean | 58 ---------- NKL/PrettyPrint.lean | 60 ----------- NKL/Python.lean | 24 ++++- interop/mk.sh | 4 +- interop/nkl/__init__.py | 9 +- interop/nkl/lean.py | 5 +- interop/nkl/loader.py | 227 ---------------------------------------- interop/nkl/parser.py | 2 +- lakefile.lean | 2 +- 10 files changed, 30 insertions(+), 363 deletions(-) delete mode 100644 NKL/NKI.lean delete mode 100644 NKL/PrettyPrint.lean delete mode 100644 interop/nkl/loader.py diff --git a/NKL.lean b/NKL.lean index a8d195a..baf227c 100644 --- a/NKL.lean +++ b/NKL.lean @@ -5,5 +5,5 @@ Authors: Paul Govereau -/ import NKL.Encode import NKL.FFI -import NKL.NKI +import NKL.KLR import NKL.Python diff --git a/NKL/NKI.lean b/NKL/NKI.lean deleted file mode 100644 index 8f124b8..0000000 --- a/NKL/NKI.lean +++ /dev/null @@ -1,58 +0,0 @@ -/- -Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Paul Govereau --/ -import Lean - -/-! -# Syntax of NKI kernels - -Representation for the abstract syntax of NKI kernels -generated by the python frontend. --/ - -namespace NKL - -inductive Const where - | nil - | bool (value: Bool) - | int (value: Int) - | float (value: Float) - | string (value: String) - | dots - deriving Repr, BEq, Lean.ToJson, Lean.FromJson - -inductive Expr where - | value (c: Const) - | bvar (name: String) - | var (name value: String) - | subscript (tensor: Expr) (ix: List Expr) - | slice (l u step: Expr) - | binop (op: String) (left right: Expr) - | cond (e thn els: Expr) - | tuple (xs: List Expr) - | list (xs: List Expr) - | call (f: Expr) (args: List Expr) - | gridcall (f: Expr) (ix: List Expr) (args: List Expr) - deriving Repr, BEq, Lean.ToJson, Lean.FromJson - -inductive Stmt where - | ret (e: Expr) - | assign (x: Expr) (e: Expr) - | ifstm (e : Expr) (thn els: List Stmt) - | forloop (x: String) (iter: Expr) (body: List Stmt) - | check (e : Expr) - deriving Repr, BEq, Lean.ToJson, Lean.FromJson - ---structure Arg where --- name : String --- type : Option String := .none --- value : Option Const := .none --- deriving Repr, BEq, Lean.ToJson, Lean.FromJson - -structure Fun where - name : String - args : List String - body : List Stmt - deriving Repr, BEq, Lean.ToJson, Lean.FromJson diff --git a/NKL/PrettyPrint.lean b/NKL/PrettyPrint.lean deleted file mode 100644 index 92ccf7d..0000000 --- a/NKL/PrettyPrint.lean +++ /dev/null @@ -1,60 +0,0 @@ -/- -Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Paul Govereau --/ - -import NKL.NKI - -namespace NKL - -instance : ToString Const where - toString - | .nil => "None" - | .bool b => toString b - | .int i => toString i - | .float f => toString f - | .string s => s - | .dots => "..." - -mutual -private partial def exps_ s l := String.intercalate s (List.map expr l) -private partial def exps := exps_ "," - -private partial def expr : Expr -> String - | .value c => toString c - | .bvar s | .var s _ => s - | .subscript e ix => expr e ++ "[" ++ exps ix ++ "]" - | .slice l u s => exps_ ":" [l,u,s] - | .binop op l r => op ++ "(" ++ expr l ++ "," ++ expr r ++ ")" - | .cond e thn els => expr thn ++ " if " ++ expr e ++ " else " ++ expr els - | .tuple es => "(" ++ exps es ++ ")" - | .list es => "[" ++ exps es ++ "]" - | .call f es => expr f ++ "(" ++ exps es ++ ")" - | .gridcall f ix es => expr f ++ "[" ++ exps ix ++ "](" ++ exps es ++ ")" -end - -instance : ToString Expr where - toString := expr - -mutual -private partial def stmts sp l := - String.intercalate "\n" $ List.map (stmt sp) l - -private partial def stmt (sp : String) (stmt : Stmt) : String := - let stmts := stmts (sp ++ " ") - sp ++ match stmt with - | .ret e => s!"ret {e}" - | .assign x e => s!"{x} = {e}" - | .ifstm e thn els => s!"if ({e}):\n{stmts thn}¬{sp}else:\n{stmts els}" - | .forloop x e b => s!"for {x} in {expr e}:\n{stmts b}" - | .check e => "assert(" ++ expr e ++ ")" -end - -instance : ToString Stmt where - toString := stmt "" - -def print_nki (f : Fun) : IO Unit := do - IO.println $ f.name ++"("++ String.intercalate "," f.args ++")" - IO.println $ stmts " " f.body - diff --git a/NKL/Python.lean b/NKL/Python.lean index 9ee5ff7..ae380c9 100644 --- a/NKL/Python.lean +++ b/NKL/Python.lean @@ -31,8 +31,19 @@ inductive Const where | ellipsis deriving Repr --- This context comes from the Python AST. --- The store hint is used by the tracing implementation for simplicity. +/- +This context comes from the Python AST. The different hints +indicated how an l-value term is being used. For example: + + x = 1 # x is store context + return x + 5 # x is load context + del x # x is del context + +The store hint is used by the tracing implementation for +simplicity: we do not try to resolve names that are being +"stored" to. +-/ + inductive Ctx where | load | store | del deriving Repr @@ -140,6 +151,15 @@ A kernel is collection of: are in the field `args` and the keyword argument are in the field `kwargs` - global variables referenced by any of the functions + +An example of a global is: + + use_fancy_thing = true # this will end up in globals + def kernel(): + if use_fancy_thing: + ... + else: + ... -/ structure Kernel where entry : String diff --git a/interop/mk.sh b/interop/mk.sh index d49c02f..f95b9ac 100644 --- a/interop/mk.sh +++ b/interop/mk.sh @@ -6,8 +6,8 @@ set -x # Need to decide which of lake or setuptools is better to use, # and how we will distribute everything -# make sure libNKL.a and lean_types.py are generated -(cd ..; lake build NKL Export) +# make sure libNKL.a is generated +(cd ..; lake build NKL) LEAN_PREFIX=$(lean --print-prefix) LEAN_CFLAGS="-I${LEAN_PREFIX}/include" diff --git a/interop/nkl/__init__.py b/interop/nkl/__init__.py index aeea995..9b93aaf 100644 --- a/interop/nkl/__init__.py +++ b/interop/nkl/__init__.py @@ -3,11 +3,4 @@ # Authors: Paul Govereau from .lean import load, to_json -from .loader import Loader - -def parse(f): - F = Loader(f) - return F.translate(F.ast) - -def parse_and_load(f): - load(parse(f)) +from .parser import Parser diff --git a/interop/nkl/lean.py b/interop/nkl/lean.py index d9be1fe..4afecaa 100644 --- a/interop/nkl/lean.py +++ b/interop/nkl/lean.py @@ -3,7 +3,6 @@ # Authors: Paul Govereau import json -from nkl.lean_types import * from nkl.lean_rffi import * def to_json_dict(obj): @@ -20,8 +19,8 @@ def to_json_dict(obj): return d return obj -def to_json(f: Fun): +def to_json(f): return json.dumps(to_json_dict(f)) -def load(f: Fun): +def load(f): py_to_lean(to_json(f)) diff --git a/interop/nkl/loader.py b/interop/nkl/loader.py deleted file mode 100644 index 5a4557f..0000000 --- a/interop/nkl/loader.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# Released under Apache 2.0 license as described in the file LICENSE. -# Authors: Paul Govereau - -import types -import typing -import inspect -import ast -import json - -from textwrap import dedent -from operator import add -from functools import reduce - -from nkl.lean import * - -def flat_map(f, l): return reduce(add, map(f, l), []) - -class NKLError(Exception): pass -class NKLInvalidConst(NKLError): pass -class NKLUnknownModule(NKLError): pass -class NKLUnknownSig(NKLError): pass -class NKLInvalidArgument(NKLError): pass -class NKLUnsupportred(NKLError): pass - -def opr(e: ast.AST): - assert e._fields == (), f"not an op? {op}" - return e.__class__.__name__ - -def const(c): - if c is None: return Nil() - elif c is Ellipsis: return Dots() - elif isinstance(c, bool): return Bool(c) - elif isinstance(c, int): return Int(c) - elif isinstance(c, float): return Float(c) - elif isinstance(c, str): return String(c) - else: raise NKLInvalidConst(c) - -def value(c): - return Value(const(c)) - -def arguments(ax): - def arg(x): - if isinstance(x, Expr): - return [x] - if isinstance(x, tuple): - return arguments(x) - if isinstance(x, dict): - if len(x) == 0: return [] - else: raise NKLInvalidArgument(x) - return [value(x)] - return flat_map(arg, ax) - -def apply_signature(f, args, kwargs): - if not isinstance(f, typing.Callable): - raise NKLUnknownSig(name) - sig = inspect.signature(f) - b = sig.bind(*args, **kwargs) - b.apply_defaults() - return arguments(b.arguments.values()) - -def apply(f, args, kwargs): - match f: - case Var(name, value): - return Call(f, apply_signature(value, args, kwargs)) - case Subscript(e, ix): - match apply(e, args, kwargs): - case Call(f, ax): - return Gridcall(f, ix, ax) - case _: assert 0, "internal error" - case _: - if len(kwargs) > 0: - raise NKLUnknownSig(ast.unparse(f)) - return Call(f, args) - -def check_module(s): - if s not in [ - 'math', 'numpy', - 'nki', 'nki.language', 'nki.isa' - ]: - raise NKLUnknownModule(s) - -class Loader: - def __init__(self, f: types.FunctionType): - super().__init__() - self.f = f - self.ast = ast.parse(dedent(inspect.getsource(f))) - - def translate(self, tree: ast.mod) -> Fun: - match tree: - case ast.Module([ast.FunctionDef(name, argsx, body, d, r, t)]): - args = [ a.arg for a in argsx.posonlyargs + argsx.args + argsx.kwonlyargs ] - return Fun(name, args, self.stmts(body)) - case _: - assert 0, "expecting function definition" - - def expr(self, e: ast.expr): - if e is None: - return Value(Nil()) - - def sorry(): assert 0, f"unsupporred expr {ast.dump(e)}" - def oper(op,x,y): return Binop(opr(op), self.expr(x), self.expr(y)) - def land(x,y): return Binop("And", x, y) - def compare(ops, l, rs): - match ops, rs: - case [op], [r]: - return oper(op, l, r) - case [op, *ops], [r, *rs]: - return land(oper(op, l, r), compare(ops, r, rs)) - case _: - assert 0, "invalid compare node" - - match e: - # constants - case ast.Constant(c): - return value(c) - - # variables - case ast.Name(name): - if name == "_": - return Bvar(name) - - if name in self.f.__code__.co_varnames: - return Bvar(name) - - if name not in self.f.__code__.co_names: - raise NameError(name) - - val = self.f.__globals__.get(name) or self.f.__builtins__.get(name) - if isinstance(val, types.ModuleType): - name = val.__name__ - check_module(name) - return Var(name, val) - - case ast.Attribute(n, a): - match self.expr(n): - case Bvar(n): - return Bvar(n + "." + a) - case Var(n, val): - if not hasattr(val, a): - raise AttributeError(f"{n} has no attribute {a}", - obj=val, name=a) - return Var(n + "." + a, getattr(val, a)) - case _: - raise NKLUnsupported(e) - - # subscript - case ast.Subscript(l, ast.Tuple(ix)): - return Subscript(self.expr(l), list(map(self.expr, ix))) - case ast.Subscript(l, ix): - return Subscript(self.expr(l), [self.expr(ix)]) - # only appears under subscript - case ast.Slice(l,u,s): - return Slice(self.expr(l), self.expr(u), self.expr(s)) - - # literals - case ast.Tuple(es): - return Tuple(self.exprs(es)) - case ast.List(es): - return List(self.exprs(es)) - - # binary operations - case ast.BoolOp(op, values): - op = opr(op) - values = map(self.expr, values) - return reduce(lambda x, y: Binop(op, x, y), values) - case ast.BinOp(l, op, r): - return oper(op, l, r) - case ast.UnaryOp(ast.USub(), val): - return oper(ast.Sub(), ast.Constant(0), val) - case ast.Compare(l, ops, rs): - return compare(ops, l, rs) - - # function calls - case ast.Call(f, args, kwargs): - args = self.exprs(args) - kwargs = {a.arg:self.expr(a.value) for a in kwargs} - return apply(self.expr(f), args, kwargs) - - # conditional expressions - case ast.IfExp(tst, tru, els): - return Cond(self.expr(tst), self.expr(tru), self.expr(els)) - - case e: - raise NKLUnsupported(e) - - def exprs(self, es): return list(map(self.expr, es)) - - # l-values - def lval(self, e: ast.expr): - return self.expr(e) - - def stmt(self, s: ast.stmt) -> [Stmt]: - match s: - case ast.Return(e): - return [Ret(self.expr(e))] - - # assignments - case ast.Assign(l, r): - x = self.lval(l[0]) - return [Assign(x, self.expr(r))] + list(map(lambda y: Assign(self.lval(y), x), l[1:])) - case ast.AugAssign(l, op, r): - return self.stmt(ast.Assign([l], ast.BinOp(l, op, r))) - case ast.AnnAssign(l, _, r): - return self.stmt(ast.Assign([l], r)) - case ast.Expr(ast.Constant()): - return [] - case ast.Expr(e): - return self.stmt(ast.Assign([ast.Name("_")], e)) - - # note: because we do not support break, we can handle orelse - case ast.For(t, i, body, orelse): - return [ Forloop(self.expr(t), self.expr(i), self.stmts(body)) ] + self.stmts(orelse) - - # if statements - case ast.If(c, t, e): - return [ Ifstm(self.expr(c), self.stmts(t), self.stmts(e)) ] - - # static assertions - case ast.Assert(e): - return [ Check(self.expr(e)) ] - - case s: - raise NKLUnsupported(s) - - def stmts(self, ss: [ast.stmt]) -> [Stmt]: - return flat_map(self.stmt, ss) diff --git a/interop/nkl/parser.py b/interop/nkl/parser.py index baaa81e..16f5df1 100644 --- a/interop/nkl/parser.py +++ b/interop/nkl/parser.py @@ -44,7 +44,7 @@ def default(self, obj): return "..." # Referenced names, that are not functions are placed in the -# global environment. Unlike functions, these values cannot +# global environment. Unlike functions, these values cannot be # reflected on using the ast module (the inspect module can only # fetch sources for a limited number of types). This function # provides an encoding for the global environment for a common diff --git a/lakefile.lean b/lakefile.lean index 02a5bb6..036f399 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -6,7 +6,7 @@ package "NKL" where lean_lib "NKL" where defaultFacets := #[LeanLib.staticFacet] -lean_lib "Export" where +--lean_lib "Export" where @[default_target] lean_exe "nkl" where