Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic parser for NKI kernels #4

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.lake/**
__pycache__/
4 changes: 1 addition & 3 deletions Export.lean
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,12 @@ def struct(name, args):
"

run_meta
let h <- IO.FS.Handle.mk "python/lean_types.py" IO.FS.Mode.write
let h <- IO.FS.Handle.mk "interop/nkl/lean_types.py" IO.FS.Mode.write
h.putStr header
flip List.forM (genPython h)
[ `NKL.Const
, `NKL.BinOp
, `NKL.Expr
, `NKL.Index
, `NKL.Stmt
, `NKL.Arg
, `NKL.Fun
]
10 changes: 8 additions & 2 deletions Main.lean
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
def main : IO Unit :=
IO.println s!"Hello, NKL!"
import NKL

def main (args : List String) : IO Unit :=
match args with
| .nil => IO.println s!"Hello, NKL!"
| .cons x _ => do
let s <- IO.FS.readFile x
NKL.parse_json s
5 changes: 2 additions & 3 deletions NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Paul Govereau
-/
import Lean
import NKL.NKI
import NKL.PrettyPrint

namespace NKL

Expand All @@ -17,6 +18,4 @@ def parse_json (json : String) : IO Unit := do
| .ok jsn => do
match Lean.fromJson? jsn with
| .error str => throw $ .userError str
| .ok (_:Fun) => do
IO.println "parse successsful"
return ()
| .ok (f:Fun) => print_nki f
42 changes: 21 additions & 21 deletions NKL/NKI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ Authors: Paul Govereau
import Lean

/-!
# Concrete Syntax of NKI kernels
# Syntax of NKI kernels

Representation of the "concrete" syntax of NKI kernels
Representation for the abstract syntax of NKI kernels
generated by the python frontend.
-/

Expand All @@ -22,43 +22,43 @@ inductive Const where
| string (value: String)
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

inductive BinOp where
| And | Or
| Eq | NotEq | Lt | LtE | Gt | GtE
| Add | Sub | Mul | Div
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

mutual
inductive Expr where
| value (c: Const)
| bvar (name: String)
| var (name: String)
| subscript (tensor: String) (ix: Array Index)
| binop (op: BinOp) (left right: Expr)
| call (f: String) (args: Array Expr)
| var (name value: String)
| subscript (tensor: Expr) (ix: List Index)
| binop (op: String) (left right: Expr)
| cond (e thn els: Expr)
| tuple (xs: List Expr)
| list (xs: List Expr)
| call (f: Expr) (args: List Expr)
| gridcall (f: Expr) (ix: List Index) (args: List Expr)
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

inductive Index where
| coord (i : Expr)
| slice (l u step: Expr)
| dots
deriving Repr, BEq, Lean.ToJson, Lean.FromJson
end

inductive Stmt where
| ret(e: Expr)
| assign (x: String) (e: Expr)
| ret (e: Expr)
| assign (x: Expr) (e: Expr)
| ifstm (e : Expr) (thn els: List Stmt)
| forloop (x: String) (iter: Expr) (body: List Stmt)
| gridcall (f: String) (ix: Array Index) (args: Array Expr)
| check (e : Expr)
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

structure Arg where
name : String
type : Option String := .none
value : Option Const := .none
deriving Repr, BEq, Lean.ToJson, Lean.FromJson
--structure Arg where
-- name : String
-- type : Option String := .none
-- value : Option Const := .none
-- deriving Repr, BEq, Lean.ToJson, Lean.FromJson

structure Fun where
name : String
args : Array Arg
args : List String
body : List Stmt
deriving Repr, BEq, Lean.ToJson, Lean.FromJson
68 changes: 68 additions & 0 deletions NKL/PrettyPrint.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/

import NKL.NKI

namespace NKL

instance : ToString Const where
govereau marked this conversation as resolved.
Show resolved Hide resolved
toString
| .nil => "None"
| .bool b => toString b
| .int i => toString i
| .float f => toString f
| .string s => s


mutual
private partial def exps_ s l := String.intercalate s (List.map expr l)
private partial def exps := exps_ ","
private partial def ndxs l := String.intercalate "," (List.map ndx l)

private partial def expr : Expr -> String
| .value c => toString c
| .bvar s | .var s _ => s
| .subscript e ix => expr e ++ "[" ++ ndxs ix ++ "]"
| .binop op l r => op ++ "(" ++ expr l ++ "," ++ expr r ++ ")"
| .cond e thn els => expr thn ++ " if " ++ expr e ++ " else " ++ expr els
| .tuple es => "(" ++ exps es ++ ")"
| .list es => "[" ++ exps es ++ "]"
| .call f es => expr f ++ "(" ++ exps es ++ ")"
| .gridcall f ix es => expr f ++ "[" ++ ndxs ix ++ "](" ++ exps es ++ ")"

private partial def ndx : Index -> String
| .coord e => expr e
| .slice l u s => exps_ ":" [l,u,s]
| .dots => "..."
end

instance : ToString Expr where
toString := expr

instance : ToString Index where
toString := ndx

mutual
private partial def stmts sp l :=
String.intercalate "\n" $ List.map (stmt sp) l

private partial def stmt (sp : String) (stmt : Stmt) : String :=
let stmts := stmts (sp ++ " ")
sp ++ match stmt with
| .ret e => s!"ret {e}"
| .assign x e => s!"{x} = {e}"
| .ifstm e thn els => s!"if ({e}):\n{stmts thn}¬{sp}else:\n{stmts els}"
| .forloop x e b => s!"for {x} in {expr e}:\n{stmts b}"
| .check e => "assert(" ++ expr e ++ ")"
end

instance : ToString Stmt where
toString := stmt ""

def print_nki (f : Fun) : IO Unit := do
IO.println $ f.name ++"("++ String.intercalate "," f.args ++")"
IO.println $ stmts " " f.body

15 changes: 15 additions & 0 deletions interop/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
__all__ = [
'getting_started',
'layout',
'index',
'mm',
'prof',
'average_pool',
'fused_mamba',
'layernorm',
'matmul',
'rmsnorm',
'sd_attention',
'tensor_addition',
'transpose2d',
]
66 changes: 66 additions & 0 deletions interop/examples/average_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved

NKI implementation for average pool 2D NKI tutorial.

"""
import numpy as np
import nki
import nki.language as nl

def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
"""NKI kernel to compute a 2D avg-pool operation

Args:
in_tensor: an input tensor, of shape C x H x W
pool_size: an integer representing a (square) pool-window size
out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
"""

# Get input/output dimensions
sz_cin, sz_hin, sz_win = in_tensor.shape
sz_cout, sz_hout, sz_wout = out_tensor.shape
assert sz_cin == sz_cout

# Set relevant sizes
sz_p = sz_cin
sz_pool = pool_size

# Generate tensor h/w index patterns
# 3D indexing according to [C, H, W]
i_p = nl.arange(sz_p)[:, None, None] # 3D for
i_win = nl.arange(sz_win)[None, None, :]
i_hin = nl.arange(sz_hin)[None, :, None]

i_wout = nl.arange(sz_wout)[None, None, :]
i_hout = nl.arange(sz_hout)[None, :, None]

# Generate pool index patterns (requires two extra dimensions, for the pool window)
i_0 = nl.arange(sz_p)[:, None, None, None, None] #
i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner

# Load input data from external memory to on-chip memory
# Declare ndarray to force a 3D tensor (temporary requirement)
in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])

# Perform the pooling operation:
# We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
# axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
# axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
# (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
# Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size)

# Store the results back to external memory
nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile)


# Reference NumPy implementation
def np_average_pool_2D(in_tensor, pool_size):
c, h_in, w_in = in_tensor.shape
reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size)
return np.nanmean(reshaped, axis=(2, 4))
Loading