Skip to content

Commit

Permalink
feat: add global references and arguments
Browse files Browse the repository at this point in the history
Add ability to capture global references, including function
arguments when parsing kernels. Globals and arguments are directly
encoded as KLR terms.
  • Loading branch information
govereau committed Dec 28, 2024
1 parent 5c76ab9 commit 940909e
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 49 deletions.
1 change: 1 addition & 0 deletions NKL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ Authors: Paul Govereau
-/
import NKL.Encode
import NKL.FFI
import NKL.KLR
import NKL.NKI
import NKL.Python
12 changes: 4 additions & 8 deletions NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean
import NKL.NKI
import NKL.PrettyPrint
import NKL.Python

namespace NKL
Expand All @@ -15,15 +13,13 @@ local instance : MonadLift (Except String) IO where
| .ok x => return x
| .error s => throw $ .userError s

@[export parse_json_old]
def parse_json_old (json : String) : IO Unit := do
let jsn <- Lean.Json.parse json
let f:Fun <- Lean.fromJson? jsn
print_nki f

@[export parse_json]
def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let names := kernel.funcs.map fun x => x.fst
let names := String.intercalate "," names
IO.println s!"Found functions: {names}"
for x in kernel.args do
IO.println s!"arg: {repr x}"
for x in kernel.globals do
IO.println s!"global: {repr x}"
79 changes: 79 additions & 0 deletions NKL/KLR.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/


/-!
# Abstract syntax of Core NKL language
This language is the result of "tracing", and is used as the
portable format, a.k.a. Kernel Language Representation (KLR).
-/

namespace NKL.KLR

inductive Ty where

inductive Const where
| none
| bool (value : Bool)
| int (value : Int)
| float (value : Float)
| string (value : String)
deriving Repr

inductive IndexExpr where
| var (name : String)
| int (i : Int)
| neg (expr : IndexExpr)
| add (left right : IndexExpr)
| mul (scalar : Int) (expr : IndexExpr)
| floor (expr : IndexExpr) (scalar : Int)
| ceil (expr : IndexExpr) (scalar : Int)
| mod (expr : IndexExpr) (scalar : Int)
deriving Repr

inductive Index where
| ellipsis
| coord (e : Option IndexExpr)
| range (l u step : Option IndexExpr)
deriving Repr

inductive Expr where
| var : String -> Expr
| const : Const -> Expr
| tuple : List Expr -> Expr
| list : List Expr -> Expr
| access : Expr -> List Index -> Expr
| binop (op : String) (left right : Expr)
| unop (op : String) (e : Expr)
| call (f : Expr) (args : List Expr) (keywords : List (String × Expr))
deriving Repr

inductive Stmt where
| pass
| expr (v : Expr)
| ret (v : Expr)
| assign (x : String) (e : Expr)
deriving Repr


-- Python-like rules for conversion to boolean
def Const.isTrue : Const -> Bool
| .none => false
| .bool b => b
| .int i => i != 0
| .float f => f != 0.0
| .string s => s != ""

-- TODO: Just place-holders for now
def Expr.toAffine : Expr -> Except String IndexExpr
| .var v => return .var v
| .const (.int i) => return .int i
| e => throw s!"toAffine unimp {repr e}"

def Expr.simplify : Expr -> Expr :=
fun x => x

82 changes: 56 additions & 26 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import Lean

/-!
Expand All @@ -15,25 +16,24 @@ see: https://docs.python.org/3/library/ast.html
namespace NKL
namespace Python

deriving instance Repr for Lean.JsonNumber

structure Pos where
lineno : Nat
lineno : Nat := 0
end_lineno : Nat := 0
col_offset : Nat := 0
end_col_offset : Nat := 0
deriving Repr

inductive Const where
| none
| bool (value: Bool)
| num (value: Lean.JsonNumber)
| string (value: String)
| bool (value : Bool)
| int (value : Int)
| float (value : Float)
| string (value : String)
| ellipsis
deriving Repr

-- We don't need these, but we preserve them to make copying
-- the Python AST over easier.
-- This context comes from the Python AST.
-- The store hint is used by the tracing implementation for simplicity.
inductive Ctx where
| load | store | del
deriving Repr
Expand Down Expand Up @@ -111,6 +111,13 @@ structure Args where
kwarg : Option String
deriving Repr

