Skip to content

Commit

Permalink
apply feedback + add test cases + update changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
lehugueni committed Dec 17, 2024
1 parent b67c72b commit 3a52e44
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 98 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Changelog
All notable changes to this library are documented in this file.

## [6.x.x] - 16.12.2024
- Refactoring of the InnerSum methods:
- `rlwe.Evaluator.InnerSum` has been replaced by `rlwe.Evaluator.PartialTrace`
- Introduction of the `bgv.Evaluator.InnerSum` and `ckks.Evaluator.InnerSum` methods, which have the same behaviour as the old `InnerSum` method for parameters `n` and `batchSize` s.t. `n*batchSize` divides the number of slots. Parameters not satisfying this condition are rejected.
- Introduction of the `bgv.Evaluator.RotateAndAdd` and `ckks.Evaluator.RotateAndAdd` methods, which have the same behaviour as the old `InnerSum` method for all parameters.

## [6.1.0] - 04.10.2024
- Update of `PrecisionStats` in `ckks/precision.go`:
- The precision is now computed as the min/max/average/... of the log of the error (instead of the log of the min/max/average/... of the error).
Expand Down
33 changes: 10 additions & 23 deletions core/rlwe/inner_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,13 @@ func GaloisElementsForTrace(params ParameterProvider, logN int) (galEls []uint64
return
}

