Skip to content

Commit

Permalink
chore: port encoder to KLR
Browse files Browse the repository at this point in the history
The KLR language is the official intermediate format. This
patch ports the serialization and deserialization code to
this new data type.
  • Loading branch information
govereau committed Dec 30, 2024
1 parent 3b5c234 commit 3dcd559
Showing 1 changed file with 145 additions and 61 deletions.
206 changes: 145 additions & 61 deletions NKL/Encode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.NKI
import NKL.KLR

/-!
# Serialization and Deserialization
-/

namespace NKL
namespace NKL.KLR

-- All of the encode function are pure; decoding uses an instance of EStateM.

Expand All @@ -25,10 +25,10 @@ def decode (f : DecodeM a) (ba : ByteArray) : Except String a :=
| .error s _ => .error s

def decodeFile (f : DecodeM a) (path : System.FilePath) : IO a := do
let buf <- IO.FS.readBinFile path
match decode f buf with
| .ok x => return x
| .error s => throw $ IO.userError s
let buf <- IO.FS.readBinFile path
match decode f buf with
| .ok x => return x
| .error s => throw $ IO.userError s

private def next : DecodeM UInt8 := do
let it <- get
Expand Down Expand Up @@ -132,7 +132,6 @@ private def decString : DecodeM String := do

------------------------------------------------------------------------------
-- Lists are encoded as a length followed by a sequence of encoded values
-- TODO: not efficient

private def encList (f : a -> ByteArray) (l : List a) : ByteArray :=
let rec mapa : List a -> ByteArray
Expand All @@ -151,117 +150,202 @@ private def decList (f : DecodeM a) : DecodeM (List a) := do
#guard decode' (decList decInt) (encList encInt [1,2,3]) = some [1,2,3]

------------------------------------------------------------------------------
-- Finally, constants are encoded with a tag followed by the values
-- Options are encoded using a tag followed by the encoded value

