Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add global references and arguments #10

Merged
merged 1 commit into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean
import NKL.NKI
import NKL.PrettyPrint
import NKL.Python

namespace NKL
Expand All @@ -15,15 +13,15 @@ local instance : MonadLift (Except String) IO where
| .ok x => return x
| .error s => throw $ .userError s

@[export parse_json_old]
def parse_json_old (json : String) : IO Unit := do
let jsn <- Lean.Json.parse json
let f:Fun <- Lean.fromJson? jsn
print_nki f

@[export parse_json]
def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let names := kernel.funcs.map fun x => x.fst
let names := String.intercalate "," names
IO.println s!"Found functions: {names}"
for x in kernel.args do
IO.println s!"arg: {repr x}"
for x in kernel.kwargs do
IO.println s!"arg: {repr x}"
for x in kernel.globals do
IO.println s!"global: {repr x}"
102 changes: 74 additions & 28 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,24 @@ see: https://docs.python.org/3/library/ast.html
namespace NKL
namespace Python

deriving instance Repr for Lean.JsonNumber

structure Pos where
lineno : Nat
lineno : Nat := 0
end_lineno : Nat := 0
col_offset : Nat := 0
end_col_offset : Nat := 0
deriving Repr

inductive Const where
| none
| bool (value: Bool)
| num (value: Lean.JsonNumber)
| string (value: String)
| bool (value : Bool)
| int (value : Int)
| float (value : Float)
| string (value : String)
| ellipsis
deriving Repr

-- We don't need these, but we preserve them to make copying
-- the Python AST over easier.
-- This context comes from the Python AST.
-- The store hint is used by the tracing implementation for simplicity.
govereau marked this conversation as resolved.
Show resolved Hide resolved
inductive Ctx where
| load | store | del
deriving Repr
Expand All @@ -44,8 +43,9 @@ inductive Expr where
deriving Repr

inductive Expr' where
| const (value: Const)
| name (id: String) (ctx : Ctx)
| const (value : Const)
| tensor (shape : List Expr) (dtype : String)
govereau marked this conversation as resolved.
Show resolved Hide resolved
| name (id : String) (ctx : Ctx)
| attr (value : Expr) (id : String) (ctx : Ctx)
| tuple (xs: List Expr) (ctx : Ctx)
| list (xs: List Expr) (ctx : Ctx)
Expand Down Expand Up @@ -111,6 +111,13 @@ structure Args where
kwarg : Option String
deriving Repr

def Args.names (ax : Args) : List String :=
let xs := ax.posonlyargs.append ax.args
let xs := match ax.vararg with | none => xs | some x => xs.append [x]
let xs := xs.append ax.kwonlyargs
let xs := match ax.kwarg with | none => xs | some x => xs.append [x]
xs

/-
In addition to the defaults above from the AST, we also collect
the values from f.__defaults__ here in the Fun structure. These
Expand All @@ -124,10 +131,22 @@ structure Fun where
body: List Stmt
deriving Repr

