Skip to content

Commit

Permalink
added objectstore-based public-key backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianMct committed Jul 22, 2024
1 parent 6fcba5c commit 7ddec84
Show file tree
Hide file tree
Showing 16 changed files with 121 additions and 37 deletions.
2 changes: 0 additions & 2 deletions circuits/parsing.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package circuits
import (
"encoding/json"
"fmt"
"log"
"maps"
"sync"

Expand Down Expand Up @@ -221,5 +220,4 @@ func (e *circuitParserContext) EvalLocal(needRlk bool, galKeys []uint64, f func(
}

func (e *circuitParserContext) Logf(format string, args ...interface{}) {
log.Printf(format, args...)
}
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (hc *HeliumClient) Register(ctx context.Context) (upstream *coordinator.Cha

ev := api.ToNodeEvent(apiEvent)
eventsStream <- ev
log.Printf("[client] new event: %s", ev)
//log.Printf("[client] new event: %s", ev)
}
}()

Expand Down
2 changes: 1 addition & 1 deletion examples/vec-mul/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func main() {
// simulates loading the secrets. In a real application, the secrets would be loaded from a secure storage.
func loadSecrets(params sessions.Parameters, nid sessions.NodeID) node.SecretProvider {

var sp node.SecretProvider = func(sid sessions.ID) (*sessions.Secrets, error) {
var sp node.SecretProvider = func(sid sessions.ID, nid sessions.NodeID) (*sessions.Secrets, error) {

if sid != params.ID {
return nil, fmt.Errorf("no secret for session %s", sid)
Expand Down
4 changes: 2 additions & 2 deletions helium_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ func TestSetup(t *testing.T) {
t.Fatal(err)
}

setup.CheckTestSetup(ctx, t, lt.TestSession, *app.SetupDescription, helper)
setup.CheckTestSetup(ctx, t, *app.SetupDescription, helper, lt.RlweParams, lt.SkIdeal, ts.N)

for _, cli := range clients {
log.Println("checking setup for", cli.id)
resCheckCtx, runCheckCancel := context.WithTimeout(ctx, time.Second)
setup.CheckTestSetup(resCheckCtx, t, lt.TestSession, *app.SetupDescription, cli)
setup.CheckTestSetup(resCheckCtx, t, *app.SetupDescription, cli, lt.RlweParams, lt.SkIdeal, ts.N)
runCheckCancel()

require.NoError(t, cli.Close())
Expand Down
2 changes: 1 addition & 1 deletion node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,4 @@ type TLSConfig struct {

// SecretProvider is a function that returns the secrets for a session,
// given the session ID.
type SecretProvider func(sessions.ID) (*sessions.Secrets, error)
type SecretProvider func(sessions.ID, sessions.NodeID) (*sessions.Secrets, error)
2 changes: 1 addition & 1 deletion node/localtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func NewLocalTest(config LocalTestConfig) (test *LocalTest, err error) {
test.Nodes[0] = test.HelperNode
for i, nc := range test.SessNodeConfigs {
var err error
test.Nodes[i+1], err = New(nc, test.List, func(_ sessions.ID) (*sessions.Secrets, error) {
test.Nodes[i+1], err = New(nc, test.List, func(_ sessions.ID, _ sessions.NodeID) (*sessions.Secrets, error) {
return secrets[nc.ID], nil
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ func (node *Node) createNewSession(sessParams sessions.Parameters, secrets Secre

var sec *sessions.Secrets
if slices.Contains(sessParams.Nodes, node.id) {
sec, err = secrets(sessParams.ID)
sec, err = secrets(sessParams.ID, node.id)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func TestNodeSetup(t *testing.T) {

for _, node := range all {
resCheckCtx, runCheckCancel := context.WithTimeout(ctx, time.Second)
setup.CheckTestSetup(resCheckCtx, t, lt.TestSession, *app.SetupDescription, node)
setup.CheckTestSetup(resCheckCtx, t, *app.SetupDescription, node, lt.TestSession.RlweParams, lt.TestSession.SkIdeal, ts.N)
runCheckCancel()
}
})
Expand Down
3 changes: 1 addition & 2 deletions protocols/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package protocols
import (
"context"
"fmt"
"log"
"slices"
"sync"

Expand Down Expand Up @@ -672,7 +671,7 @@ func (s *Executor) getParticipants(sig Signature, sess *sessions.Session) []sess
}

func (s *Executor) Logf(msg string, v ...any) {
log.Printf("%s | [executor] %s\n", s.self, fmt.Sprintf(msg, v...))
//log.Printf("%s | [executor] %s\n", s.self, fmt.Sprintf(msg, v...))
}

func (s *Executor) NodeID() sessions.NodeID {
Expand Down
6 changes: 2 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ func (hsv *HeliumServer) Register(_ *pb.Void, stream pb.Helium_RegisterServer) e
}
hsv.mu.Unlock() // all events after pastEvents will go on the sendQueue

hsv.Logf("registering %s...", nodeID)
err := hsv.helperNode.Register(nodeID)
if err != nil {
panic(err)
Expand All @@ -249,7 +248,7 @@ func (hsv *HeliumServer) Register(_ *pb.Void, stream pb.Helium_RegisterServer) e
break
}
}
hsv.Logf("done sending %d past events to %s, stream is live", present, nodeID)
//hsv.Logf("done sending %d past events to %s, stream is live", present, nodeID)

// Processes the node's sendQueue. The sendQueue channel is closed when exiting the loop
cancelled := stream.Context().Done()
Expand All @@ -265,7 +264,7 @@ func (hsv *HeliumServer) Register(_ *pb.Void, stream pb.Helium_RegisterServer) e
//hsv.Logf("sent to node %s: %v", nodeID, evt)
} else {
done = true
hsv.Logf("update queue for %s closed", nodeID)
//hsv.Logf("update queue for %s closed", nodeID)
}

// stream was terminated by the node or the server
Expand All @@ -283,7 +282,6 @@ func (hsv *HeliumServer) Register(_ *pb.Void, stream pb.Helium_RegisterServer) e
peer.sendQueue = nil
hsv.mu.Unlock()

hsv.Logf("unregistering %s...", nodeID)
err = hsv.helperNode.Unregister(nodeID)
if err != nil {
panic(err)
Expand Down
2 changes: 1 addition & 1 deletion services/compute/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,5 @@ func (se *evaluatorRuntime) Parameters() sessions.FHEParameters {
}

func (se *evaluatorRuntime) Logf(msg string, v ...any) {
log.Printf("%s | [%s] %s\n", se.cDesc.Evaluator, se.cDesc.CircuitID, fmt.Sprintf(msg, v...))
log.Printf("%s | [compute][%s] %s\n", se.cDesc.Evaluator, se.cDesc.CircuitID, fmt.Sprintf(msg, v...))
}
3 changes: 1 addition & 2 deletions services/compute/participant.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package compute
import (
"context"
"fmt"
"log"
"math/big"

"golang.org/x/exp/maps"
Expand Down Expand Up @@ -272,7 +271,7 @@ func (p *participantRuntime) Parameters() sessions.FHEParameters {
}

func (p *participantRuntime) Logf(msg string, v ...any) {
log.Printf("%s | [%s] %s\n", p.sess.NodeID, p.cd.CircuitID, fmt.Sprintf(msg, v...))
//log.Printf("%s | [%s] %s\n", p.sess.NodeID, p.cd.CircuitID, fmt.Sprintf(msg, v...))
}

func isRLWEPLaintext(in interface{}) bool {
Expand Down
8 changes: 2 additions & 6 deletions services/compute/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,23 +433,19 @@ func (s *Service) Run(ctx context.Context, ip InputProvider, or OutputReceiver,

if ev.ProtocolEvent != nil {
pev := *ev.ProtocolEvent

s.Logf("new coordination event: PROTOCOL %s", pev)

s.Logf("PROTOCOL %s", pev)
s.incoming <- pev

switch pev.EventType {
case protocols.Completed:
if err := s.sendCompletedPdToCircuit(pev.Descriptor); err != nil {
panic(err)
}
}

continue
}

cev := *ev.CircuitEvent
s.Logf("new coordination event: CIRCUIT %s", cev)
s.Logf("CIRCUIT %s", cev)
switch ev.CircuitEvent.EventType {
case circuits.Started:
err := s.createCircuit(serviceCtx, cev.Descriptor)
Expand Down
13 changes: 4 additions & 9 deletions services/setup/description.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,12 @@ func (sd Description) String() string {
}

// CheckTestSetup checks if a public key provider is able to produce valid keys for a given test session and setup description.
func CheckTestSetup(ctx context.Context, t *testing.T, lt *sessions.TestSession, setup Description, n sessions.PublicKeyProvider) {

params := lt.RlweParams
sk := lt.SkIdeal
nParties := len(lt.HelperSession.Nodes)

func CheckTestSetup(ctx context.Context, t *testing.T, setup Description, n sessions.PublicKeyProvider, params rlwe.Parameters, skIdeal *rlwe.SecretKey, nParties int) {
// check CPK
if setup.Cpk {
cpk, err := n.GetCollectivePublicKey(ctx)
require.NoError(t, err)
require.Less(t, rlwe.NoisePublicKey(cpk, sk, params), math.Log2(math.Sqrt(float64(nParties))*params.NoiseFreshSK())+1)
require.Less(t, rlwe.NoisePublicKey(cpk, skIdeal, params), math.Log2(math.Sqrt(float64(nParties))*params.NoiseFreshSK())+1)
}

// check RTG
Expand All @@ -86,7 +81,7 @@ func CheckTestSetup(ctx context.Context, t *testing.T, lt *sessions.TestSession,

decompositionVectorSize := params.BaseRNSDecompositionVectorSize(params.MaxLevelQ(), params.MaxLevelP())
noiseBound := math.Log2(math.Sqrt(float64(decompositionVectorSize))*drlwe.NoiseGaloisKey(params, nParties)) + 1
require.Less(t, rlwe.NoiseGaloisKey(rtk, sk, params), noiseBound, "rtk for galEl %d should be correct", galEl)
require.Less(t, rlwe.NoiseGaloisKey(rtk, skIdeal, params), noiseBound, "rtk for galEl %d should be correct", galEl)

}

Expand All @@ -98,6 +93,6 @@ func CheckTestSetup(ctx context.Context, t *testing.T, lt *sessions.TestSession,
BaseRNSDecompositionVectorSize := params.BaseRNSDecompositionVectorSize(params.MaxLevelQ(), params.MaxLevelP())
noiseBound := math.Log2(math.Sqrt(float64(BaseRNSDecompositionVectorSize))*drlwe.NoiseRelinearizationKey(params, nParties)) + 1

require.Less(t, rlwe.NoiseRelinearizationKey(rlk, sk, params), noiseBound)
require.Less(t, rlwe.NoiseRelinearizationKey(rlk, skIdeal, params), noiseBound)
}
}
99 changes: 99 additions & 0 deletions services/setup/keybackend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package setup

import (
"context"
"fmt"

"github.com/ChristianMct/helium/objectstore"
"github.com/ChristianMct/helium/protocols"
"github.com/ChristianMct/helium/sessions"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
)

type KeyBackend struct {
*objStoreResultBackend
sess *sessions.Session
}

func NewKeyBackend(osc objectstore.Config, sessParams sessions.Parameters) (kb *KeyBackend, err error) {
kb = new(KeyBackend)
os, err := objectstore.NewObjectStoreFromConfig(osc)
if err != nil {
return nil, err
}
kb.objStoreResultBackend = newObjStoreResultBackend(os)
kb.sess, err = sessions.NewSession("", sessParams, nil)
if err != nil {
return nil, err
}
return kb, nil
}

// GetCollectivePublicKey returns the collective public key for the session in ctx.
func (kb *KeyBackend) GetCollectivePublicKey(ctx context.Context) (*rlwe.PublicKey, error) {
out, err := kb.getOutput(ctx, protocols.Signature{Type: protocols.CKG})
if err != nil {
return nil, err
}
return out.(*rlwe.PublicKey), nil
}

// GetGaloisKey returns the galois key for the session in ctx and the given Galois element.
func (kb *KeyBackend) GetGaloisKey(ctx context.Context, galEl uint64) (*rlwe.GaloisKey, error) {
out, err := kb.getOutput(ctx, protocols.Signature{Type: protocols.RTG, Args: map[string]string{"GalEl": fmt.Sprintf("%d", galEl)}})
if err != nil {
return nil, err
}
return out.(*rlwe.GaloisKey), err
}

// GetRelinearizationKey returns the relinearization key for the session in ctx.
func (kb *KeyBackend) GetRelinearizationKey(ctx context.Context) (*rlwe.RelinearizationKey, error) {
out, err := kb.getOutput(ctx, protocols.Signature{Type: protocols.RKG})
if err != nil {
return nil, err
}
return out.(*rlwe.RelinearizationKey), err
}

func (kb *KeyBackend) getOutput(ctx context.Context, sig protocols.Signature) (interface{}, error) {
pd, err := kb.GetProtocolDesc(ctx, sig)
if err != nil {
return nil, err
}
share := sig.Type.Share()
err = kb.GetShare(ctx, sig, share)
if err != nil {
return nil, err
}
aggOut := protocols.AggregationOutput{Descriptor: *pd, Share: protocols.Share{MHEShare: share, ShareMetadata: protocols.ShareMetadata{ProtocolID: pd.ID(), ProtocolType: pd.Type}}}

p, err := protocols.NewProtocol(*pd, kb.sess)
if err != nil {
return nil, err
}
in, err := kb.getProtoInput(ctx, p)
if err != nil {
return nil, err
}
out := protocols.AllocateOutput(sig, *kb.sess.Params.GetRLWEParameters()) // TODO cache ?
if err = p.Output(in, aggOut, out); err != nil {
return nil, err
}
return out, nil
}

func (kb *KeyBackend) getProtoInput(ctx context.Context, p *protocols.Protocol) (protocols.Input, error) {
switch p.Descriptor().Type {
case protocols.CKG, protocols.RTG:
return p.ReadCRP()
case protocols.RKG:
share := p.AllocateShare()
if err := kb.objStoreResultBackend.GetShare(ctx, protocols.Signature{Type: protocols.RKG1}, share); err != nil {
return nil, fmt.Errorf("could not retrieve round 1 share: %w", err)
}
return share.MHEShare, nil
default:
return nil, fmt.Errorf("unsupported protocol type %s", p.Descriptor().Type)
}
}
6 changes: 3 additions & 3 deletions services/setup/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func TestSetup(t *testing.T) {
require.Nil(t, err)

for _, n := range all {
CheckTestSetup(ctx, t, testSess, sd, n)
CheckTestSetup(ctx, t, sd, n, testSess.RlweParams, testSess.SkIdeal, ts.N)
}
})
}
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestSetupLateConnect(t *testing.T) {
}

for _, n := range all {
CheckTestSetup(ctx, t, testSess, sd, n)
CheckTestSetup(ctx, t, sd, n, testSess.RlweParams, testSess.SkIdeal, ts.N)
}
})
}
Expand Down Expand Up @@ -336,6 +336,6 @@ func TestSetupRetries(t *testing.T) {
require.Nil(t, err)

for _, n := range all {
CheckTestSetup(ctx, t, testSess, Description{Cpk: true}, n)
CheckTestSetup(ctx, t, Description{Cpk: true}, n, testSess.RlweParams, testSess.SkIdeal, ts.N)
}
}

0 comments on commit 7ddec84

Please sign in to comment.