Skip to content

Commit

Permalink
feat: simple binary serialization format
Browse files Browse the repository at this point in the history
Add support for binary serialization and deserialization,
using a simple encoding. This encoding should be more stable
and easier to parse than JSON.
  • Loading branch information
govereau committed Nov 14, 2024
1 parent f793a6f commit 0cc3123
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 2 deletions.
1 change: 1 addition & 0 deletions NKL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Authors: Paul Govereau
-/
import NKL.NKI
import NKL.FFI
import NKL.Encode
331 changes: 331 additions & 0 deletions NKL/Encode.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.NKI

/-!
# Serialization and Deserialization
-/

namespace NKL

-- All of the encode function are pure; decoding uses an instance of EStateM.

abbrev DecodeM := EStateM String ByteArray.Iterator

def decode' (f : DecodeM a) (ba : ByteArray) : Option a :=
EStateM.run' f ba.iter

def decode (f : DecodeM a) (ba : ByteArray) : Except String a :=
match EStateM.run f ba.iter with
| .ok x _ => .ok x
| .error s _ => .error s

def decodeFile (f : DecodeM a) (path : System.FilePath) : IO a := do
let buf <- IO.FS.readBinFile path
match decode f buf with
| .ok x => return x
| .error s => throw $ IO.userError s

private def next : DecodeM UInt8 := do
let it <- get
if it.atEnd then throw "unexpected EOF"
let b := it.curr
set it.next
return b

-- TODO: this could be more efficient if we complicate the monad state
private def take (n : Nat) : DecodeM ByteArray :=
let rec f (ba : ByteArray) (n : Nat) : DecodeM ByteArray :=
match n with
| .zero => return ba
| .succ n => next >>= fun b => f (ba.push b) n
f (.mk #[]) n

------------------------------------------------------------------------------
-- Nats are encoded as u64, little-endian
-- TODO: if the nat is too large, it will be silently truncated.

private def encNat (n : Nat) : ByteArray := .mk
#[ UInt8.ofNat $ n >>> 0
, UInt8.ofNat $ n >>> 8
, UInt8.ofNat $ n >>> 16
, UInt8.ofNat $ n >>> 24
, UInt8.ofNat $ n >>> 32
, UInt8.ofNat $ n >>> 40
, UInt8.ofNat $ n >>> 48
, UInt8.ofNat $ n >>> 56
]

private def decNat : DecodeM Nat := do
let nat8 : DecodeM Nat := next >>= fun x => return x.toNat
return (<- nat8) <<< 0 |||
(<- nat8) <<< 8 |||
(<- nat8) <<< 16 |||
(<- nat8) <<< 24 |||
(<- nat8) <<< 32 |||
(<- nat8) <<< 40 |||
(<- nat8) <<< 48 |||
(<- nat8) <<< 56

#guard decode' decNat (.mk #[1,2,3,4,5,6,7]) == none
#guard decode' decNat (.mk #[0,1,0,0,0,0,0,0]) == some 256
#guard decode' decNat (encNat 0) == some 0
#guard decode' decNat (encNat 1) == some 1
#guard decode' decNat (encNat 256) == some 256

------------------------------------------------------------------------------
-- Ints are encoded as a sign tag followed by a Nat
-- TODO: Int64 should be in standard library soon, revisit when available