// InnerSum applies an optimized inner sum on the Ciphertext (log2(n) + HW(n) rotations with double hoisting).
// The operation assumes that `ctIn` encrypts Slots/`batchSize` sub-vectors of size `batchSize` and will add them together (in parallel) in groups of `n`.
// It outputs in opOut a [Ciphertext] for which the "leftmost" sub-vector of each group is equal to the sum of the group.
//
// The inner sum is computed in a tree fashion. Example for batchSize=2 & n=4 (garbage slots are marked by 'x'):
//
// 1. [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
//
// 2. [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
// +
// [{c, d}, {e, f}, {g, h}, {x, x}, {c, d}, {e, f}, {g, h}, {x, x}] (rotate batchSize * 2^{0})
// =
// [{a+c, b+d}, {x, x}, {e+g, f+h}, {x, x}, {a+c, b+d}, {x, x}, {e+g, f+h}, {x, x}]
//
// 3. [{a+c, b+d}, {x, x}, {e+g, f+h}, {x, x}, {a+c, b+d}, {x, x}, {e+g, f+h}, {x, x}] (rotate batchSize * 2^{1})
// +
// [{e+g, f+h}, {x, x}, {x, x}, {x, x}, {e+g, f+h}, {x, x}, {x, x}, {x, x}] =
// =
// [{a+c+e+g, b+d+f+h}, {x, x}, {x, x}, {x, x}, {a+c+e+g, b+d+f+h}, {x, x}, {x, x}, {x, x}]
func (eval Evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, opOut *Ciphertext) (err error) {
// PartialTrace applies a partial trace on the input ciphertext with the automorphisms phi(i*offset, X), 0 <= i < n, where phi(k, X): X -> X^{5^k}
// i.e. opOut = \sum_{i = 0}^{n-1} phi(i*offset, ctIn).
// At the scheme level, this function is used to perform inner sums or efficiently replicate slots.
func (eval Evaluator) PartialTrace(ctIn *Ciphertext, offset, n int, opOut *Ciphertext) (err error) {
if n == 0 || offset == 0 {
return fmt.Errorf("partialtrace: invalid parameter (n = 0 or batchSize = 0)")
}

params := eval.GetRLWEParameters()

Expand Down Expand Up @@ -236,7 +223,7 @@ func (eval Evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, opOut *Cipher
if j&1 == 1 {

k := n - (n & ((2 << i) - 1))
k *= batchSize
k *= offset

// If the rotation is not zero
if k != 0 {
Expand Down Expand Up @@ -281,7 +268,7 @@ func (eval Evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, opOut *Cipher

if !state {

rot := params.GaloisElement((1 << i) * batchSize)
rot := params.GaloisElement((1 << i) * offset)

// ctInNTT = ctInNTT + Rotate(ctInNTT, 2^i)
if err = eval.AutomorphismHoisted(levelQ, ctInNTT, eval.BuffDecompQP, rot, cQ); err != nil {
Expand Down Expand Up @@ -486,7 +473,7 @@ func GaloisElementsForInnerSum(params ParameterProvider, batch, n int) (galEls [
// two consecutive sub-vectors to replicate.
// This method is faster than Replicate when the number of rotations is large and it uses log2(n) + HW(n) instead of n.
func (eval Evaluator) Replicate(ctIn *Ciphertext, batchSize, n int, opOut *Ciphertext) (err error) {
return eval.InnerSum(ctIn, -batchSize, n, opOut)
return eval.PartialTrace(ctIn, -batchSize, n, opOut)
}

// GaloisElementsForReplicate returns the list of Galois elements necessary to perform the
Expand Down
5 changes: 3 additions & 2 deletions core/rlwe/rlwe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ func testSlotOperations(tc *TestContext, level, bpw2 int, t *testing.T) {
enc := tc.enc
dec := tc.dec

t.Run(testString(params, level, params.MaxLevelP(), bpw2, "Evaluator/InnerSum"), func(t *testing.T) {
t.Run(testString(params, level, params.MaxLevelP(), bpw2, "Evaluator/PartialTrace"), func(t *testing.T) {

if params.MaxLevelP() == -1 {
t.Skip("test requires #P > 0")
Expand All @@ -1095,14 +1095,15 @@ func testSlotOperations(tc *TestContext, level, bpw2 int, t *testing.T) {
ringQ := tc.params.RingQ().AtLevel(level)

pt := genPlaintext(params, level, 1<<30)
pt.LogDimensions = ring.Dimensions{Rows: 1, Cols: params.logN - 1}
ptInnerSum := *pt.Value.CopyNew()
ct, err := enc.EncryptNew(pt)
require.NoError(t, err)

// Galois Keys
evk := NewMemEvaluationKeySet(nil, kgen.GenGaloisKeysNew(GaloisElementsForInnerSum(params, batch, n), sk)...)

require.NoError(t, eval.WithKey(evk).InnerSum(ct, batch, n, ct))
require.NoError(t, eval.WithKey(evk).PartialTrace(ct, batch, n, ct))

dec.Decrypt(ct, pt)

Expand Down
42 changes: 35 additions & 7 deletions examples/singleparty/tutorials/ckks/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,13 +590,13 @@ func main() {

// The `circuits/lintrans` package provides a multiple handy linear transformations.
// We will start with the inner sum.
// Thus method allows to aggregate `n` sub-vectors of size `batch`.
// For example given a vector [x0, x1, x2, x3, x4, x5, x6, x7], batch = 2 and n = 3
// it will return the vector [x0+x2+x4, x1+x3+x5, x2+x4+x6, x3+x5+x7, x4+x6+x0, x5+x7+x1, x6+x0+x2, x7+x1+x3]
// Observe that the inner sum wraps around the vector, this behavior must be taken into account.
// This method allows to aggregate `n` sub-vectors of size `batch` and it stores the result in the leftmost sub-vector of each "group".
// For example given a vector [x0, x1, x2, x3, x4, x5, x6, x7], batch = 2 and n = 4
// it will return the vector [x0+x2+x4+x6, x1+x3+x5+x7, X, X, X, X, X, X], where X marks garbage slots.
// Note that n*batch must divide the length of the vector (i.e. the number of slots).

batch := 37
n := 127
batch := 32
n := 128

// The innersum operations is carried out with log2(n) + HW(n) automorphisms and we need to
// generate the corresponding Galois keys and provide them to the `Evaluator`.
Expand All @@ -619,7 +619,35 @@ func main() {
// apply the innersum and then only apply the rescaling.
fmt.Printf("Innersum %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, 0, false).String())

// The replicate operation is exactly the same as the innersum operation, but in reverse
// Sometimes we wish to compute an inner sum on the first values of the vector only.
// In this case, n*batch does not necessarily divide the length of the vector and the RotateAndAdd function must be used instead.
// This method allows to repeatedly shift the vector by batch values and add (i.e. \sum_{i=0}^{n-1} v << (i*batch), where v is the input vector).
// For example given a vector [x0, x1, x2, x3, x4, x5, x6, x7], batch = 2 and n = 3
// it will return the vector [x0+x2+x4, x1+x3+x5, x2+x4+x6, x3+x5+x7, x4+x6+x0, x5+x7+x1, x6+x0+x2, x7+x1+x3].
// Observe that the inner sum wraps around the vector, this behavior must be taken into account.

batch = 37
n = 127
eval = eval.WithKey(rlwe.NewMemEvaluationKeySet(rlk, kgen.GenGaloisKeysNew(params.GaloisElementsForInnerSum(batch, n), sk)...))

// Plaintext circuit
copy(want, values1)
for i := 1; i < n; i++ {
for j, vi := range utils.RotateSlice(values1, i*batch) {
want[j] += vi
}
}

if err := eval.RotateAndAdd(ct1, batch, n, res); err != nil {
panic(err)
}

// Note that this method can obviously be used to average values.
// For a good noise management, it is recommended to first multiply the values by 1/n, then
// apply the inner sum and then only apply the rescaling.
fmt.Printf("RotateAndAdd %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, 0, false).String())

// The replicate operation is exactly the same as the rotate and add operation, but in reverse
eval = eval.WithKey(rlwe.NewMemEvaluationKeySet(rlk, kgen.GenGaloisKeysNew(params.GaloisElementsForReplicate(batch, n), sk)...))

// Plaintext circuit
Expand Down
113 changes: 64 additions & 49 deletions schemes/bgv/bgv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,77 +668,92 @@ func testEvaluatorBvg(tc *TestContext, t *testing.T) {
}

// Naive implementation of the inner sum for reference
innersum := func(values []uint64, n, batchSize int) {
tmp := make([]uint64, len(values))
copy(tmp, values)
innersum := func(values []uint64, n, batchSize int, rotateAndAdd bool) {
aggregate := false
if n*batchSize == len(values) && !rotateAndAdd {
aggregate = true
n = n / 2
}
halfN := len(values) >> 1
tmp1 := make([]uint64, halfN)
tmp2 := make([]uint64, halfN)
copy(tmp1, values[:halfN])
copy(tmp2, values[halfN:])
for i := 1; i < n; i++ {
rot := utils.RotateSlice(tmp, i*batchSize)
for j := range values {
values[j] = (values[j] + rot[j]) % tc.Params.PlaintextModulus()
rot1 := utils.RotateSlice(tmp1, i*batchSize)
rot2 := utils.RotateSlice(tmp2, i*batchSize)
for j := range rot1 {
values[j] = (values[j] + rot1[j]) % tc.Params.PlaintextModulus()
values[j+halfN] = (values[j+halfN] + rot2[j]) % tc.Params.PlaintextModulus()
}
}
if aggregate {
for i := range tmp1 {
values[i] = (values[i] + values[i+halfN]) % tc.Params.PlaintextModulus()
}
}
}

for _, N := range []int{tc.Params.N(), tc.Params.MaxSlots()} {
for _, lvl := range testLevel {
t.Run(name("Evaluator/InnerSum/N slots", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
n := N >> 2
batchSize := 1 << 2
for _, i := range []int{0, 1, 2} {
// n*batchSize = N, N/2, N/8
for _, offset := range []int{0, 1, 3} {
for _, lvl := range testLevel {
t.Run(name("Evaluator/InnerSum/", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
n := tc.Params.MaxSlots() >> (i + offset)
batchSize := 1 << i

galEls := tc.Params.GaloisElementsForInnerSum(batchSize, n)
evl := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...))
galEls := tc.Params.GaloisElementsForInnerSum(batchSize, n)
evl := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...))

want, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
want, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))

innersum(want, n, batchSize)
innersum(want, n, batchSize, false)

receiver := NewCiphertext(tc.Params, 1, lvl)
receiver := NewCiphertext(tc.Params, 1, lvl)

require.NoError(t, evl.InnerSum(ciphertext0, batchSize, n, receiver))
require.NoError(t, evl.InnerSum(ciphertext0, batchSize, n, receiver))

have := make([]uint64, len(want))
require.NoError(t, tc.Ecd.Decode(tc.Dec.DecryptNew(receiver), have))
have := make([]uint64, len(want))
require.NoError(t, tc.Ecd.Decode(tc.Dec.DecryptNew(receiver), have))

for i := 0; i < len(want); i += n * batchSize {
require.Equal(t, want[i:i+batchSize], have[i:i+batchSize])
}
})
for i := 0; i < len(want); i += n * batchSize {
require.Equal(t, want[i:i+batchSize], have[i:i+batchSize])
}
})
}
}
}

for _, lvl := range testLevel {
t.Run(name("Evaluator/InnerSum/N/2 slots", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
n := 7
batchSize := 13
l := n * batchSize
halfN := tc.Params.MaxSlots() >> 1
// Test RotateAndAdd with n*batchSize dividing and not dividing #slots
for _, n := range []int{tc.Params.MaxSlots() >> 3, 7} {
for _, batchSize := range []int{8, 3} {
for _, lvl := range testLevel {
t.Run(name("Evaluator/RotateAndAdd/", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}

galEls := tc.Params.GaloisElementsForInnerSum(batchSize, n)
evl := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...))
galEls := tc.Params.GaloisElementsForInnerSum(batchSize, n)
evl := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...))

want, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
want, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))

innersum(want[:halfN], n, batchSize)
innersum(want[halfN:], n, batchSize)
innersum(want, n, batchSize, true)

receiver := NewCiphertext(tc.Params, 1, lvl)
receiver := NewCiphertext(tc.Params, 1, lvl)

require.NoError(t, evl.InnerSum(ciphertext0, batchSize, n, receiver))
require.NoError(t, evl.RotateAndAdd(ciphertext0, batchSize, n, receiver))

have := make([]uint64, len(want))
require.NoError(t, tc.Ecd.Decode(tc.Dec.DecryptNew(receiver), have))
have := make([]uint64, len(want))
require.NoError(t, tc.Ecd.Decode(tc.Dec.DecryptNew(receiver), have))

for i, j := 0, halfN; i < halfN; i, j = i+l, j+l {
require.Equal(t, want[i:i+batchSize], have[i:i+batchSize])
require.Equal(t, want[j:j+batchSize], have[j:j+batchSize])
require.Equal(t, want, have)
})
}
}
})
}
}
}

Expand Down
74 changes: 58 additions & 16 deletions schemes/bgv/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1505,41 +1505,83 @@ func (eval Evaluator) RotateHoistedLazyNew(level int, rotations []int, op0 *rlwe
return
}

// InnerSum computes the inner sum of the underlying slots (see [rlwe.Evaluator.InnerSum]).
// NB: in the slot encoding of BGV/BFV, the underlying N slots are arranged as 2 rows of N/2 slots.
// If n*batchSize is a multiple of N, InnerSum computes the [rlwe.Evaluator.InnerSum] on the N slots.
// NOTE: In this case, InnerSum performs an addition and a [Evaluator.RotateRowsNew] on top.
// Otherwise, InnerSum computes the [rlwe.Evaluator.InnerSum] of each row separately.
// InnerSum divides each row of the underlying plaintext in sub-vectors of size batchSize and add n of these together.
// If n*batchSize = ctIn.Slots(), the inner sum is computed as if the plaintext was a 1-D vector of dimension ctIn.Slots()
// (we recall that a BGV/BFV plaintext is represented as a 2 x ctIn.Slots()/2 matrix).
//
// WARNING: 0 < n*batchSize <= ctIn.Slots() must divide the number of slots ctIn.Slots(). For other parameters, consider using [Evaluator.RotateAndAdd].
//
// Example for batchSize=2, n=4 and 32 slots (garbage slots are marked as X):
//
// Input:
//
// [[{a1, b1}, {c1, d1}, {e1, f1}, {g1, h1}, {i1, j1}, {k1, l1}, {m1, n1}, {o1, p1}]
//
// [{a2, b2}, {c2, d2}, {e2, f2}, {g2, h2}, {i2, j2}, {k2, l2}, {m2, n2}, {o2, p2}]]
//
// Output:
//
// [[{a1+c1+e1+g1, b1+d1+f1+h1}, {X, X}, {X, X}, {X, X}, {i1+k1+m1+o1, j1+l1+n1+p1}, {X, X}, {X, X}, {X, X}]
//
// [{a2+c2+e2+g2, b2+d2+f2+h2}, {X, X}, {X, X}, {X, X}, {i2+k2+m2+o2, j2+l2+n2+p2}, {X, X}, {X, X}, {X, X}]]
func (eval Evaluator) InnerSum(ctIn *rlwe.Ciphertext, batchSize, n int, opOut *rlwe.Ciphertext) (err error) {
N := eval.parameters.MaxSlots()
N := ctIn.Slots()
l := n * batchSize

if l%N == 0 {
if n <= 0 || batchSize <= 0 {
return fmt.Errorf("innersum: invalid parameter (n <= 0 or batchSize <= 0)")
}
if l > N {
return fmt.Errorf("innersum: invalid parameters (n*batchSize=%d > #slots=%d)", l, N)
}
if l&(l-1) != 0 {
return fmt.Errorf("innersum: invalid parameters (n*batchSize=%d does not divide #slots=%d)", l, N)
}

if l == N {
if n == 1 {
if ctIn != opOut {
opOut.Copy(ctIn)
}
opOut.Copy(ctIn)
return
}

if err = eval.Evaluator.InnerSum(ctIn, batchSize, n/2, opOut); err != nil {
if err = eval.Evaluator.PartialTrace(ctIn, batchSize, n/2, opOut); err != nil {
return
}

var ctRot *rlwe.Ciphertext
ctRot, err = eval.RotateRowsNew(opOut)
if err != nil {
ctTmp := &rlwe.Ciphertext{Element: rlwe.Element[ring.Poly]{Value: []ring.Poly{eval.BuffQP[2].Q, eval.BuffQP[3].Q}}}
ctTmp.MetaData = opOut.MetaData
if err = eval.RotateRows(opOut, ctTmp); err != nil {
return
}

if err = eval.Add(opOut, ctRot, opOut); err != nil {
if err = eval.Add(opOut, ctTmp, opOut); err != nil {
return
}

return
}

err = eval.Evaluator.InnerSum(ctIn, batchSize, n, opOut)
err = eval.Evaluator.PartialTrace(ctIn, batchSize, n, opOut)
return
}

// RotateAndAdd computes the sum of pt_i, 0 <= i < n, where pt_i is the underlying plaintext rotated ([Evaluator.RotateRows]) by batchSize*i slots.
//
// Example: for batchSize=3, n=2, ctIn.Slots()=16:
//
// Input (recall that a BGV/BFV plaintext is represented as a 2 x ctIn.Slots()/2 matrix):
//
// [[a, b, c, d, e, f, g, h]
// [i, j, k, l, m, n, o, p]]
//
// Output:
//
// [[a, b, c, d, e, f, g, h] + [[d, e, f, g, h, a, b, c] = [[a+d, b+e, c+f, d+g, e+h, f+a, g+b, h+c]
// [i, j, k, l, m, n, o, p]] [l, m, n, o, p, i, j, k]] [i+l, j+m, k+n, l+o, m+p, n+i, o+j, p+k]]
//
// Calling RotateAndAdd(ctIn, 1, n, opOut) can be used to compute the inner sum of the first n slots of a plaintext.
func (eval Evaluator) RotateAndAdd(ctIn *rlwe.Ciphertext, batchSize, n int, opOut *rlwe.Ciphertext) (err error) {
err = eval.Evaluator.PartialTrace(ctIn, batchSize, n, opOut)
return
}

Expand Down
Loading

0 comments on commit 3a52e44

Please sign in to comment.