def Args.names (ax : Args) : List String :=
let xs := ax.posonlyargs.append ax.args
let xs := match ax.vararg with | none => xs | some x => xs.append [x]
let xs := xs.append ax.kwonlyargs
let xs := match ax.kwarg with | none => xs | some x => xs.append [x]
xs

/-
In addition to the defaults above from the AST, we also collect
the values from f.__defaults__ here in the Fun structure. These
Expand All @@ -124,10 +131,19 @@ structure Fun where
body: List Stmt
deriving Repr

/-
A kernel is collection of:
- functions, including the primary function and any functions
called by the primary func
- arguments to the primary function, the positional arguments
are under the key "*"
- global variables referenced by any of the functions
-/
structure Kernel where
entry : String
funcs : List (String × Fun)
globals : List (String × Option String)
args : List (String × KLR.Expr)
globals : List (String × KLR.Expr)

-------------------------------------------------------------------------------
-- Converting Python AST from Json
Expand Down Expand Up @@ -191,29 +207,30 @@ private def withPos (p : String -> Json -> Parser b) (f : b -> Pos -> a) : Json
return (f exp pos)
| _ => throw "expecting object"

def genError (source err : String) (pos : Pos) : String :=
let lines := source.splitOn "\n"
let lineno := pos.lineno - 1
let colno := pos.col_offset
let line := if lines.length < lineno
then "<source not available>"
else lines[lineno]!
let indent := (Nat.repeat (List.cons ' ') colno List.nil).asString
s!"line {lineno}:\n{line}\n{indent}^-- {err}"

private def withSrc (source : String) (p : Parser a) : Parser a :=
try set { lineno := 0 : Pos } ; p
catch e => get >>= throw ∘ genError e
where
genError (err : String) (pos : Pos) : String :=
let lines := source.splitOn "\n"
let lineno := pos.lineno - 1
let colno := pos.col_offset
let line := if lines.length < lineno
then "<source not available>"
else lines[lineno]!
let indent := (Nat.repeat (List.cons ' ') colno List.nil).asString
s!"line {lineno}:\n{line}\n{indent}^-- {err}"
catch e => get >>= throw ∘ genError source e

-------------------------------------------------------------------------------
-- Python AST Json objects

def const : Json -> Parser Const
| .null => return .none
| .bool b => return (.bool b)
| .num jn => return (.num jn)
| .bool b => return .bool b
| .num { mantissa := m, exponent := 0 } => return .int m
| .num jn => return .float jn.toFloat
| .str "..." => return .ellipsis
| .str s => return (.string s)
| .str s => return .string s
| _ => throw "expecting constant"

def exprCtx : Json -> Parser Ctx
Expand Down Expand Up @@ -302,14 +319,27 @@ def function (j : Json) : Parser Fun := do
let body <- field (list stmt) j "body"
return Fun.mk source args defaults body

partial def global : Json -> Parser KLR.Expr
| .null => return .const .none
| .obj (.node _ _ "fun" (.str s) _) => return .var s
| .obj (.node _ _ "mod" (.str s) _) => return .var s
| .obj (.node _ _ "bool" (.bool b) _) => return .const (.bool b)
| .obj (.node _ _ "float" (.num n) _) => return .const (.float n.toFloat)
| .obj (.node _ _ "int" (.num ⟨m,0⟩) _) => return .const (.int m)
| .obj (.node _ _ "str" (.str s) _) => return .const (.string s)
| .obj (.node _ _ "tuple" (.arr arr) _) => return .tuple (<- arr.toList.mapM global)
| .obj (.node _ _ "list" (.arr arr) _) => return .list (<- arr.toList.mapM global)
| j => throw s!"malformed global environment '{j}'"

def kernel (j : Json) : Parser Kernel := do
let name <- field str j "entry"
let funcs <- field (dict function) j "funcs"
let globals <- field (dict (opt str)) j "globals"
return Kernel.mk name funcs globals
let args <- field (dict global) j "args"
let globals <- field (dict global) j "globals"
return Kernel.mk name funcs args globals

def parse (s : String) : Except String Kernel := do
let jsn <- Json.parse s
match kernel jsn { lineno := 0 } with
match kernel jsn {} with
| .ok x _ => .ok x
| .error s _ => .error s
Loading

0 comments on commit 940909e

Please sign in to comment.