private def tag (t : UInt8) : List ByteArray -> ByteArray :=
List.foldl ByteArray.append (.mk #[t])

private def encInt (i : Int) : ByteArray :=
match i with
| .ofNat n => tag 0 [encNat n]
| .negSucc n => tag 1 [encNat n]

private def decInt : DecodeM Int := do
let sign <- next
let n <- decNat
bif sign == 0 then
return (.ofNat n)
else
return (.negSucc n)

private def decInt' : DecodeM (Bool × Nat) := do
let sign <- next
let n <- decNat
return (sign == 1, n + sign.toNat)

local instance : BEq ByteArray where
beq a b := a.data == b.data

#guard encInt 1 == .mk #[0, 1, 0, 0, 0, 0, 0, 0, 0]
#guard encInt 0 == .mk #[0, 0, 0, 0, 0, 0, 0, 0, 0]
#guard encInt (-1) == .mk #[1, 0, 0, 0, 0, 0, 0, 0, 0]
#guard encInt (-2) == .mk #[1, 1, 0, 0, 0, 0, 0, 0, 0]

#guard decode' decInt (encInt 1) == some 1
#guard decode' decInt (encInt 0) == some 0
#guard decode' decInt (encInt $ -1) == some (-1)

#guard decode' decInt' (encInt 1) == some (false, 1)
#guard decode' decInt' (encInt 0) == some (false, 0)
#guard decode' decInt' (encInt (-0)) == some (false, 0)
#guard decode' decInt' (encInt $ -1) == some (true, 1)

------------------------------------------------------------------------------
-- Float values are encoded as (mantissa, exponent), where the
-- mantissa and exponenet are integers.
-- TODO:
-- I am keeping the encode functions pure; encFloat can fail,
-- but this seems unlikely as the Lean parser would have to not understand
-- the result of toString on a Float value

private def encFloat (f : Float) : ByteArray :=
let str := (toString f).toSubstring
let one : String.Pos := { byteIdx := 1 }
let sgn := str.get 0 == '-'
let str := bif sgn then { str with startPos := one } else str
match Lean.Syntax.decodeScientificLitVal? str.toString with
| none => panic! s!"Failed to parse {toString f}"
| some (m, esn, e) =>
let m := toInt sgn m
let e := toInt esn e
(encInt m).append (encInt e)
where
-- this is to handle negative zero
toInt : Bool -> Nat -> Int := fun
| false, n => .ofNat n
| true, 0 => .ofNat 0
| true, n => .negSucc (n-1)

private def decFloat : DecodeM Float := do
let (sgn,m) <- decInt'
let (esn,e) <- decInt'
let f := Float.ofScientific m esn e
return bif sgn then -f else f

#guard decode' decFloat (encFloat 0) == some 0
#guard decode' decFloat (encFloat 0.5) == some 0.5
#guard decode' decFloat (encFloat 1.0) == some 1.0
#guard decode' decFloat (encFloat 1.5) == some 1.5
#guard decode' decFloat (encFloat (-0)) == some (0.0)
#guard decode' decFloat (encFloat (-0.5)) == some (-0.5)
#guard decode' decFloat (encFloat (-1.0)) == some (-1.0)
#guard decode' decFloat (encFloat (-1.5)) == some (-1.5)

------------------------------------------------------------------------------
-- Strings are encoded as a length followed by a sequence of UTF8 bytes

private def encString (s : String) : ByteArray :=
let ba := s.toUTF8
(encNat ba.size).append ba

private def decString : DecodeM String := do
let len <- decNat
let ba <- take len
match String.fromUTF8? ba with
| none => throw "invalid UTF8 string"
| some s => return s

#guard decode' decString (encString "") = some ""
#guard decode' decString (encString "Hello World") = some "Hello World"
#guard decode' decString (encString "Hello • World") = some "Hello • World"

------------------------------------------------------------------------------
-- Lists are encoded as a length followed by a sequence of encoded values
-- TODO: not efficient

private def encList (f : a -> ByteArray) (l : List a) : ByteArray :=
let rec mapa : List a -> ByteArray
| .nil => .mk #[]
| .cons x l => (f x).append (mapa l)
(encNat l.length).append (mapa l)

private def decList (f : DecodeM a) : DecodeM (List a) := do
let rec gena : Nat -> DecodeM (List a)
| .zero => return []
| .succ n => return (<- f) :: (<- gena n)
gena (<- decNat)

-- TODO: there are no types in serialized format, do we need that?
#guard decode' (decList decNat) (encList encFloat []) = some []
#guard decode' (decList decInt) (encList encInt [1,2,3]) = some [1,2,3]

------------------------------------------------------------------------------
-- Finally, constants are encoded with a tag followed by the values

def encConst : Const -> ByteArray
| .nil => tag 0x00 []
| .bool false => tag 0x01 []
| .bool true => tag 0x02 []
| .int (.ofNat n) => tag 0x03 [encNat n]
| .int (.negSucc n) => tag 0x04 [encNat n]
| .float f => tag 0x05 [encFloat f]
| .string s => tag 0x06 [encString s]

def decConst : DecodeM Const := do
let val <- next
match val with
| 0x00 => return .nil
| 0x01 => return .bool false
| 0x02 => return .bool true
| 0x03 => return .int $ .ofNat (<- decNat)
| 0x04 => return .int $ .negSucc (<- decNat)
| 0x05 => return .float (<- decFloat)
| 0x06 => return .string (<- decString)
| _ => throw s!"Unknown Const tag value {val}"

private def chkConst (c: Const) : Bool :=
(decode' decConst $ encConst c) == some c

#guard chkConst .nil
#guard chkConst (.bool true)
#guard chkConst (.bool false)
#guard chkConst (.int 1)
#guard chkConst (.float 1.0)
#guard chkConst (.string "str")

------------------------------------------------------------------------------
-- Expressions

mutual
partial def encExpr : Expr -> ByteArray
| .value c => tag 0x10 [encConst c]
| .bvar s => tag 0x11 [encString s]
| .var s _ => tag 0x12 [encString s]
| .subscript e ix => tag 0x13 [encExpr e, encList encIndex ix]
| .binop op l r => tag 0x14 [encString op, encExpr l, encExpr r]
| .cond c t e => tag 0x15 [encExpr c, encExpr t, encExpr e]
| .tuple es => tag 0x16 [encList encExpr es]
| .list es => tag 0x17 [encList encExpr es]
| .call f ax => tag 0x18 [encExpr f, encList encExpr ax]
| .gridcall f ix ax => tag 0x19 [encExpr f, encList encIndex ix, encList encExpr ax]

partial def encIndex : Index -> ByteArray
| .coord i => tag 0x20 [encExpr i]
| .slice l u step => tag 0x21 [encExpr l, encExpr u, encExpr step]
| .dots => tag 0x22 []
end

mutual
partial def decExpr : DecodeM Expr := do
match (<- next) with
| 0x10 => return .value (<- decConst)
| 0x11 => return .bvar (<- decString)
| 0x12 => return .var (<- decString) ""
| 0x13 => return .subscript (<-decExpr) (<- decList decIndex)
| 0x14 => return .binop (<- decString) (<- decExpr) (<- decExpr)
| 0x15 => return .cond (<- decExpr) (<- decExpr) (<- decExpr)
| 0x16 => return .tuple (<- decList decExpr)
| 0x17 => return .list (<- decList decExpr)
| 0x18 => return .call (<- decExpr) (<- decList decExpr)
| 0x19 => return .gridcall (<- decExpr) (<- decList decIndex) (<- decList decExpr)
| t => throw s!"Unknown tag in Expr {t}"

partial def decIndex : DecodeM Index := do
match (<- next) with
| 0x20 => return .coord (<- decExpr)
| 0x21 => return .slice (<- decExpr) (<- decExpr) (<- decExpr)
| 0x22 => return .dots
| t => throw s!"Unknown tag in Index {t}"
end

private def chkExpr (e : Expr) : Bool :=
(decode' decExpr $ encExpr e) == some e
private def chkIndex (i : Index) : Bool :=
(decode' decIndex $ encIndex i) == some i

private def nil := Expr.value .nil
private def ndx := Index.coord nil

#guard chkExpr nil
#guard chkExpr (.bvar "var")
#guard chkExpr (.var "var" "")
#guard chkExpr (.subscript nil [ndx, ndx, ndx])
#guard chkExpr (.binop "op" nil nil)
#guard chkExpr (.cond nil nil nil)
#guard chkExpr (.tuple [nil, nil, nil])
#guard chkExpr (.list [nil, nil, nil])
#guard chkExpr (.call nil [nil, nil, nil])
#guard chkExpr (.gridcall nil [ndx, ndx, ndx] [nil, nil, nil])

#guard chkIndex ndx
#guard chkIndex (.slice nil nil nil)
#guard chkIndex .dots

------------------------------------------------------------------------------
-- Statements

partial def encStmt : Stmt -> ByteArray
| .ret e => tag 0x30 [encExpr e]
| .assign x e => tag 0x31 [encExpr x, encExpr e]
| .ifstm c t e => tag 0x32 [encExpr c, encList encStmt t, encList encStmt e]
| .forloop x e b => tag 0x33 [encString x, encExpr e, encList encStmt b]
| .check e => tag 0x34 [encExpr e]

partial def decStmt : DecodeM Stmt := do
match (<- next) with
| 0x30 => return .ret (<- decExpr)
| 0x31 => return .assign (<- decExpr) (<- decExpr)
| 0x32 => return .ifstm (<- decExpr) (<- decList decStmt) (<- decList decStmt)
| 0x33 => return .forloop (<- decString) (<- decExpr) (<- decList decStmt)
| 0x34 => return .check (<- decExpr)
| t => throw s!"Unknown tag in Stmt {t}"

private def chkStmt (s : Stmt) : Bool :=
(decode' decStmt $ encStmt s) == some s

private def stm := Stmt.check nil

#guard chkStmt (.ret nil)
#guard chkStmt (.assign nil nil)
#guard chkStmt (.ifstm nil [stm, stm, stm] [stm, stm, stm])
#guard chkStmt (.forloop "x" nil [stm, stm, stm])
#guard chkStmt (.check nil)
2 changes: 1 addition & 1 deletion lake-manifest.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{"version": "1.1.0",
"packagesDir": ".lake/packages",
"packages": [],
"name": "ANPU",
"name": "NKL",
"lakeDir": ".lake"}
2 changes: 1 addition & 1 deletion lean-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
leanprover/lean4:v4.13.0-rc4
leanprover/lean4:v4.13.0

0 comments on commit 0cc3123

Please sign in to comment.