From a5fbcef35650542e359508efe6b9214b72554949 Mon Sep 17 00:00:00 2001 From: Paul Govereau Date: Sun, 29 Dec 2024 13:12:41 -0500 Subject: [PATCH] feat: add global references and arguments Add ability to capture global references, including function arguments when parsing kernels. --- NKL/FFI.lean | 14 +++--- NKL/Python.lean | 102 ++++++++++++++++++++++++++++++------------ interop/nkl/parser.py | 93 +++++++++++++++++++++++++++++++------- 3 files changed, 158 insertions(+), 51 deletions(-) diff --git a/NKL/FFI.lean b/NKL/FFI.lean index 591b19f..6f59d36 100644 --- a/NKL/FFI.lean +++ b/NKL/FFI.lean @@ -4,8 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ import Lean -import NKL.NKI -import NKL.PrettyPrint import NKL.Python namespace NKL @@ -15,15 +13,15 @@ local instance : MonadLift (Except String) IO where | .ok x => return x | .error s => throw $ .userError s -@[export parse_json_old] -def parse_json_old (json : String) : IO Unit := do - let jsn <- Lean.Json.parse json - let f:Fun <- Lean.fromJson? jsn - print_nki f - @[export parse_json] def parse_json (s : String) : IO Unit := do let kernel <- Python.Parsing.parse s let names := kernel.funcs.map fun x => x.fst let names := String.intercalate "," names IO.println s!"Found functions: {names}" + for x in kernel.args do + IO.println s!"arg: {repr x}" + for x in kernel.kwargs do + IO.println s!"arg: {repr x}" + for x in kernel.globals do + IO.println s!"global: {repr x}" diff --git a/NKL/Python.lean b/NKL/Python.lean index f577ce3..9ee5ff7 100644 --- a/NKL/Python.lean +++ b/NKL/Python.lean @@ -15,10 +15,8 @@ see: https://docs.python.org/3/library/ast.html namespace NKL namespace Python -deriving instance Repr for Lean.JsonNumber - structure Pos where - lineno : Nat + lineno : Nat := 0 end_lineno : Nat := 0 col_offset : Nat := 0 end_col_offset : Nat := 0 @@ -26,14 +24,15 @@ structure Pos where inductive Const where | none - | bool (value: Bool) - | num (value: Lean.JsonNumber) - | string (value: String) + | bool (value : Bool) + | int (value : Int) + | float (value : Float) + | string (value : String) | ellipsis deriving Repr --- We don't need these, but we preserve them to make copying --- the Python AST over easier. +-- This context comes from the Python AST. +-- The store hint is used by the tracing implementation for simplicity. inductive Ctx where | load | store | del deriving Repr @@ -44,8 +43,9 @@ inductive Expr where deriving Repr inductive Expr' where - | const (value: Const) - | name (id: String) (ctx : Ctx) + | const (value : Const) + | tensor (shape : List Expr) (dtype : String) + | name (id : String) (ctx : Ctx) | attr (value : Expr) (id : String) (ctx : Ctx) | tuple (xs: List Expr) (ctx : Ctx) | list (xs: List Expr) (ctx : Ctx) @@ -111,6 +111,13 @@ structure Args where kwarg : Option String deriving Repr +def Args.names (ax : Args) : List String := + let xs := ax.posonlyargs.append ax.args + let xs := match ax.vararg with | none => xs | some x => xs.append [x] + let xs := xs.append ax.kwonlyargs + let xs := match ax.kwarg with | none => xs | some x => xs.append [x] + xs + /- In addition to the defaults above from the AST, we also collect the values from f.__defaults__ here in the Fun structure. These @@ -124,10 +131,22 @@ structure Fun where body: List Stmt deriving Repr +/- +A kernel is collection of: + - the name of the main kernel function: `entry` + - functions, including the primary function and any functions + called by the primary func that we are able to parse + - arguments to the primary function, the positional arguments + are in the field `args` and the keyword argument are in the + field `kwargs` + - global variables referenced by any of the functions +-/ structure Kernel where entry : String funcs : List (String × Fun) - globals : List (String × Option String) + args : List Expr' + kwargs : List (String × Expr') + globals : List (String × Expr') ------------------------------------------------------------------------------- -- Converting Python AST from Json @@ -191,29 +210,30 @@ private def withPos (p : String -> Json -> Parser b) (f : b -> Pos -> a) : Json return (f exp pos) | _ => throw "expecting object" +def genError (source err : String) (pos : Pos) : String := + let lines := source.splitOn "\n" + let lineno := pos.lineno - 1 + let colno := pos.col_offset + let line := if lines.length < lineno + then "" + else lines[lineno]! + let indent := (Nat.repeat (List.cons ' ') colno List.nil).asString + s!"line {lineno}:\n{line}\n{indent}^-- {err}" + private def withSrc (source : String) (p : Parser a) : Parser a := try set { lineno := 0 : Pos } ; p - catch e => get >>= throw ∘ genError e -where - genError (err : String) (pos : Pos) : String := - let lines := source.splitOn "\n" - let lineno := pos.lineno - 1 - let colno := pos.col_offset - let line := if lines.length < lineno - then "" - else lines[lineno]! - let indent := (Nat.repeat (List.cons ' ') colno List.nil).asString - s!"line {lineno}:\n{line}\n{indent}^-- {err}" + catch e => get >>= throw ∘ genError source e ------------------------------------------------------------------------------- -- Python AST Json objects def const : Json -> Parser Const | .null => return .none - | .bool b => return (.bool b) - | .num jn => return (.num jn) + | .bool b => return .bool b + | .num { mantissa := m, exponent := 0 } => return .int m + | .num jn => return .float jn.toFloat | .str "..." => return .ellipsis - | .str s => return (.string s) + | .str s => return .string s | _ => throw "expecting constant" def exprCtx : Json -> Parser Ctx @@ -302,14 +322,40 @@ def function (j : Json) : Parser Fun := do let body <- field (list stmt) j "body" return Fun.mk source args defaults body +-- Both global references and arguments are processed in the global +-- environment. These terms do not have a position, and must be +-- evaluable in the default environment. +partial def global : Json -> Parser Expr' + | .null => return .const .none + | .obj (.node _ _ "fun" (.str s) _) => return .name s .load + | .obj (.node _ _ "mod" (.str s) _) => return .name s .load + | .obj (.node _ _ "bool" (.bool b) _) => return .const (.bool b) + | .obj (.node _ _ "float" (.num n) _) => return .const (.float n.toFloat) + | .obj (.node _ _ "int" (.num ⟨m,0⟩) _) => return .const (.int m) + | .obj (.node _ _ "str" (.str s) _) => return .const (.string s) + | .obj (.node _ _ "tuple" (.arr arr) _) => return .tuple (<- globals arr) .load + | .obj (.node _ _ "list" (.arr arr) _) => return .list (<- globals arr) .load + | .obj (.node _ _ "tensor" kvs _) => do + let dtype <- field global kvs "dtype" + let shape <- field global kvs "shape" + match dtype, shape with + | .const (.string s), .tuple l _ => return .tensor l s + | _, _ => throw "malformed tensor type" + | j => throw s!"malformed global environment '{j}'" +where + globals (arr : Array Json) : Parser (List Expr) := + arr.toList.mapM fun x => return .exprPos (<- global x) {} + def kernel (j : Json) : Parser Kernel := do let name <- field str j "entry" let funcs <- field (dict function) j "funcs" - let globals <- field (dict (opt str)) j "globals" - return Kernel.mk name funcs globals + let args <- field (list global) j "args" + let kwargs <- field (dict global) j "kwargs" + let globals <- field (dict global) j "globals" + return Kernel.mk name funcs args kwargs globals def parse (s : String) : Except String Kernel := do let jsn <- Json.parse s - match kernel jsn { lineno := 0 } with + match kernel jsn {} with | .ok x _ => .ok x | .error s _ => .error s diff --git a/interop/nkl/parser.py b/interop/nkl/parser.py index 230136f..baaa81e 100644 --- a/interop/nkl/parser.py +++ b/interop/nkl/parser.py @@ -6,6 +6,7 @@ import inspect import ast import json +import numpy as np from textwrap import dedent from itertools import chain @@ -42,19 +43,52 @@ def default(self, obj): except Exception: return "..." +# Referenced names, that are not functions are placed in the +# global environment. Unlike functions, these values cannot +# reflected on using the ast module (the inspect module can only +# fetch sources for a limited number of types). This function +# provides an encoding for the global environment for a common +# set of types. + +class Unsupported(Exception): pass + +def encode_for_env(val): + match val: + case bool(b): return {'bool':b} + case int(i): return {'int':i} + case float(f): return {'float':f} + case str(s): return {'str':s} + case types.NoneType(): return None + case tuple(t): return {'tuple': list(map(encode_for_env, t))} + case list(l): return {'list': list(map(encode_for_env, l))} + case types.ModuleType(): return {'mod':val.__name__} + case np.ndarray(): + return { + 'tensor': { + 'dtype': encode_for_env(str(val.dtype)), + 'shape': encode_for_env(val.shape) + } + } + case _: + raise Unsupported(f"global value type: {val.__class__.__name__}") + class Parser(ast.NodeVisitor): def __init__(self, f: types.FunctionType): super().__init__() self.workq = deque() self.funcs = {} self.globals = {} + self.args = [] + self.kwargs = {} self.entry = f.__module__ + "." + f.__name__ - self.reference(self.entry, f) + self.ref_global(self.entry, f) self.do_work() def json(self): d = { 'entry': self.entry , 'funcs': self.funcs + , 'args' : self.args + , 'kwargs' : self.kwargs , 'globals': self.globals } return json.dumps(d, cls=Enc) @@ -63,31 +97,56 @@ def json(self): def load(self): py_to_lean(self.json()) + def apply_args(self, *args, **kwargs): + self.args = [] + self.kwargs = {} + d = {} + for arg in args: + self.reference(d, '_', arg) + try: self.args.append(d.popitem()[1]) + except Exception: + raise Exception("Unsupported argument type") + for k,v in kwargs.items(): + self.ref_arg(k, v) + + def __call__(self, *args, **kwargs): + self.apply_args(*args, **kwargs) + py_to_lean(self.json()) + + def ref_arg(self, refname, val): + return self.reference(self.kwargs, refname, val) + + def ref_global(self, refname, val): + return self.reference(self.globals, refname, val) + # resolve a reference: either populating the environment, # or adding new items to the work queue - def reference(self, refname, val): + def reference(self, env, refname, val): f = None if isinstance(val, types.FunctionType): f = val - val = f.__module__ + "." + f.__name__ - elif isinstance(val, types.ModuleType): - val = val.__name__ + fname = f.__module__ + "." + f.__name__ + val = {'fun': fname} + else: + try: val = encode_for_env(val) + except Exception: + return - if refname in self.globals: - if val != self.globals[refname]: + if refname in env: + if val != env[refname]: assert 0, "global mismatch" else: - self.globals[refname] = val + env[refname] = val if f is None: return try: match ast.parse(dedent(inspect.getsource(f))): case ast.Module([ast.FunctionDef(_, args, body)]): - self.workq.append((val, f, args, body)) + self.workq.append((fname, f, args, body)) case _: assert 0, "expecting function definition" - except Exception as e: + except Exception: pass def do_work(self): @@ -121,9 +180,11 @@ def visit_Name(self, node): return try: y = self.lookup(node.id) - self.reference(node.id, y) + self.ref_global(node.id, y) return node.id, y - except Exception as e: + except Unsupported as e: + raise e + except Exception: return def visit_Attribute(self, node): @@ -134,9 +195,11 @@ def visit_Attribute(self, node): n, x = self.visit(node.value) n = n + "." + node.attr y = getattr(x, node.attr) - self.reference(n, y) + self.ref_global(n, y) return n, y - except Exception as e: + except Unsupported as e: + raise e + except Exception: return def fun_defaults(self, f: types.FunctionType): @@ -152,6 +215,6 @@ def is_ok(x): if isinstance(x, types.FunctionType): # TODO: this could be incorrect if default # is using an alternate name for the function - self.reference(x.__name__, x) + self.ref_global(x.__name__, x) return False return { n:v for (n,v) in tbl.items() if is_ok(v) }