private def tag (t : UInt8) : List ByteArray -> ByteArray :=
List.foldl ByteArray.append (.mk #[t])

private def encOption (f : a -> ByteArray) : Option a -> ByteArray
| .none => tag 0 []
| .some x => tag 1 [f x]

private def decOption (f : DecodeM a) : DecodeM (Option a) := do
match (<- next) with
| 0 => return .none
| 1 => f
| t => throw s!"invalid option tag {t}"

#guard decode' (decOption decInt) (encOption encInt none) = some none
#guard decode' (decOption decInt) (encOption encInt $ some 1) = some (some 1)

------------------------------------------------------------------------------
-- Constants are encoded with a tag followed by the values

def encConst : Const -> ByteArray
| .nil => tag 0x00 []
| .none => tag 0x00 []
| .bool false => tag 0x01 []
| .bool true => tag 0x02 []
| .int i => tag 0x03 [encInt i]
| .float f => tag 0x04 [encFloat f]
| .string s => tag 0x05 [encString s]
| .dots => tag 0x06 []

def decConst : DecodeM Const := do
let val <- next
match val with
| 0x00 => return .nil
| 0x00 => return .none
| 0x01 => return .bool false
| 0x02 => return .bool true
| 0x03 => return .int (<- decInt)
| 0x04 => return .float (<- decFloat)
| 0x05 => return .string (<- decString)
| 0x06 => return .dots
| _ => throw s!"Unknown Const tag value {val}"

private def chkConst (c: Const) : Bool :=
(decode' decConst $ encConst c) == some c

#guard chkConst .nil
#guard chkConst .none
#guard chkConst (.bool true)
#guard chkConst (.bool false)
#guard chkConst (.int 1)
#guard chkConst (.float 1.0)
#guard chkConst (.string "str")
#guard chkConst .dots

------------------------------------------------------------------------------
-- Affine Expressions

def encIndexExpr : IndexExpr -> ByteArray
| .var name => tag 0x10 [encString name]
| .int i => tag 0x11 [encInt i]
| .neg e => tag 0x12 [encIndexExpr e]
| .add l r => tag 0x13 [encIndexExpr l, encIndexExpr r]
| .mul i e => tag 0x14 [encInt i, encIndexExpr e]
| .floor e i => tag 0x15 [encIndexExpr e, encInt i]
| .ceil e i => tag 0x16 [encIndexExpr e, encInt i]
| .mod e i => tag 0x17 [encIndexExpr e, encInt i]

partial def decIndexExpr : DecodeM IndexExpr := do
match (<- next) with
| 0x10 => return .var (<- decString)
| 0x11 => return .int (<- decInt)
| 0x12 => return .neg (<- decIndexExpr)
| 0x13 => return .add (<- decIndexExpr) (<- decIndexExpr)
| 0x14 => return .mul (<- decInt) (<- decIndexExpr)
| 0x15 => return .floor (<- decIndexExpr) (<- decInt)
| 0x16 => return .ceil (<- decIndexExpr) (<- decInt)
| 0x17 => return .mod (<- decIndexExpr) (<- decInt)
| t => throw s!"Unknown tag in IndexExpr {t}"

private def chkIE (e: IndexExpr) : Bool :=
(decode' decIndexExpr $ encIndexExpr e) == some e

private def ie_var : IndexExpr := .var "s"

#guard chkIE (.var "v")
#guard chkIE (.int 1)
#guard chkIE (.neg ie_var)
#guard chkIE (.add ie_var ie_var)
#guard chkIE (.mul 2 ie_var)
#guard chkIE (.floor ie_var 2)
#guard chkIE (.ceil ie_var 2)
#guard chkIE (.mod ie_var 2)

def encIndex : Index -> ByteArray
| .ellipsis => tag 0x20 []
| .coord e => tag 0x21 [enc e]
| .range l u s => tag 0x22 [enc l, enc u, enc s]
where
enc := encOption encIndexExpr

def decIndex : DecodeM Index := do
match (<- next) with
| 0x20 => return .ellipsis
| 0x21 => return .coord (<- dec)
| 0x22 => return .range (<- dec) (<- dec) (<- dec)
| t => throw s!"Unknown tag in Index {t}"
where
dec:= decOption decIndexExpr

private def chkIndex (i : Index) : Bool :=
(decode' decIndex $ encIndex i) == some i

#guard chkIndex .ellipsis
#guard chkIndex (.coord none)
#guard chkIndex (.coord $ some ie_var)
#guard chkIndex (.range (some ie_var) none none)

------------------------------------------------------------------------------
-- Expressions

partial def encExpr : Expr -> ByteArray
| .value c => tag 0x10 [encConst c]
| .bvar s => tag 0x11 [encString s]
| .var s _ => tag 0x12 [encString s]
| .subscript e ix => tag 0x13 [encExpr e, encList encExpr ix]
| .slice l u step => tag 0x14 [encExpr l, encExpr u, encExpr step]
| .binop op l r => tag 0x15 [encString op, encExpr l, encExpr r]
| .cond c t e => tag 0x16 [encExpr c, encExpr t, encExpr e]
| .tuple es => tag 0x17 [encList encExpr es]
| .list es => tag 0x18 [encList encExpr es]
| .call f ax => tag 0x19 [encExpr f, encList encExpr ax]
| .gridcall f ix ax => tag 0x1a [encExpr f, encList encExpr ix, encList encExpr ax]
| .var s => tag 0x30 [encString s]
| .tensor t s => tag 0x31 [encString t, encList encInt s]
| .const c => tag 0x32 [encConst c]
| .tuple es => tag 0x33 [encList encExpr es]
| .list es => tag 0x34 [encList encExpr es]
| .access e ix => tag 0x35 [encExpr e, encList encIndex ix]
| .binop op l r => tag 0x36 [encString op, encExpr l, encExpr r]
| .unop op e => tag 0x37 [encString op, encExpr e]
| .call f ax kw => tag 0x38 [encExpr f, encList encExpr ax, encList encKeyword kw]
where
encKeyword : String × Expr -> ByteArray
| (key, expr) => (encString key).append (encExpr expr)

partial def decExpr : DecodeM Expr := do
match (<- next) with
| 0x10 => return .value (<- decConst)
| 0x11 => return .bvar (<- decString)
| 0x12 => return .var (<- decString) ""
| 0x13 => return .subscript (<- decExpr) (<- decList decExpr)
| 0x14 => return .slice (<- decExpr) (<- decExpr) (<- decExpr)
| 0x15 => return .binop (<- decString) (<- decExpr) (<- decExpr)
| 0x16 => return .cond (<- decExpr) (<- decExpr) (<- decExpr)
| 0x17 => return .tuple (<- decList decExpr)
| 0x18 => return .list (<- decList decExpr)
| 0x19 => return .call (<- decExpr) (<- decList decExpr)
| 0x1a => return .gridcall (<- decExpr) (<- decList decExpr) (<- decList decExpr)
| 0x30 => return .var (<- decString)
| 0x31 => return .tensor (<- decString) (<- decList decInt)
| 0x32 => return .const (<- decConst)
| 0x33 => return .tuple (<- decList decExpr)
| 0x34 => return .list (<- decList decExpr)
| 0x35 => return .access (<- decExpr) (<- decList decIndex)
| 0x36 => return .binop (<- decString) (<- decExpr) (<- decExpr)
| 0x37 => return .unop (<- decString) (<- decExpr)
| 0x38 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword)
| t => throw s!"Unknown tag in Expr {t}"
where
decKeyword : DecodeM (String × Expr) :=
return ((<- decString), (<- decExpr))

private def chkExpr (e : Expr) : Bool :=
(decode' decExpr $ encExpr e) == some e

private def nil := Expr.value .nil
private def nil := Expr.const .none
private def ixz := Index.coord (IndexExpr.int 0)

#guard chkExpr nil
#guard chkExpr (.bvar "var")
#guard chkExpr (.var "var" "")
#guard chkExpr (.subscript nil [nil, nil, nil])
#guard chkExpr (.slice nil nil nil)
#guard chkExpr (.binop "op" nil nil)
#guard chkExpr (.cond nil nil nil)
#guard chkExpr (.var "var")
#guard chkExpr (.tensor "float32" [1,2,3])
#guard chkExpr (.const (.int 1))
#guard chkExpr (.tuple [nil, nil, nil])
#guard chkExpr (.list [nil, nil, nil])
#guard chkExpr (.call nil [nil, nil, nil])
#guard chkExpr (.gridcall nil [nil, nil, nil] [nil, nil, nil])
#guard chkExpr (.access nil [ixz, ixz, ixz])
#guard chkExpr (.binop "op" nil nil)
#guard chkExpr (.unop "op" nil)
#guard chkExpr (.call nil [nil, nil, nil] [("a", nil), ("b", nil)])

------------------------------------------------------------------------------
-- Statements

partial def encStmt : Stmt -> ByteArray
| .ret e => tag 0x30 [encExpr e]
| .assign x e => tag 0x31 [encExpr x, encExpr e]
| .ifstm c t e => tag 0x32 [encExpr c, encList encStmt t, encList encStmt e]
| .forloop x e b => tag 0x33 [encString x, encExpr e, encList encStmt b]
| .check e => tag 0x34 [encExpr e]
| .pass => tag 0x40 []
| .expr e => tag 0x41 [encExpr e]
| .ret e => tag 0x42 [encExpr e]
| .assign x e => tag 0x43 [encString x, encExpr e]
| .loop x l u step body =>
tag 0x44 [ encString x,
encIndexExpr l, encIndexExpr u, encIndexExpr step,
encList encStmt body ]

partial def decStmt : DecodeM Stmt := do
match (<- next) with
| 0x30 => return .ret (<- decExpr)
| 0x31 => return .assign (<- decExpr) (<- decExpr)
| 0x32 => return .ifstm (<- decExpr) (<- decList decStmt) (<- decList decStmt)
| 0x33 => return .forloop (<- decString) (<- decExpr) (<- decList decStmt)
| 0x34 => return .check (<- decExpr)
| 0x40 => return .pass
| 0x41 => return .expr (<- decExpr)
| 0x42 => return .ret (<- decExpr)
| 0x43 => return .assign (<- decString) (<- decExpr)
| 0x44 => do
let x <- decString
let l <- decIndexExpr
let u <- decIndexExpr
let step <- decIndexExpr
let body <- decList decStmt
return .loop x l u step body
| t => throw s!"Unknown tag in Stmt {t}"

private def chkStmt (s : Stmt) : Bool :=
(decode' decStmt $ encStmt s) == some s

private def stm := Stmt.check nil

#guard chkStmt .pass
#guard chkStmt (.expr nil)
#guard chkStmt (.ret nil)
#guard chkStmt (.assign nil nil)
#guard chkStmt (.ifstm nil [stm, stm, stm] [stm, stm, stm])
#guard chkStmt (.forloop "x" nil [stm, stm, stm])
#guard chkStmt (.check nil)
#guard chkStmt (.assign "x" nil)
#guard chkStmt (.loop "x" ie_var ie_var ie_var [.pass, .pass])

0 comments on commit 3dcd559

Please sign in to comment.