/-
A kernel is collection of:
- the name of the main kernel function: `entry`
- functions, including the primary function and any functions
called by the primary func that we are able to parse
- arguments to the primary function, the positional arguments
are in the field `args` and the keyword argument are in the
field `kwargs`
- global variables referenced by any of the functions
-/
structure Kernel where
entry : String
funcs : List (String × Fun)
globals : List (String × Option String)
args : List Expr'
kwargs : List (String × Expr')
globals : List (String × Expr')
govereau marked this conversation as resolved.
Show resolved Hide resolved

-------------------------------------------------------------------------------
-- Converting Python AST from Json
Expand Down Expand Up @@ -191,29 +210,30 @@ private def withPos (p : String -> Json -> Parser b) (f : b -> Pos -> a) : Json
return (f exp pos)
| _ => throw "expecting object"

def genError (source err : String) (pos : Pos) : String :=
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe an example of what you're parsing here would be nice.

let lines := source.splitOn "\n"
let lineno := pos.lineno - 1
let colno := pos.col_offset
let line := if lines.length < lineno
then "<source not available>"
else lines[lineno]!
let indent := (Nat.repeat (List.cons ' ') colno List.nil).asString
s!"line {lineno}:\n{line}\n{indent}^-- {err}"

private def withSrc (source : String) (p : Parser a) : Parser a :=
try set { lineno := 0 : Pos } ; p
catch e => get >>= throw ∘ genError e
where
genError (err : String) (pos : Pos) : String :=
let lines := source.splitOn "\n"
let lineno := pos.lineno - 1
let colno := pos.col_offset
let line := if lines.length < lineno
then "<source not available>"
else lines[lineno]!
let indent := (Nat.repeat (List.cons ' ') colno List.nil).asString
s!"line {lineno}:\n{line}\n{indent}^-- {err}"
catch e => get >>= throw ∘ genError source e

-------------------------------------------------------------------------------
-- Python AST Json objects

def const : Json -> Parser Const
| .null => return .none
| .bool b => return (.bool b)
| .num jn => return (.num jn)
| .bool b => return .bool b
| .num { mantissa := m, exponent := 0 } => return .int m
| .num jn => return .float jn.toFloat
| .str "..." => return .ellipsis
| .str s => return (.string s)
| .str s => return .string s
| _ => throw "expecting constant"

def exprCtx : Json -> Parser Ctx
Expand Down Expand Up @@ -302,14 +322,40 @@ def function (j : Json) : Parser Fun := do
let body <- field (list stmt) j "body"
return Fun.mk source args defaults body

-- Both global references and arguments are processed in the global
-- environment. These terms do not have a position, and must be
-- evaluable in the default environment.
partial def global : Json -> Parser Expr'
| .null => return .const .none
| .obj (.node _ _ "fun" (.str s) _) => return .name s .load
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand what .load does.

| .obj (.node _ _ "mod" (.str s) _) => return .name s .load
| .obj (.node _ _ "bool" (.bool b) _) => return .const (.bool b)
| .obj (.node _ _ "float" (.num n) _) => return .const (.float n.toFloat)
| .obj (.node _ _ "int" (.num ⟨m,0⟩) _) => return .const (.int m)
| .obj (.node _ _ "str" (.str s) _) => return .const (.string s)
| .obj (.node _ _ "tuple" (.arr arr) _) => return .tuple (<- globals arr) .load
| .obj (.node _ _ "list" (.arr arr) _) => return .list (<- globals arr) .load
| .obj (.node _ _ "tensor" kvs _) => do
let dtype <- field global kvs "dtype"
let shape <- field global kvs "shape"
match dtype, shape with
| .const (.string s), .tuple l _ => return .tensor l s
| _, _ => throw "malformed tensor type"
| j => throw s!"malformed global environment '{j}'"
where
globals (arr : Array Json) : Parser (List Expr) :=
arr.toList.mapM fun x => return .exprPos (<- global x) {}

def kernel (j : Json) : Parser Kernel := do
let name <- field str j "entry"
let funcs <- field (dict function) j "funcs"
let globals <- field (dict (opt str)) j "globals"
return Kernel.mk name funcs globals
let args <- field (list global) j "args"
let kwargs <- field (dict global) j "kwargs"
let globals <- field (dict global) j "globals"
return Kernel.mk name funcs args kwargs globals

def parse (s : String) : Except String Kernel := do
let jsn <- Json.parse s
match kernel jsn { lineno := 0 } with
match kernel jsn {} with
| .ok x _ => .ok x
| .error s _ => .error s
93 changes: 78 additions & 15 deletions interop/nkl/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
import ast
import json
import numpy as np

from textwrap import dedent
from itertools import chain
Expand Down Expand Up @@ -42,19 +43,52 @@ def default(self, obj):
except Exception:
return "..."

# Referenced names, that are not functions are placed in the
# global environment. Unlike functions, these values cannot
# reflected on using the ast module (the inspect module can only
govereau marked this conversation as resolved.
Show resolved Hide resolved
# fetch sources for a limited number of types). This function
# provides an encoding for the global environment for a common
# set of types.

class Unsupported(Exception): pass

def encode_for_env(val):
match val:
case bool(b): return {'bool':b}
case int(i): return {'int':i}
case float(f): return {'float':f}
case str(s): return {'str':s}
case types.NoneType(): return None
case tuple(t): return {'tuple': list(map(encode_for_env, t))}
case list(l): return {'list': list(map(encode_for_env, l))}
case types.ModuleType(): return {'mod':val.__name__}
case np.ndarray():
return {
'tensor': {
'dtype': encode_for_env(str(val.dtype)),
'shape': encode_for_env(val.shape)
}
}
case _:
raise Unsupported(f"global value type: {val.__class__.__name__}")

class Parser(ast.NodeVisitor):
def __init__(self, f: types.FunctionType):
super().__init__()
self.workq = deque()
self.funcs = {}
self.globals = {}
self.args = []
self.kwargs = {}
self.entry = f.__module__ + "." + f.__name__
self.reference(self.entry, f)
self.ref_global(self.entry, f)
self.do_work()

def json(self):
d = { 'entry': self.entry
, 'funcs': self.funcs
, 'args' : self.args
, 'kwargs' : self.kwargs
, 'globals': self.globals
}
return json.dumps(d, cls=Enc)
Expand All @@ -63,31 +97,56 @@ def json(self):
def load(self):
py_to_lean(self.json())

def apply_args(self, *args, **kwargs):
self.args = []
self.kwargs = {}
d = {}
for arg in args:
self.reference(d, '_', arg)
try: self.args.append(d.popitem()[1])
except Exception:
raise Exception("Unsupported argument type")
for k,v in kwargs.items():
self.ref_arg(k, v)

def __call__(self, *args, **kwargs):
self.apply_args(*args, **kwargs)
py_to_lean(self.json())

def ref_arg(self, refname, val):
return self.reference(self.kwargs, refname, val)

def ref_global(self, refname, val):
return self.reference(self.globals, refname, val)

# resolve a reference: either populating the environment,
# or adding new items to the work queue
def reference(self, refname, val):
def reference(self, env, refname, val):
f = None
if isinstance(val, types.FunctionType):
f = val
val = f.__module__ + "." + f.__name__
elif isinstance(val, types.ModuleType):
val = val.__name__
fname = f.__module__ + "." + f.__name__
val = {'fun': fname}
else:
try: val = encode_for_env(val)
except Exception:
return

if refname in self.globals:
if val != self.globals[refname]:
if refname in env:
if val != env[refname]:
assert 0, "global mismatch"
else:
self.globals[refname] = val
env[refname] = val

if f is None:
return
try:
match ast.parse(dedent(inspect.getsource(f))):
case ast.Module([ast.FunctionDef(_, args, body)]):
self.workq.append((val, f, args, body))
self.workq.append((fname, f, args, body))
case _:
assert 0, "expecting function definition"
except Exception as e:
except Exception:
pass

def do_work(self):
Expand Down Expand Up @@ -121,9 +180,11 @@ def visit_Name(self, node):
return
try:
y = self.lookup(node.id)
self.reference(node.id, y)
self.ref_global(node.id, y)
return node.id, y
except Exception as e:
except Unsupported as e:
raise e
except Exception:
return

def visit_Attribute(self, node):
Expand All @@ -134,9 +195,11 @@ def visit_Attribute(self, node):
n, x = self.visit(node.value)
n = n + "." + node.attr
y = getattr(x, node.attr)
self.reference(n, y)
self.ref_global(n, y)
return n, y
except Exception as e:
except Unsupported as e:
raise e
except Exception:
return

def fun_defaults(self, f: types.FunctionType):
Expand All @@ -152,6 +215,6 @@ def is_ok(x):
if isinstance(x, types.FunctionType):
# TODO: this could be incorrect if default
# is using an alternate name for the function
self.reference(x.__name__, x)
self.ref_global(x.__name__, x)
return False
return { n:v for (n,v) in tbl.items() if is_ok(v) }
Loading