From 20e7fca174f052da30f155735bc68f3b81181786 Mon Sep 17 00:00:00 2001 From: fuyao-w <815775754@qq.com> Date: Wed, 26 Apr 2023 18:38:59 +0800 Subject: [PATCH] feat:init --- .gitignore | 5 + LICENSE | 21 + api.go | 253 +++++++ cmd.go | 123 ++++ command.go | 393 +++++++++++ config.go | 93 +++ configuration.go | 80 +++ file_snapshot.go | 332 +++++++++ fsm.go | 116 ++++ future.go | 267 +++++++ go.mod | 13 + go.sum | 14 + log.go | 31 + mem_transport.go | 258 +++++++ memory_log.go | 163 +++++ men_fsm.go | 126 ++++ net_protocol.go | 68 ++ net_transport.go | 435 ++++++++++++ raft.go | 1722 ++++++++++++++++++++++++++++++++++++++++++++++ raft_test.go | 242 +++++++ rpc.go | 107 +++ snapshot.go | 67 ++ state.go | 43 ++ store.go | 41 ++ t_test.go | 1 + tcp_transport.go | 67 ++ transport.go | 77 +++ util.go | 198 ++++++ 28 files changed, 5356 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 api.go create mode 100644 cmd.go create mode 100644 command.go create mode 100644 config.go create mode 100644 configuration.go create mode 100644 file_snapshot.go create mode 100644 fsm.go create mode 100644 future.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 log.go create mode 100644 mem_transport.go create mode 100644 memory_log.go create mode 100644 men_fsm.go create mode 100644 net_protocol.go create mode 100644 net_transport.go create mode 100644 raft.go create mode 100644 raft_test.go create mode 100644 rpc.go create mode 100644 snapshot.go create mode 100644 state.go create mode 100644 store.go create mode 100644 t_test.go create mode 100644 tcp_transport.go create mode 100644 transport.go create mode 100644 util.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9163e11 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.idea +testsnapshot +snapshot +example/testsnapshot +example/log_*.log \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7bd6a2c --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Wangfuyao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/api.go b/api.go new file mode 100644 index 0000000..f9f3f5e --- /dev/null +++ b/api.go @@ -0,0 +1,253 @@ +package papillon + +import ( + "errors" + "io" + "time" +) + +var ( + ErrNotExist = errors.New("not exist") + ErrPipelineReplicationNotSupported = errors.New("pipeline replication not supported") + ErrNotFoundLog = customError{"not found log"} + ErrNotLeader = errors.New("not leader") + ErrCantBootstrap = errors.New("bootstrap only works on new clusters") + ErrIllegalConfiguration = errors.New("illegal configuration") + ErrShutDown = errors.New("shut down") + ErrLeadershipTransferInProgress = errors.New("leader ship transfer in progress") + // ErrAbortedByRestore is returned when a leader fails to commit a log + // entry because it's been superseded by a user snapshot restore. + ErrAbortedByRestore = errors.New("snapshot restored while committing log") + ErrEnqueueTimeout = errors.New("timed out enqueuing operation") + ErrTimeout = errors.New("time out") + ErrPipelineShutdown = errors.New("append pipeline closed") + ErrNotVoter = errors.New("not voter") + ErrLeadershipTransferFail = errors.New("not found transfer peer") + ErrLeadershipLost = errors.New("leadership lost") + ErrNothingNewToSnapshot = errors.New("nothing new to snapshot") +) + +type customError struct{ string } + +func (e customError) Error() string { + return e.string +} + +func (e customError) Is(err error) bool { + if err == nil { + return false + } + return e.Error() == err.Error() +} + +func (r *Raft) BootstrapCluster(configuration Configuration) defaultFuture { + future := &bootstrapFuture{ + configuration: configuration, + } + future.init() + select { + case <-r.shutDown.C: + future.fail(ErrShutDown) + case r.commandCh <- &command{typ: commandBootstrap, item: future}: + } + return future +} + +func (r *Raft) LeaderInfo() (ServerID, ServerAddr) { + info := r.leaderInfo.Get() + return info.ID, info.Addr +} + +// Apply 向 raft 提交日志 +func (r *Raft) Apply(data []byte, timeout time.Duration) ApplyFuture { + return r.apiApplyLog(&LogEntry{Data: data, Type: LogCommand}, timeout) +} +func (r *Raft) apiApplyLog(entry *LogEntry, timeout time.Duration) ApplyFuture { + var tm <-chan time.Time + if timeout > 0 { + tm = time.After(timeout) + } + var applyFuture = &LogFuture{ + log: entry, + } + applyFuture.init() + select { + case <-tm: + return &errFuture[nilRespFuture]{errors.New("apply log time out")} + case <-r.shutDown.C: + return &errFuture[nilRespFuture]{ErrShutDown} + case r.apiLogApplyCh <- applyFuture: //batch apply + return applyFuture + case r.commandCh <- &command{ // 正常提交 + typ: commandLogApply, + item: applyFuture, + }: + } + return applyFuture +} + +// VerifyLeader 验证当前节点是否是领导人 +func (r *Raft) VerifyLeader() Future[bool] { + vf := &verifyFuture{} + vf.init() + select { + case <-r.shutDown.C: + vf.fail(ErrShutDown) + return &vf.deferResponse + case r.commandCh <- &command{typ: commandVerifyLeader, item: vf}: + return &vf.deferResponse + } +} + +// GetConfiguration 获取集群配置 +func (r *Raft) GetConfiguration() Configuration { + return r.configuration.Load() +} + +func (r *Raft) requestClusterChange(req configurationChangeRequest, timeout time.Duration) IndexFuture { + var tm <-chan time.Time + if timeout > 0 { + tm = time.After(timeout) + } + var ccf = &configurationChangeFuture{ + req: &req, + } + ccf.init() + select { + case <-tm: + return &errFuture[nilRespFuture]{err: errors.New("apply log time out")} + case <-r.shutDown.C: + return &errFuture[nilRespFuture]{err: ErrShutDown} + case r.commandCh <- &command{typ: commandClusterChange, item: ccf}: + return ccf + } +} + +func (r *Raft) AddServer(peer ServerInfo, prevIndex uint64, timeout time.Duration) IndexFuture { + return r.requestClusterChange(configurationChangeRequest{ + command: addServer, + peer: peer, + pervIndex: prevIndex, + }, timeout) +} +func (r *Raft) RemoveServer(peer ServerInfo, prevIndex uint64, timeout time.Duration) IndexFuture { + return r.requestClusterChange(configurationChangeRequest{ + command: removeServer, + peer: peer, + pervIndex: prevIndex, + }, timeout) +} + +func (r *Raft) UpdateServer(peer ServerInfo, prevIndex uint64, timeout time.Duration) IndexFuture { + return r.requestClusterChange(configurationChangeRequest{ + command: updateServer, + peer: peer, + pervIndex: prevIndex, + }, timeout) +} + +func (r *Raft) SnapShot() Future[OpenSnapShot] { + fu := &apiSnapshotFuture{} + fu.init() + select { + case <-r.shutDown.C: + return &errFuture[OpenSnapShot]{ErrShutDown} + case r.apiSnapshotBuildCh <- fu: + return fu + } +} + +func (r *Raft) StateCh() <-chan State { + return r.stateChangeCh +} + +func (r *Raft) LastContact() time.Time { + return r.lastContact.Load() +} + +func (r *Raft) LatestIndex() uint64 { + return r.getLatestIndex() +} + +func (r *Raft) LastApplied() uint64 { + return r.getLastApplied() +} + +func (r *Raft) LeaderTransfer(id ServerID, address ServerAddr) defaultFuture { + future := &leadershipTransferFuture{ + Peer: ServerInfo{ + ID: id, + Addr: address, + }, + } + if len(id) > 0 && id == r.localInfo.ID { + future.fail(errors.New("can't transfer to itself")) + return future + } + future.init() + select { + case r.commandCh <- &command{typ: commandLeadershipTransfer, item: future}: + return future + case <-r.shutDown.C: + return &errFuture[nilRespFuture]{ErrShutDown} + default: + return &errFuture[nilRespFuture]{ErrEnqueueTimeout} + } +} + +func (r *Raft) ReloadConfig(rc ReloadableConfig) error { + r.confReloadMu.Lock() + defer r.confReloadMu.Unlock() + oldConf := *r.Conf() + newConf := rc.apply(oldConf) + ok, hint := ValidateConfig(&newConf) + if !ok { + return errors.New(hint) + } + r.conf.Store(&newConf) + if newConf.HeartbeatTimeout <= oldConf.HeartbeatTimeout { + return nil + } + select { + case <-r.shutDown.C: + return ErrShutDown + case r.commandCh <- &command{ + typ: commandConfigReload, + item: newConf, + }: + + } + return nil +} + +func (r *Raft) ReStoreSnapshot(meta *SnapShotMeta, reader io.ReadCloser) error { + fu := &userRestoreFuture{ + meta: meta, + reader: reader, + } + fu.init() + select { + case r.commandCh <- &command{typ: commandSnapshotRestore, item: fu}: + case <-r.shutDown.C: + return ErrShutDown + } + _, err := fu.Response() + if err != nil { + return err + } + applyFu := r.apiApplyLog(&LogEntry{Type: LogNoop}, 0) + _, err = applyFu.Response() + return err +} + +func (r *Raft) ShutDown() defaultFuture { + var resp *shutDownFuture + r.shutDown.done(func(oldState bool) { + resp = new(shutDownFuture) + if !oldState { + resp.raft = r + } + r.setShutDown() + }) + return resp +} diff --git a/cmd.go b/cmd.go new file mode 100644 index 0000000..0e2310b --- /dev/null +++ b/cmd.go @@ -0,0 +1,123 @@ +package papillon + +import ( + "io" + "time" +) + +type ( + Processor interface { + Do(rpcType, interface{}, io.Reader) (interface{}, error) + SetFastPath(cb fastPath) + } + // ProcessorProxy 服务器接口 handler 代理,提供将序列化数据,解析成接口 struct 指针的功能 + ProcessorProxy struct { + Processor + } + // ServerProcessor 服务器接口 handler ,提供具体的接口处理逻辑 + ServerProcessor struct { + cmdChan chan *RPC + fastPath fastPath + } +) + +func (d *ProcessorProxy) SetFastPath(cb fastPath) { + d.Processor.SetFastPath(cb) +} +func (d *ServerProcessor) SetFastPath(cb fastPath) { + d.fastPath = cb +} + +// Do ServerProcessor 不关心上层协议,所以不用处理第一个参数(rpcType) +func (d *ServerProcessor) Do(typ rpcType, req interface{}, reader io.Reader) (resp interface{}, err error) { + resCh := make(chan any, 1) + cmd := &RPC{ + Request: req, + Response: resCh, + } + switch typ { + case CmdAppendEntry: + request := req.(*AppendEntryRequest) + if len(request.Entries) == 0 && d.fastPath != nil && d.fastPath(cmd) { + return <-resCh, nil + } + case CmdInstallSnapshot: + cmd.Reader = io.LimitReader(reader, req.(*InstallSnapshotRequest).SnapshotMeta.Size) + } + + d.cmdChan <- cmd + return <-resCh, nil +} + +type processorOption struct { + Processor + CmdConvert +} + +func withProcessor(p Processor) func(opt *processorOption) { + return func(opt *processorOption) { + opt.Processor = p + } +} +func withCmdConvert(c CmdConvert) func(opt *processorOption) { + return func(opt *processorOption) { + opt.CmdConvert = c + } +} +func newProcessorProxy(cmdCh chan *RPC, options ...func(opt *processorOption)) Processor { + proxy := &ProcessorProxy{ + Processor: &ServerProcessor{ + cmdChan: cmdCh, + }, + } + var opt processorOption + for _, do := range options { + do(&opt) + } + if opt.Processor != nil { + proxy.Processor = opt.Processor + } + //if opt.CmdConvert != nil { + // proxy.CmdConvert = opt.CmdConvert + //} + return proxy +} + +func (p *ProcessorProxy) Do(cmdType rpcType, reqBytes interface{}, reader io.Reader) (respBytes interface{}, err error) { + date := reqBytes.([]byte) + var req interface{} + + switch cmdType { + case CmdVoteRequest: + req = new(VoteRequest) + case CmdAppendEntry: + req = new(AppendEntryRequest) + case CmdAppendEntryPipeline: + req = new(AppendEntryRequest) + case CmdInstallSnapshot: + req = new(InstallSnapshotRequest) + } + err = defaultCmdConverter.Deserialization(date, req) + if err != nil { + return + } + resp, err := p.Processor.Do(cmdType, req, reader) + if err != nil { + return nil, err + } + return defaultCmdConverter.Serialization(resp) +} + +func doWithTimeout(timeout time.Duration, do func()) bool { + wrapper := func() chan struct{} { + done := make(chan struct{}) + go do() + return done + } + select { + case <-time.After(timeout): + return false + case <-wrapper(): + return true + } +} diff --git a/command.go b/command.go new file mode 100644 index 0000000..77aaf90 --- /dev/null +++ b/command.go @@ -0,0 +1,393 @@ +package papillon + +import ( + "errors" + "fmt" + . "github.com/fuyao-w/common-util" + "io" + "sync" + "time" +) + +type ( + commandTyp int + command struct { + typ commandTyp + item interface{} + } + commandMap map[commandTyp]map[State]func(*Raft, interface{}) +) +type testFuture struct { + deferResponse[string] +} + +const ( + commandTest commandTyp = iota + 1 // 用于单测 + commandClusterGet + commandBootstrap + commandLogApply + commandSnapshotRestore + commandClusterChange + commandVerifyLeader + commandConfigReload + commandLeadershipTransfer +) + +var channelCommand commandMap + +func init() { + channelCommand = commandMap{ + commandTest: { + Leader: func(raft *Raft, i interface{}) { + i.(*testFuture).responded("succ", nil) + }, + }, + commandClusterGet: { + Leader: (*Raft).processClusterGet, + }, + commandLogApply: { + Leader: (*Raft).processLogApply, + }, + commandBootstrap: { + Follower: (*Raft).processBootstrap, + }, + commandSnapshotRestore: { + Leader: (*Raft).processSnapshotRestore, + }, + commandLeadershipTransfer: { + Leader: (*Raft).processLeadershipTransfer, + }, + commandConfigReload: { + Follower: (*Raft).processReloadConfig, + Leader: (*Raft).processReloadConfig, + Candidate: (*Raft).processReloadConfig, + }, + commandVerifyLeader: { + Leader: (*Raft).processVerifyLeader, + }, + commandClusterChange: { + Leader: (*Raft).processClusterChange, + }, + } +} + +func (r *Raft) processCommand(cmd *command) { + cc, ok := channelCommand[cmd.typ] + if !ok { + panic(fmt.Sprintf("command type :%d not register", cmd.typ)) + } + state := r.state.Get() + f, ok := cc[state] + if ok { + f(r, cmd.item) + } else { + cmd.item.(reject).reject(state) + } +} + +// processClusterGet 用于内部从主线程获取集群配置 +func (r *Raft) processClusterGet(i interface{}) { + fu := i.(*clusterGetFuture) + if r.leaderState.getLeadershipTransfer() { + fu.fail(ErrLeadershipTransferInProgress) + return + } + fu.responded(r.cluster.Clone(), nil) +} + +// processLogApply 只在 cycleLeader 中调用,将日志提交到本次,并通知跟随者进行复制 +func (r *Raft) processLogApply(item interface{}) { + var ( + fu = item.(*LogFuture) + futures = []*LogFuture{fu} + batch = r.Conf().ApplyBatch + ) + if r.leaderState.getLeadershipTransfer() { + fu.fail(ErrLeadershipTransferInProgress) + return + } +BREAK: + for i := 0; i < batch; i++ { + select { + case applyFu := <-r.apiLogApplyCh: + futures = append(futures, applyFu) + default: + break BREAK + } + } + r.applyLog(futures) +} + +// processBootstrap 引导集群启动,节点必须是干净的(日志、快照、任期都是 0 ) +// 将配置作为第一条日志存储到本地,然后当前节点就可以发起选举,并最终将日志复制到集群副本中 +func (r *Raft) processBootstrap(item interface{}) { + var ( + fu = item.(*bootstrapFuture) + ) + if !validateConfiguration(fu.configuration) { + fu.fail(ErrIllegalConfiguration) + return + } + exist, err := r.hasExistTerm() + if err != nil { + fu.fail(err) + return + } + if exist { + fu.fail(ErrCantBootstrap) + return + } + entry := &LogEntry{ + Index: 1, + Term: 1, + Data: EncodeConfiguration(fu.configuration), + Type: LogConfiguration, + CreatedAt: time.Now(), + } + if err = r.logStore.SetLogs([]*LogEntry{entry}); err != nil { + r.logger.Errorf("processBootstrap|SetLogs err:%s", err) + fu.fail(err) + return + } + r.setCurrentTerm(1) + r.setLatestLog(1, 1) + r.saveConfiguration(entry) + fu.success() +} + +// processSnapshotRestore 从外部接收一个快照并应用,但是不进行日志压缩,外部看起来就像多执行了一些命令一样 +// 当前节点会基于最新的配置和日志索引创建快照,并且会把最新索引 + 1 。这样做的目的是更快的触发快照发送 +// 所有的集群副本最终在状态机中应用了两个命令集合:一个原有的 + 从外部引入的 +// 注意:这个请求执行后的最新快照索引和最新日志索引均是最新的。在下一次完整快照生成之前,如果我们重启节点, +// 将会丢失最新快照索引 SnapShotMeta.Index 之前的所有日志。该命令仅应在崩溃恢复时使用 +// TODO 可以在崩溃恢复后的下一次快照检查强制生成快照,尽量保证日志的安全 +func (r *Raft) processSnapshotRestore(item interface{}) { + var ( + fu = item.(*userRestoreFuture) + ) + if r.leaderState.getLeadershipTransfer() { + fu.fail(ErrLeadershipTransferInProgress) + return + } + if r.cluster.commitIndex != r.cluster.latestIndex { + fu.fail(fmt.Errorf("cannot restore snapshot now, wait until the configuration entry at %v has been applied (have applied %v)", + r.cluster.latestIndex, r.cluster.commitIndex)) + return + } + for e := r.leaderState.inflight.Front(); e != nil; e = e.Next() { + e.Value.(*LogFuture).fail(ErrAbortedByRestore) + r.leaderState.inflight.Remove(e) + } + index := r.getLatestIndex() + 1 + sink, err := r.snapshotStore.Create(SnapShotVersionDefault, index, r.getCurrentTerm(), r.cluster.latest, r.cluster.latestIndex, nil) + if err != nil { + fu.fail(err) + return + } + defer fu.reader.Close() + written, err := io.Copy(sink, fu.reader) + if err != nil { + fu.fail(err) + sink.Cancel() + return + } + if written != fu.meta.Size { + fu.fail(fmt.Errorf("failed to write snapshot, size didn't match (%d != %d)", written, fu.meta.Size)) + sink.Cancel() + return + } + if err := sink.Close(); err != nil { + fu.fail(err) + return + } + r.logger.Info("copied to local snapshot", "bytes", written) + + fsmFu := &restoreFuture{ID: sink.ID()} + fsmFu.init() + select { + case r.fsmRestoreCh <- fsmFu: + case <-r.shutDown.C: + fu.fail(ErrShutDown) + return + } + if _, err = fsmFu.Response(); err != nil { + panic(fmt.Errorf("failed to restore snapshot: %v", err)) + } + + r.setLatestLog(r.getCurrentTerm(), index) + r.setLatestSnapshot(r.getCurrentTerm(), index) + r.setLastApplied(index) + r.logger.Info("restored user snapshot", "index", index) + fu.success() +} + +// pickLatestPeer 寻找进度最新的跟随者 +func (r *Raft) pickLatestPeer() *replication { + var ( + latest *replication + latestIndex uint64 + rep = r.leaderState.replicate + ) + for _, info := range r.getLatestConfiguration() { + if !info.isVoter() { + continue + } + fr, ok := rep[info.ID] + if !ok { + continue + } + if fr.getNextIndex() > latestIndex { + latestIndex = fr.getNextIndex() + latest = fr + } + } + + return latest +} + +// leadershipTransfer 进行领导权转移,必须等到跟随者的进度赶上当前节点才可以发起请求 +func (r *Raft) leadershipTransfer(fr *replication) error { + var ( + rounds = r.Conf().LeadershipCatchUpRounds + i uint + ) + + for ; i < rounds && r.getLatestIndex() > fr.getNextIndex(); i++ { + fu := new(defaultDeferResponse) + fu.init() + select { + case <-r.shutDown.C: + return ErrShutDown + case <-fr.stop: + return nil + case fr.trigger <- fu: + if _, err := fu.Response(); err != nil { + return err + } + } + } + if i >= rounds { + return errors.New("reach the maximum number of catch-up rounds") + } + resp, err := r.rpc.FastTimeout(Ptr(fr.peer.Get()), &FastTimeoutRequest{ + RPCHeader: r.buildRPCHeader(), + Term: r.getCurrentTerm(), + LeaderShipTransfer: true, + }) + if err != nil { + return err + } + if !resp.Success { + return errors.New("peer reject time out") + } + return nil +} + +// processLeadershipTransfer 处理领导权转移,只能由领导人执行 +// 首先挑选出最新的跟随者,然后停止接收日志提交请求并等待跟随者赶上当前领导人的进度, +// 最后调用 RpcInterface.FastTimeout 通知其快速超时发起选举 +// 追赶进度时将等待固定的轮次,超过次数则返回失败 +func (r *Raft) processLeadershipTransfer(item interface{}) { + var ( + fu = item.(*leadershipTransferFuture) + fr *replication + ) + if id := fu.Peer.ID; len(id) > 0 { + fr = r.leaderState.replicate[id] + } else { + fr = r.pickLatestPeer() + } + if fr == nil { + fu.fail(errors.New("no suitable peer")) + return + } + if !r.leaderState.setupLeadershipTransfer(true) { + fu.fail(ErrLeadershipTransferInProgress) + return + } + go func() { + fu.responded(nil, r.leadershipTransfer(fr)) + r.leaderState.setupLeadershipTransfer(false) + }() +} + +// processReloadConfig 验证当前节点是否还是领导人 +func (r *Raft) processReloadConfig(item interface{}) { + var ( + oldConf = item.(*Config) + newConf = r.Conf() + ) + switch r.state.Get() { + case Follower: + if oldConf.HeartbeatTimeout != newConf.HeartbeatTimeout { + r.heartbeatTimeout = time.After(0) + } + case Leader: + if oldConf.HeartbeatTimeout != newConf.HeartbeatTimeout { + for _, replication := range r.leaderState.replicate { + asyncNotify(replication.notifyCh) + } + } + + case Candidate: + if oldConf.ElectionTimeout != newConf.ElectionTimeout { + r.electionTimeout = randomTimeout(newConf.ElectionTimeout) + } + default: + panic(fmt.Errorf("except state :%d ", r.state.Get())) + } +} + +// processVerifyLeader 验证当前节点是否还是领导人 +func (r *Raft) processVerifyLeader(item interface{}) { + var ( + fu = item.(*verifyFuture) + ) + // 先计算自己一票 + fu.quorumCount = r.quorumSize() + fu.voteGranted = 0 + fu.reportOnce = new(sync.Once) + fu.stepDown = r.leaderState.stepDown + fu.vote(true) + for _, repl := range r.leaderState.replicate { + repl.observe(fu) + asyncNotify(repl.notifyCh) + } +} + +// processClusterChange 集群配置更新,只能等到上次更新已提交后才可以开始新的变更 +func (r *Raft) processClusterChange(item interface{}) { + var ( + fu = item.(*configurationChangeFuture) + ) + if r.cluster.commitIndex != r.cluster.latestIndex { + fu.fail(errors.New("no stable configuration")) + return + } + if r.leaderState.getLeadershipTransfer() { + fu.fail(ErrLeadershipTransferInProgress) + return + } + if r.cluster.latestIndex != fu.req.pervIndex { + fu.fail(errors.New("configuration index not match")) + return + } + newConfiguration, err := r.clacNewConfiguration(fu.req) + if err != nil { + fu.fail(err) + return + } + logFu := &LogFuture{ + log: &LogEntry{ + Data: EncodeConfiguration(newConfiguration), + Type: LogCommand, + }, + } + logFu.init() + r.applyLog([]*LogFuture{logFu}) + + r.cluster.setLatest(logFu.Index(), newConfiguration) + r.onConfigurationUpdate() + r.reloadReplication() + +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..72356c5 --- /dev/null +++ b/config.go @@ -0,0 +1,93 @@ +package papillon + +import ( + "fmt" + "time" +) + +const ( + minCheckInterval = 10 * time.Millisecond +) + +type Config struct { + ElectionTimeout time.Duration + HeartbeatTimeout time.Duration + LeaderLeaseTimeout time.Duration + ApplyBatch int + MaxAppendEntries int + CommitTimeout time.Duration + SnapshotInterval time.Duration + SnapshotThreshold uint64 + TrailingLogs uint64 + Logger Logger + LocalID string + LeadershipCatchUpRounds uint +} +type ReloadableConfig struct { + TrailingLogs uint64 + SnapshotInterval time.Duration + SnapshotThreshold uint64 + HeartbeatTimeout time.Duration + ElectionTimeout time.Duration +} + +func DefaultConfig() *Config { + return &Config{ + HeartbeatTimeout: 1000 * time.Millisecond, + ElectionTimeout: 1000 * time.Millisecond, + CommitTimeout: 50 * time.Millisecond, + MaxAppendEntries: 64, + TrailingLogs: 10240, + SnapshotInterval: 120 * time.Second, + SnapshotThreshold: 8192, + LeaderLeaseTimeout: 500 * time.Millisecond, + } +} +func ValidateConfig(c *Config) (bool, string) { + if len(c.LocalID) == 0 { + return false, "LocalID is blank" + } + if c.TrailingLogs < 1 { + return false, "TrailingLogs must greater than 1" + } + if c.SnapshotThreshold < 0 { + return false, "SnapshotThreshold must greater than 0" + } + if c.MaxAppendEntries < 1 { + return false, "MaxAppendEntries must greater than 1" + } + maximumAppendEntries := 1024 + if c.MaxAppendEntries > maximumAppendEntries { + return false, fmt.Sprintf("MaxAppendEntries must less than or equal to %d", maximumAppendEntries) + } + if c.ApplyBatch < 1 { + return false, "ApplyBatch must greater than 1" + } + minimumTimeout := 5 * time.Millisecond + if c.HeartbeatTimeout < minimumTimeout { + return false, fmt.Sprintf("HeartbeatTimeout must greater than :%s", minimumTimeout) + } + if c.ElectionTimeout < minimumTimeout { + return false, fmt.Sprintf("ElectionTimeout must greater than :%s", minimumTimeout) + } + + if c.SnapshotInterval < minimumTimeout { + return false, fmt.Sprintf("SnapshotInterval must greater than :%s", minimumTimeout) + } + if c.LeaderLeaseTimeout < minCheckInterval { + return false, fmt.Sprintf("LeaderLeaseTimeout must greater than :%s", minCheckInterval) + } + minimumCommitTimeout := time.Millisecond + if c.CommitTimeout < minimumCommitTimeout { + return false, fmt.Sprintf("CommitTimeout must greater than :%s", minimumCommitTimeout) + } + // 处理投票时实现一种租期机制,如果能正常接受到了心跳,则拒绝投票请求,如果选举超时果断则会多发送无效请求 + if c.ElectionTimeout < c.HeartbeatTimeout { + return false, fmt.Sprintf("ElectionTimeout must greater than or equal HeartbeatTimeout") + } + // 确保领导者至少可以先下台 + if c.LeaderLeaseTimeout > c.HeartbeatTimeout { + return false, fmt.Sprintf("LeaderLeaseTimeout must greater than or equal HeartbeatTimeout") + } + return true, "" +} diff --git a/configuration.go b/configuration.go new file mode 100644 index 0000000..bb3e0a2 --- /dev/null +++ b/configuration.go @@ -0,0 +1,80 @@ +package papillon + +import ( + "encoding/json" + "fmt" +) + +type ( + Configuration struct { + Servers []ServerInfo + } + cluster struct { + latest Configuration + latestIndex uint64 + commit Configuration + commitIndex uint64 + } +) + +func (c *Configuration) Clone() (copy Configuration) { + copy.Servers = append([]ServerInfo(nil), c.Servers...) + return +} + +func (c *cluster) Clone() cluster { + return cluster{ + commit: c.commit.Clone(), + latest: c.latest.Clone(), + commitIndex: c.commitIndex, + latestIndex: c.latestIndex, + } +} + +func DecodeConfiguration(data []byte) (c Configuration) { + if err := json.Unmarshal(data, &c); err != nil { + panic(fmt.Errorf("failed to decode Configuration: %s ,%s", err, data)) + } + return +} +func EncodeConfiguration(c Configuration) (data []byte) { + data, err := json.Marshal(c) + if err != nil { + panic(fmt.Errorf("failed to encode Configuration :%s", err)) + } + return +} +func (c *cluster) setCommit(index uint64, configuration Configuration) { + c.commitIndex = index + c.commit = configuration +} +func (c *cluster) setLatest(index uint64, configuration Configuration) { + c.latestIndex = index + c.latest = configuration +} + +// validateConfiguration 校验配置是否合法 1. 可选举节点数大于 0 2. 节点不能重复 +func validateConfiguration(configuration Configuration) bool { + var ( + voter int + set = map[ServerID]bool{} + ) + for _, server := range configuration.Servers { + if set[server.ID] { + return false + } + set[server.ID] = true + if server.isVoter() { + voter++ + } + } + return voter > 0 +} +func (rc *ReloadableConfig) apply(to Config) Config { + to.TrailingLogs = rc.TrailingLogs + to.SnapshotInterval = rc.SnapshotInterval + to.SnapshotThreshold = rc.SnapshotThreshold + to.HeartbeatTimeout = rc.HeartbeatTimeout + to.ElectionTimeout = rc.ElectionTimeout + return to +} diff --git a/file_snapshot.go b/file_snapshot.go new file mode 100644 index 0000000..3c43366 --- /dev/null +++ b/file_snapshot.go @@ -0,0 +1,332 @@ +package papillon + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "hash" + "hash/crc64" + "io" + "io/fs" + "os" + "path/filepath" + "runtime" + "sort" + "strings" +) + +const ( + dirMode = 0755 + testFile = "snapShotTest" + tmpFileSuffix = ".tmp" + metaFile = "meta.json" + snapshotFile = "snapshot.bin" +) + +type ( + FileWithSync interface { + fs.File + Sync() error + } + FileSnapshot struct { + dir string + noSync bool + retainCount int + } + + FileSnapshotSink struct { + snapshotStore *FileSnapshot + close bool + meta *fileSnapshotMeta + dir string + parentPath string + noSync bool + hash hash.Hash64 + file FileWithSync + writer *bufio.Writer + } + fileSnapshotMeta struct { + *SnapShotMeta + CRC []byte + } + bufferReader struct { + buf *bufio.Reader + file FileWithSync + } +) + +func newCRC64() hash.Hash64 { + return crc64.New(crc64.MakeTable(crc64.ECMA)) +} +func sortMetaList(list []*fileSnapshotMeta) { + sort.Slice(list, func(i, j int) bool { + if list[i].Term != list[j].Term { + return list[i].Term > list[j].Term + } + if list[i].Index != list[j].Index { + return list[i].Index > list[j].Index + } + return list[i].ID > list[j].ID + }) +} + +func (b *bufferReader) Read(p []byte) (n int, err error) { + return b.buf.Read(p) +} + +func (b *bufferReader) Close() error { + return b.file.Close() +} + +func NewFileSnapshot(dirPath string, noSync bool, retainCount int) (*FileSnapshot, error) { + if err := os.MkdirAll(dirPath, dirMode); err != nil && !os.IsExist(err) { + return nil, err + } + + snapshot := &FileSnapshot{ + dir: dirPath, + noSync: noSync, + retainCount: retainCount, + } + if err := snapshot.testCreatePermission(); err != nil { + return nil, fmt.Errorf("test create permissions failed :%s", err) + } + return snapshot, nil +} + +func (f *FileSnapshot) testCreatePermission() error { + filePath := filepath.Join(f.dir, testFile) + file, err := os.Create(filePath) + if err != nil { + return err + } + _ = file.Close() + return os.Remove(filePath) +} + +func (f *FileSnapshot) readMeta(id string) (meta *fileSnapshotMeta, err error) { + metaFilePath := filepath.Join(f.dir, id, metaFile) + file, err := os.Open(metaFilePath) + if err != nil { + return nil, err + } + err = json.NewDecoder(file).Decode(&meta) + return +} + +func (f *FileSnapshot) Open(id string) (*SnapShotMeta, io.ReadCloser, error) { + meta, err := f.readMeta(id) + if err != nil { + return nil, nil, err + } + + snapshotFilePath := filepath.Join(f.dir, id, snapshotFile) + file, err := os.Open(snapshotFilePath) + defer func() { + if err != nil { + file.Close() + } + }() + if err != nil { + return nil, nil, err + } + hash64 := newCRC64() + if _, err = io.Copy(hash64, file); err != nil { + return nil, nil, err + } + if !bytes.Equal(hash64.Sum(nil), meta.CRC) { + return nil, nil, errors.New("CRC mismatch") + } + if _, err = file.Seek(0, 0); err != nil { + return nil, nil, err + } + return meta.SnapShotMeta, &bufferReader{ + buf: bufio.NewReader(file), + file: file, + }, nil +} + +func (f *FileSnapshot) List() (list []*SnapShotMeta, err error) { + metaList, err := f.getSnapshots() + if err != nil { + return nil, err + } + for _, meta := range metaList { + list = append(list, meta.SnapShotMeta) + if len(list) == f.retainCount { + break + } + } + return +} + +func (f *FileSnapshot) getSnapshots() (metaList []*fileSnapshotMeta, err error) { + dirList, err := os.ReadDir(f.dir) + if err != nil { + return nil, err + } + for _, entry := range dirList { + if !entry.IsDir() { + continue + } + if strings.HasSuffix(entry.Name(), tmpFileSuffix) { + continue + } + meta, err := f.readMeta(entry.Name()) + if err != nil { + return nil, err + } + metaList = append(metaList, meta) + } + sortMetaList(metaList) + return +} +func (s *FileSnapshotSink) writeMeta() error { + fileName := filepath.Join(s.dir, metaFile) + file, err := os.Create(fileName) + if err != nil { + return err + } + defer file.Close() + writer := bufio.NewWriter(file) + enc := json.NewEncoder(file) + enc.SetIndent("", " ") + if err = enc.Encode(s.meta); err != nil { + return err + } + if err = writer.Flush(); err != nil { + return err + } + if !s.noSync { + if err = file.Sync(); err != nil { + return err + } + } + return nil +} + +func (f *FileSnapshot) Create(version SnapShotVersion, index, term uint64, configuration Configuration, configurationIndex uint64, rpc RpcInterface) (SnapshotSink, error) { + name := snapshotName(term, index) + snapshotDir := filepath.Join(f.dir, name+tmpFileSuffix) + if err := os.MkdirAll(snapshotDir, dirMode); err != nil { + return nil, err + } + sink := &FileSnapshotSink{ + snapshotStore: f, + dir: snapshotDir, + noSync: f.noSync, + parentPath: f.dir, + meta: &fileSnapshotMeta{ + SnapShotMeta: &SnapShotMeta{ + Version: version, + ID: name, + Index: index, + Term: term, + Configuration: configuration, + ConfigurationIndex: configurationIndex, + }, + }, + hash: newCRC64(), + } + if err := sink.writeMeta(); err != nil { + return nil, err + } + snapshotPath := filepath.Join(snapshotDir, snapshotFile) + if file, err := os.Create(snapshotPath); err != nil { + return nil, err + } else { + sink.file = file + sink.writer = bufio.NewWriter(io.MultiWriter(sink.hash, file)) + } + return sink, nil +} + +func (f *FileSnapshotSink) Write(p []byte) (n int, err error) { + return f.writer.Write(p) +} + +func (f *FileSnapshotSink) Close() error { + if f.close { + return nil + } + f.close = true + if err := f.finalize(); err != nil { + return err + } + if err := f.writeMeta(); err != nil { + return err + } + newPath := strings.TrimSuffix(f.dir, tmpFileSuffix) + + if err := os.Rename(f.dir, newPath); err != nil { + return err + } + if err := func() error { + if !(!f.noSync && runtime.GOOS != "windows") { + return nil + } + // 对于目录也需要执行 fsync 操作 https://man7.org/linux/man-pages/man2/fsync.2.html + file, err := os.Open(f.parentPath) + if err != nil { + return err + } + defer file.Close() + return file.Sync() + }(); err != nil { + return err + } + + return f.snapshotStore.reapSnapshot() + +} + +func (f *FileSnapshotSink) ID() string { + return f.meta.ID +} + +func (f *FileSnapshotSink) Cancel() error { + if err := f.finalize(); err != nil { + return err + } + return os.RemoveAll(f.dir) +} + +func (f *FileSnapshotSink) finalize() error { + if err := f.writer.Flush(); err != nil { + return err + } + if !f.noSync { + if err := f.file.Sync(); err != nil { + return err + } + } + fileInfo, err := f.file.Stat() + if err != nil { + return err + } + if err = f.file.Close(); err != nil { + return err + } + f.meta.Size = fileInfo.Size() + f.meta.CRC = f.hash.Sum(nil) + return nil +} + +func (f *FileSnapshot) reapSnapshot() error { + metaList, err := f.getSnapshots() + if err != nil { + return err + } + if len(metaList) < f.retainCount { + return nil + } + for _, meta := range metaList[f.retainCount:] { + filePath := filepath.Join(f.dir, meta.ID) + if err := os.RemoveAll(filePath); err != nil { + return err + } + } + return nil +} diff --git a/fsm.go b/fsm.go new file mode 100644 index 0000000..90b9723 --- /dev/null +++ b/fsm.go @@ -0,0 +1,116 @@ +package papillon + +import "io" + +type FSM interface { + Apply(*LogEntry) interface{} + ReStore(reader io.ReadCloser) error + Snapshot() (FsmSnapshot, error) +} + +type FsmSnapshot interface { + Persist(sink SnapshotSink) error + Release() +} + +type BatchFSM interface { + FSM + BatchApply([]*LogEntry) []interface{} +} + +// runFSM 状态机线程 +func (r *Raft) runFSM() { + batchFSM, canBatchApply := r.fsm.(BatchFSM) + configurationStore, canConfigurationStore := r.kvStore.(ConfigurationStorage) + var ( + lastAppliedIdx, lastAppliedTerm uint64 //用于创建快照 + ) + canApply := func(future *LogFuture) bool { + switch future.log.Type { + case LogCommand: + return true + } + return false + } + processConfiguration := func(fu *LogFuture) { + if fu.log.Type != LogConfiguration || !canConfigurationStore { + return + } + configurationStore.SetConfiguration(fu.log.Index, DecodeConfiguration(fu.log.Data)) + + } + applyBatch := func(futures []*LogFuture) { + var ( + logs []*LogEntry + respFutures []*LogFuture + ) + + for _, fu := range futures { + processConfiguration(fu) + if canApply(fu) { + logs = append(logs, fu.log) + respFutures = append(respFutures, fu) + } else { + fu.success() + } + } + if len(logs) == 0 { + return + } + resp := batchFSM.BatchApply(logs) + for i, fu := range respFutures { + fu.responded(resp[i], nil) + } + } + applySingle := func(fu *LogFuture) error { + processConfiguration(fu) + if !canApply(fu) { + fu.responded(nil, nil) + return nil + } + fu.responded(r.fsm.Apply(fu.log), nil) + return nil + } + snapshot := func(fu *fsmSnapshotFuture) { + if lastAppliedIdx == 0 { + fu.fail(ErrNothingNewToSnapshot) + return + } + snapshot, err := r.fsm.Snapshot() + if err != nil { + r.logger.Errorf("") + } + fu.responded(&SnapShotFutureResp{ + term: lastAppliedTerm, + index: lastAppliedIdx, + fsmSnapshot: snapshot, + }, err) + } + for { + select { + case <-r.shutDown.C: + return + case futures := <-r.fsmApplyCh: + if canBatchApply { + applyBatch(futures) + if len(futures) > 0 { + future := futures[len(futures)-1] + lastAppliedIdx, lastAppliedTerm = future.log.Index, future.log.Term + } + } else { + for _, future := range futures { + applySingle(future) + lastAppliedIdx, lastAppliedTerm = future.log.Index, future.log.Term + } + } + case fu := <-r.fsmRestoreCh: + meta, err := r.recoverSnapshotByID(fu.ID) + lastAppliedIdx = meta.Index + lastAppliedTerm = meta.Term + fu.responded(nil, err) + + case fu := <-r.fsmSnapshotCh: + snapshot(fu) + } + } +} diff --git a/future.go b/future.go new file mode 100644 index 0000000..8a6f8c4 --- /dev/null +++ b/future.go @@ -0,0 +1,267 @@ +package papillon + +import ( + "errors" + "fmt" + common_util "github.com/fuyao-w/common-util" + "io" + "sync" + "time" +) + +var ( + FutureErrTimeout = errors.New("time out") + FutureErrNotLeader = errors.New("not leader") +) + +// OpenSnapShot 用于 API 请求执行完快照后再需要的时候延迟打开快照 +type OpenSnapShot = func() (*SnapShotMeta, io.ReadCloser, error) + +// nilRespFuture Future 默认不需要返回值的类型 +type nilRespFuture = interface{} + +// Future 用于异步提交,Response 会同步返回,可以重复调用 +type Future[T any] interface { + Response() (T, error) +} + +// defaultFuture 默认不需要返回值的 Future +type defaultFuture = Future[nilRespFuture] + +type defaultDeferResponse = deferResponse[nilRespFuture] + +type reject interface { + reject(state State) +} + +type deferResponse[T any] struct { + err error + once *sync.Once + errCh chan error + response T + ShutdownCh <-chan struct{} +} + +func (d *deferResponse[_]) reject(state State) { + if state == ShutDown { + d.fail(ErrShutDown) + return + } + d.fail(fmt.Errorf("current state %s can't process", state.String())) +} + +func (d *deferResponse[_]) init() { + d.errCh = make(chan error, 1) + d.once = new(sync.Once) +} +func (d *deferResponse[_]) setTimeout() { + d.errCh = make(chan error, 1) + d.once = new(sync.Once) +} + +func (d *deferResponse[T]) Response() (T, error) { + d.once.Do(func() { + select { + case d.err = <-d.errCh: + case <-d.ShutdownCh: + d.err = ErrShutDown + } + }) + return d.response, d.err +} + +type LogFuture struct { + deferResponse[any] + log *LogEntry +} + +func (l *LogFuture) Index() uint64 { + return l.log.Index +} + +// responded 返回响应结果,在调用该方法后 Response 就会返回,该方法不支持重复调用 +func (d *deferResponse[T]) responded(resp T, err error) { + d.response = resp + select { + case d.errCh <- err: + default: + panic("defer response not init") + } + close(d.errCh) +} + +func (d *deferResponse[T]) success() { + d.responded(common_util.Zero[T](), nil) +} +func (d *deferResponse[T]) fail(err error) { + d.responded(common_util.Zero[T](), err) +} + +type AppendEntriesFuture interface { + Future[*AppendEntryResponse] + StartAt() time.Time + Request() *AppendEntryRequest +} + +type appendEntriesFuture struct { + deferResponse[*AppendEntryResponse] + startAt time.Time + req *AppendEntryRequest +} + +func newAppendEntriesFuture(req *AppendEntryRequest) *appendEntriesFuture { + af := &appendEntriesFuture{ + startAt: time.Now(), + req: req, + } + af.init() + return af +} +func (a *appendEntriesFuture) StartAt() time.Time { + return a.startAt +} + +func (a *appendEntriesFuture) Request() *AppendEntryRequest { + return a.req +} + +type configurationChangeFuture struct { + LogFuture + req *configurationChangeRequest +} + +type configurationChangeCommend uint64 + +const ( + addServer configurationChangeCommend = iota + 1 + removeServer + updateServer +) + +type configurationChangeRequest struct { + command configurationChangeCommend + peer ServerInfo + pervIndex uint64 +} + +type ( + verifyFuture struct { + deferResponse[bool] + sync.Mutex + quorumCount uint + voteGranted uint + reportOnce *sync.Once + stepDown chan struct{} + } +) + +func (v *verifyFuture) report(leadership bool) { + v.reportOnce.Do(func() { + v.responded(leadership, nil) + if !leadership { + asyncNotify(v.stepDown) + } + }) +} +func (v *verifyFuture) vote(leadership bool) { + v.Lock() + defer v.Unlock() + if leadership { + v.voteGranted++ + if v.voteGranted >= v.quorumCount { + v.report(true) + } + } else { + v.report(false) + } +} + +type userRestoreFuture struct { + defaultDeferResponse + meta *SnapShotMeta + reader io.ReadCloser +} + +type leadershipTransferFuture struct { + defaultDeferResponse + Peer ServerInfo +} + +type clusterGetFuture struct { + deferResponse[cluster] +} + +type apiClusterGetFuture struct { + deferResponse[Configuration] +} + +// bootstrapFuture is used to attempt a live bootstrap of the cluster. See the +// Raft object's BootstrapCluster member function for more details. +type bootstrapFuture struct { + defaultDeferResponse + + // configuration is the proposed bootstrap configuration to apply. + configuration Configuration +} + +type ( + fsmSnapshotFuture struct { + deferResponse[*SnapShotFutureResp] + } + SnapShotFutureResp struct { + term, index uint64 + fsmSnapshot FsmSnapshot + } +) + +// apiSnapshotFuture is used for waiting on a user-triggered snapshot to +// complete. +type apiSnapshotFuture struct { + deferResponse[OpenSnapShot] +} + +// restoreFuture is used for requesting an FSM to perform a +// snapshot restore. Used internally only. +type restoreFuture struct { + defaultDeferResponse + ID string +} + +type shutDownFuture struct { + raft *Raft +} + +func (s *shutDownFuture) Response() (nilRespFuture, error) { + if s.raft == nil { + return nil, nil + } + s.raft.waitShutDown() + + if inter, ok := s.raft.rpc.(interface { + Close() error + }); ok { + inter.Close() + } + return nil, nil +} + +type ApplyFuture interface { + IndexFuture + Future[nilRespFuture] +} +type IndexFuture interface { + Index() uint64 + defaultFuture +} + +type errFuture[T any] struct { + err error +} + +func (e *errFuture[T]) Index() uint64 { + return 0 +} + +func (e *errFuture[T]) Response() (t T, _ error) { + return t, e.err +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..89e9696 --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module github.com/fuyao-w/papillon + +go 1.20 + +require ( + github.com/fuyao-w/common-util v0.0.0-20230325061500-2195cc554530 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect + github.com/sasha-s/go-deadlock v0.3.1 // indirect + github.com/smartystreets/assertions v1.13.1 // indirect + github.com/smartystreets/goconvey v1.8.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..11b9f1c --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/fuyao-w/common-util v0.0.0-20230325061500-2195cc554530 h1:xW/fNP/gyfKUshlorjhEF7ccgHwQGwb1q36mG9iZqK4= +github.com/fuyao-w/common-util v0.0.0-20230325061500-2195cc554530/go.mod h1:kV74J/K4zGi+Utc156GHNI2T6oJJKIi+aSVc0CFkFb8= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o= +github.com/sasha-s/go-deadlock v0.3.1 h1:sqv7fDNShgjcaxkO0JNcOAlr8B9+cV5Ey/OB71efZx0= +github.com/sasha-s/go-deadlock v0.3.1/go.mod h1:F73l+cr82YSh10GxyRI6qZiCgK64VaZjwesgfQ1/iLM= +github.com/smartystreets/assertions v1.13.1 h1:Ef7KhSmjZcK6AVf9YbJdvPYG9avaF0ZxudX+ThRdWfU= +github.com/smartystreets/assertions v1.13.1/go.mod h1:cXr/IwVfSo/RbCSPhoAPv73p3hlSdrBH/b3SdnW/LMY= +github.com/smartystreets/goconvey v1.8.0 h1:Oi49ha/2MURE0WexF052Z0m+BNSGirfjg5RL+JXWq3w= +github.com/smartystreets/goconvey v1.8.0/go.mod h1:EdX8jtrTIj26jmjCOVNMVSIYAtgexqXKHOXW2Dx9JLg= diff --git a/log.go b/log.go new file mode 100644 index 0000000..1f2f503 --- /dev/null +++ b/log.go @@ -0,0 +1,31 @@ +package papillon + +import ( + "time" +) + +var ( + ErrKeyNotFound = customError{"not found"} + ErrKeyIsNil = customError{"key is nil"} + ErrValueIsNil = customError{"value is nil"} + ErrRange = customError{"from must no bigger than to"} +) + +type ( + LogType uint8 + LogEntry struct { + Index uint64 + Term uint64 + Data []byte + Type LogType + CreatedAt time.Time + } +) + +const ( + LogCommand LogType = iota + 1 + LogBarrier + // LogNoop 只用于确认 leader + LogNoop + LogConfiguration +) diff --git a/mem_transport.go b/mem_transport.go new file mode 100644 index 0000000..8dfad56 --- /dev/null +++ b/mem_transport.go @@ -0,0 +1,258 @@ +package papillon + +import ( + "container/list" + "errors" + "io" + "sync" + "time" +) + +type memRPC struct { + sync.Mutex + consumerCh chan *RPC + localAddr ServerAddr + peerMap map[ServerAddr]RpcInterface + pipeline list.List + timeout time.Duration + shutDown shutDown + fastPath fastPath +} + +func newMemRpc(localAddr string) *memRPC { + return &memRPC{ + localAddr: ServerAddr(localAddr), + consumerCh: make(chan *RPC), + peerMap: map[ServerAddr]RpcInterface{}, + timeout: time.Second, + shutDown: newShutDown(), + } +} + +func (m *memRPC) Connect(addr ServerAddr, rpc RpcInterface) { + m.Lock() + defer m.Unlock() + if _, ok := m.peerMap[addr]; ok { + return + } + m.peerMap[addr] = rpc +} + +func (m *memRPC) Disconnect(addr ServerAddr) { + m.Lock() + defer m.Unlock() + delete(m.peerMap, addr) + for e := m.pipeline.Front(); e != nil; e.Next() { + if p := e.Value.(*menAppendEntryPipeline); p.peer.localAddr == addr { + m.pipeline.Remove(e) + } + } +} + +func (m *memRPC) DisconnectAll() { + m.Lock() + defer m.Unlock() + m.peerMap = map[ServerAddr]RpcInterface{} + for e := m.pipeline.Front(); e != nil; e = e.Next() { + e.Value.(*menAppendEntryPipeline).Close() + } + m.pipeline.Init() +} + +type menAppendEntryPipeline struct { + peer, rpc *memRPC + processedCh chan AppendEntriesFuture + inProgressCh chan *memAppendEntriesInflight + shutDownCh chan struct{} + shutDownOnce sync.Once +} + +type memAppendEntriesInflight struct { + af *appendEntriesFuture + cmd *RPC +} + +func newMenAppendEntryPipeline(peer, rpc *memRPC) *menAppendEntryPipeline { + return &menAppendEntryPipeline{ + peer: peer, + rpc: rpc, + shutDownCh: make(chan struct{}), + inProgressCh: make(chan *memAppendEntriesInflight), + processedCh: make(chan AppendEntriesFuture), + } +} + +func (pipe *menAppendEntryPipeline) decodeResponse() { + timeout := pipe.rpc.timeout + for { + select { + case <-pipe.shutDownCh: + return + case inflight := <-pipe.inProgressCh: + var timeoutCh <-chan time.Time + if timeout > 0 { + timeoutCh = time.After(timeout) + } + select { + case rpcResp := <-inflight.cmd.Response: + resp := rpcResp.(*AppendEntryResponse) + inflight.af.responded(resp, nil) + select { + case pipe.processedCh <- inflight.af: + case <-pipe.shutDownCh: + return + } + case <-timeoutCh: + inflight.af.responded(nil, ErrTimeout) + select { + case pipe.processedCh <- inflight.af: + case <-pipe.shutDownCh: + return + } + case <-pipe.shutDownCh: + return + } + } + } +} +func (pipe *menAppendEntryPipeline) AppendEntries(request *AppendEntryRequest) (AppendEntriesFuture, error) { + var ( + af = newAppendEntriesFuture(request) + timeout <-chan time.Time + ) + if t := pipe.rpc.timeout; t > 0 { + timeout = time.After(t) + } + + cmd := RPC{ + CmdType: CmdAppendEntry, + Request: request, + Response: make(chan interface{}, 1), + } + + select { + case pipe.peer.consumerCh <- &cmd: + case <-timeout: + return nil, ErrTimeout + case <-pipe.shutDownCh: + return nil, ErrShutDown + } + select { + case pipe.inProgressCh <- &memAppendEntriesInflight{af: af, cmd: &cmd}: + case <-pipe.shutDownCh: + return nil, ErrPipelineShutdown + } + return af, nil +} + +func (pipe *menAppendEntryPipeline) Consumer() <-chan AppendEntriesFuture { + return pipe.processedCh +} + +func (pipe *menAppendEntryPipeline) Close() error { + pipe.shutDownOnce.Do(func() { + close(pipe.shutDownCh) + }) + return nil +} + +func (m *memRPC) getPeer(addr ServerAddr) *memRPC { + m.Lock() + defer m.Unlock() + return m.peerMap[addr].(*memRPC) +} + +func (m *memRPC) Consumer() <-chan *RPC { + return m.consumerCh +} +func (m *memRPC) doRpc(cmdType rpcType, peer *memRPC, request interface{}, reader io.Reader) (interface{}, error) { + timeout := m.timeout + cmd := &RPC{ + CmdType: cmdType, + Request: request, + Reader: reader, + Response: make(chan interface{}), + } + now := time.Now() + select { + case peer.consumerCh <- cmd: + timeout = time.Now().Sub(now) + case <-time.After(timeout): + } + + select { + case resp := <-cmd.Response: + return resp, nil + case <-time.After(m.timeout): + return nil, errors.New("time out") + } +} + +func (m *memRPC) VoteRequest(info *ServerInfo, request *VoteRequest) (*VoteResponse, error) { + resp, err := m.doRpc(CmdVoteRequest, m.getPeer(info.Addr), request, nil) + if err != nil { + return nil, err + } + return resp.(*VoteResponse), nil +} + +func (m *memRPC) AppendEntries(info *ServerInfo, request *AppendEntryRequest) (*AppendEntryResponse, error) { + resp, err := m.doRpc(CmdAppendEntry, m.getPeer(info.Addr), request, nil) + if err != nil { + return nil, err + } + return resp.(*AppendEntryResponse), nil +} + +func (m *memRPC) AppendEntryPipeline(info *ServerInfo) (AppendEntryPipeline, error) { + peer := m.getPeer(info.Addr) + m.Lock() + defer m.Unlock() + + pipe := newMenAppendEntryPipeline(peer, m) + m.pipeline.PushBack(pipe) + go pipe.decodeResponse() + return pipe, nil +} + +func (m *memRPC) InstallSnapShot(info *ServerInfo, request *InstallSnapshotRequest, reader io.Reader) (*InstallSnapshotResponse, error) { + peer := m.getPeer(info.Addr) + resp, err := m.doRpc(CmdInstallSnapshot, peer, request, reader) + if err != nil { + return nil, err + } + return resp.(*InstallSnapshotResponse), nil +} + +func (m *memRPC) SetHeartbeatFastPath(cb fastPath) { + m.fastPath = cb +} + +func (m *memRPC) FastTimeout(info *ServerInfo, request *FastTimeoutRequest) (*FastTimeoutResponse, error) { + resp, err := m.doRpc(CmdFastTimeout, m.getPeer(info.Addr), request, nil) + if err != nil { + return nil, err + } + return resp.(*FastTimeoutResponse), nil +} + +func (m *memRPC) LocalAddr() ServerAddr { + return m.localAddr +} + +func (m *memRPC) EncodeAddr(info *ServerInfo) []byte { + return []byte(info.Addr) +} + +func (m *memRPC) DecodeAddr(bytes []byte) ServerAddr { + return ServerAddr(bytes) +} + +func batchConn(rpc ...*memRPC) { + for _, outer := range rpc { + for _, inner := range rpc { + outer.Connect(inner.LocalAddr(), inner) + inner.Connect(outer.LocalAddr(), outer) + } + } +} diff --git a/memory_log.go b/memory_log.go new file mode 100644 index 0000000..cee1f20 --- /dev/null +++ b/memory_log.go @@ -0,0 +1,163 @@ +package papillon + +import ( + . "github.com/fuyao-w/common-util" + "github.com/fuyao-w/deepcopy" + "time" +) + +type memLog struct { + firstIndex, lastIndex uint64 + log map[uint64]*LogEntry +} +type MemorySore struct { + kv *LockItem[map[string]interface{}] // 实现 KVStorage + log *LockItem[memLog] // 实现 LogStore +} + +func newMemoryStore() *MemorySore { + return &MemorySore{ + kv: NewLockItem(map[string]interface{}{}), + log: NewLockItem(memLog{ + log: map[uint64]*LogEntry{}, + }), + } +} + +func (m *MemorySore) GetLogRange(from, to uint64) (logs []*LogEntry, err error) { + m.log.Action(func(t *memLog) { + for i := from; i <= to; i++ { + log := t.log[i] + if log == nil { + continue + } + logs = append(logs, log) + } + }) + return +} + +func (m *MemorySore) Get(key []byte) (val []byte, err error) { + if len(key) == 0 { + return nil, ErrKeyIsNil + } + kv := m.kv.Lock() + defer m.kv.Unlock() + + v, ok := (*kv)[string(key)] + if ok { + return v.([]byte), nil + } + return nil, ErrKeyNotFound +} + +func (m *MemorySore) Set(key []byte, val []byte) (err error) { + if len(key) == 0 { + return ErrKeyIsNil + } + if len(val) == 0 { + return ErrValueIsNil + } + m.kv.Action(func(t *map[string]interface{}) { + (*t)[string(key)] = val + }) + return +} + +func (m *MemorySore) SetUint64(key []byte, val uint64) (err error) { + if len(key) == 0 { + return ErrKeyIsNil + } + m.kv.Action(func(t *map[string]interface{}) { + (*t)[string(key)] = val + }) + return +} + +func (m *MemorySore) GetUint64(key []byte) (uint64, error) { + if len(key) == 0 { + return 0, ErrKeyIsNil + } + kv := m.kv.Lock() + defer m.kv.Unlock() + + v, ok := (*kv)[string(key)] + if ok { + return v.(uint64), nil + } + return 0, ErrKeyNotFound +} +func (m *MemorySore) FirstIndex() (uint64, error) { + var idx uint64 + m.log.Action(func(t *memLog) { + idx = (*t).firstIndex + }) + return idx, nil +} + +func (m *MemorySore) LastIndex() (uint64, error) { + var idx uint64 + m.log.Action(func(t *memLog) { + idx = (*t).lastIndex + }) + return idx, nil +} + +func (m *MemorySore) GetLog(index uint64) (log *LogEntry, err error) { + m.log.Action(func(t *memLog) { + s := *t + l, ok := s.log[index] + if ok { + log = deepcopy.Copy(l).(*LogEntry) + } else { + err = ErrKeyNotFound + } + }) + return +} + +func (m *MemorySore) SetLogs(logs []*LogEntry) (err error) { + m.log.Action(func(t *memLog) { + s := *t + var exists []uint64 + for _, entry := range logs { + if _, ok := s.log[entry.Index]; ok { + exists = append(exists, entry.Index) + } + } + for _, entry := range logs { + s.log[entry.Index] = deepcopy.Copy(entry).(*LogEntry) + s.log[entry.Index].CreatedAt = time.Now() + if t.firstIndex == 0 { + t.firstIndex = entry.Index + } + if entry.Index > t.lastIndex { + t.lastIndex = entry.Index + } + } + }) + return nil +} + +func (m *MemorySore) DeleteRange(min, max uint64) error { + if min > max { + return ErrRange + } + m.log.Action(func(t *memLog) { + s := *t + for i := min; i <= max; i++ { + delete(s.log, i) + } + if min <= s.firstIndex { + s.firstIndex = max + 1 + } + if max >= s.lastIndex { + s.lastIndex = min - 1 + } + if s.firstIndex > s.lastIndex { + s.firstIndex = 0 + s.lastIndex = 0 + } + }) + return nil +} diff --git a/men_fsm.go b/men_fsm.go new file mode 100644 index 0000000..0ed8d5e --- /dev/null +++ b/men_fsm.go @@ -0,0 +1,126 @@ +package papillon + +import ( + "encoding/json" + . "github.com/fuyao-w/common-util" + "hash/adler32" + "io" +) + +type logHash struct { + lastHash []byte +} + +type kvSchema [2]string + +func (s kvSchema) encode(k, v string) []byte { + s[0], s[1] = k, v + b, _ := json.Marshal(s) + return b +} +func (s kvSchema) decode(data []byte) (k, v string) { + _ = json.Unmarshal(data, &s) + return s[0], s[1] +} + +func (l *logHash) Add(p []byte) { + hasher := adler32.New() + hasher.Write(l.lastHash) + hasher.Write(p) + l.lastHash = hasher.Sum(nil) +} + +type memFSM struct { + logHash + lastIndex, lastTerm uint64 + kv *LockItem[map[string]string] // 简单提供 kv 功能 + configurations []Configuration +} + +func (m *memFSM) StoreConfiguration(index uint64, configuration Configuration) { + m.configurations = append(m.configurations, configuration) +} + +func newMemFSM() *memFSM { + return &memFSM{ + kv: NewLockItem(map[string]string{}), + } +} + +func (m *memFSM) getVal(key string) (val string) { + kv := m.kv.Lock() + defer m.kv.Unlock() + return (*kv)[key] +} + +type memSnapshotContainer struct { + LastIndex uint64 `json:"last_index"` + LastTerm uint64 `json:"last_term"` + Content []byte `json:"content"` +} + +func (m *memFSM) Persist(sink SnapshotSink) error { + b, _ := json.Marshal(m.kv.Get()) + c, _ := json.Marshal(memSnapshotContainer{ + LastIndex: m.lastIndex, + LastTerm: m.lastTerm, + Content: b, + }) + _, err := sink.Write(c) + return err +} + +func (m *memFSM) Release() { +} + +type applyItem struct { + index uint64 + term uint64 + data []byte +} + +func (m *memFSM) Apply(entry *LogEntry) interface{} { + if entry.Index < m.lastIndex { + panic("index error") + } + if entry.Term < m.lastTerm { + panic("term error") + } + m.lastTerm = entry.Term + m.lastIndex = entry.Index + m.Add(entry.Data) + if k, v := kvSchema.decode(kvSchema{}, entry.Data); len(k) > 0 { + m.kv.Action(func(t *map[string]string) { + (*t)[k] = v + }) + } + return nil +} + +func (m *memFSM) Snapshot() (FsmSnapshot, error) { + return &*m, nil +} + +func (m *memFSM) ReStore(rc io.ReadCloser) error { + defer rc.Close() + var c memSnapshotContainer + buf, err := io.ReadAll(rc) + if err != nil { + return err + } + err = json.Unmarshal(buf, &c) + if err != nil { + return err + } + m.lastTerm = c.LastTerm + m.lastIndex = c.LastIndex + newKv := map[string]string{} + err = json.Unmarshal(c.Content, &newKv) + if err != nil { + return err + } + m.kv.Action(func(t *map[string]string) { + *t = newKv + }) + return nil +} diff --git a/net_protocol.go b/net_protocol.go new file mode 100644 index 0000000..22da370 --- /dev/null +++ b/net_protocol.go @@ -0,0 +1,68 @@ +package papillon + +import ( + "bufio" + "errors" + "strconv" + "strings" +) + +/* +协议:${魔数 1byte 0x3} ${ 请求类型 1 byte} ${包体长度 不固定} \n 包体 +*/ +const ( + delim = '\n' + magic = 0x3 +) + +type DefaultPackageParser struct{} + +var defaultPackageParser = new(DefaultPackageParser) + +func (d *DefaultPackageParser) Encode(writer *bufio.Writer, cmdType rpcType, data []byte) (err error) { + for _, f := range []func() error{ + func() error { return writer.WriteByte(magic) }, // magic + func() error { return writer.WriteByte(byte(cmdType)) }, // 命令类型 + func() error { _, e := writer.WriteString(strconv.Itoa(len(data))); return e }, // 包体长度 + func() error { return writer.WriteByte(delim) }, // 分割符 + func() error { _, e := writer.Write(data); return e }, // 包体 + } { + if err = f(); err != nil { + return + } + } + return err +} + +func (d *DefaultPackageParser) Decode(reader *bufio.Reader) (rpcType, []byte, error) { + _magic, err := reader.ReadByte() + if err != nil { + return 0, nil, err + } + + if _magic != magic { + return 0, nil, errors.New("unrecognized request") + } + + // 获取命令类型 + ct, err := reader.ReadByte() + if err != nil { + return 0, nil, err + } + + // 获取包体长度 + pkgLength, err := reader.ReadString(delim) + if err != nil { + return 0, nil, err + } + + // 获取包体 + length, err := strconv.Atoi(strings.TrimRight(pkgLength, string(delim))) + if err != nil { + return 0, nil, err + } + + buf := make([]byte, length) + _, err = reader.Read(buf) + return rpcType(ct), buf, err +} diff --git a/net_transport.go b/net_transport.go new file mode 100644 index 0000000..e21e3b5 --- /dev/null +++ b/net_transport.go @@ -0,0 +1,435 @@ +package papillon + +import ( + "bufio" + "context" + "encoding/json" + "errors" + . "github.com/fuyao-w/common-util" + "io" + "log" + "net" + "sync" + "time" +) + +type ( + netConn struct { + remote ServerAddr + c net.Conn + rw *bufio.ReadWriter + } + ServerAddrProvider interface { + GetAddr(id ServerID) (ServerAddr, error) + } + typConnPool map[ServerAddr][]*netConn + NetTransport struct { + logger Logger + shutDown shutDown + timeout time.Duration + cmdChan chan *RPC + netLayer NetLayer + connPoll *connPool + serverAddrProvider ServerAddrProvider + processor Processor + heartbeatFastPath fastPath + TimeoutScale int64 + ctx *LockItem[ctx] + } + ctx struct { + ctx context.Context + cancel context.CancelFunc + } + connPool struct { + pool *LockItem[typConnPool] + maxSinglePoolNum int + } +) + +func (n *NetTransport) EncodeAddr(info *ServerInfo) []byte { + return []byte(info.Addr) +} + +func (n *NetTransport) DecodeAddr(bytes []byte) ServerAddr { + return ServerAddr(bytes) +} + +func (n *NetTransport) LocalAddr() ServerAddr { + return ServerAddr(n.netLayer.Addr().String()) +} + +func (n *NetTransport) Consumer() <-chan *RPC { + return n.cmdChan +} + +func (n *NetTransport) getServerAddr(info *ServerInfo) ServerAddr { + if n.serverAddrProvider == nil { + return info.Addr + } + addr, err := n.serverAddrProvider.GetAddr(info.ID) + if err != nil { + return info.Addr + } + return addr +} + +func (n *NetTransport) sendRpc(conn *netConn, cmdType rpcType, request interface{}) error { + data, err := defaultCmdConverter.Serialization(request) + if err != nil { + return err + } + if err = defaultPackageParser.Encode(conn.rw.Writer, cmdType, data); err != nil { + return err + } + return conn.rw.Flush() +} +func (n *NetTransport) recvRpc(conn *netConn, resp interface{}) error { + _, data, err := defaultPackageParser.Decode(conn.rw.Reader) + if err != nil { + return err + } + err = defaultCmdConverter.Deserialization(data, resp) + if err != nil { + return err + } + return nil +} +func (n *NetTransport) genericRPC(info *ServerInfo, cmdType rpcType, request, response interface{}) (err error) { + conn, err := n.getConn(info) + if err != nil { + return err + } + if n.timeout > 0 { + conn.c.SetDeadline(time.Now().Add(n.timeout)) + } + defer func() { + if err != nil { + conn.Close() + data, _ := json.Marshal(request) + n.logger.Infof("genericRPC errorx : %s , rpcType :%d , req :%s", err, cmdType, data) + } else { + n.connPoll.PutConn(conn) + } + }() + if err = n.sendRpc(conn, cmdType, request); err != nil { + return + } + + return n.recvRpc(conn, response) +} +func (n *NetTransport) getConn(info *ServerInfo) (*netConn, error) { + addr := n.getServerAddr(info) + if conn := n.connPoll.GetConn(addr); conn != nil { + return conn, nil + } + conn, err := n.netLayer.Dial(info.Addr, n.timeout) + if err != nil { + return nil, err + } + return newNetConn(info.Addr, conn), nil +} +func newNetConn(addr ServerAddr, conn net.Conn) *netConn { + return &netConn{ + remote: addr, + c: conn, + rw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + } +} +func (n *NetTransport) VoteRequest(info *ServerInfo, request *VoteRequest) (*VoteResponse, error) { + var resp = new(VoteResponse) + return resp, n.genericRPC(info, CmdVoteRequest, request, resp) +} + +func (n *NetTransport) AppendEntries(info *ServerInfo, request *AppendEntryRequest) (*AppendEntryResponse, error) { + var resp = new(AppendEntryResponse) + return resp, n.genericRPC(info, CmdAppendEntry, request, resp) +} + +func (n *NetTransport) AppendEntryPipeline(info *ServerInfo) (AppendEntryPipeline, error) { + conn, err := n.getConn(info) + if err != nil { + return nil, err + } + return newNetPipeline(n, conn), err +} + +func (n *NetTransport) InstallSnapShot(info *ServerInfo, request *InstallSnapshotRequest, r io.Reader) (*InstallSnapshotResponse, error) { + conn, err := n.getConn(info) + if err != nil { + return nil, err + } + defer conn.Close() + if n.timeout > 0 { + conn.c.SetDeadline(time.Now().Add(Max(n.timeout*time.Duration(request.SnapshotMeta.Size/n.TimeoutScale), n.timeout))) + } + if err = n.sendRpc(conn, CmdInstallSnapshot, request); err != nil { + return nil, err + } + if _, err = io.Copy(conn.rw, r); err != nil { + return nil, err + } + if err = conn.rw.Flush(); err != nil { + n.logger.Errorf("InstallSnapShot|Flush errorx :%s", err) + return nil, err + } + + var resp = new(InstallSnapshotResponse) + if err = n.recvRpc(conn, resp); err != nil { + n.logger.Errorf("InstallSnapShot|recvRpc errorx :%s", err) + return nil, err + } + return resp, nil +} + +func (n *NetTransport) SetHeartbeatFastPath(cb fastPath) { + n.processor.SetFastPath(cb) +} + +func (n *NetTransport) FastTimeout(info *ServerInfo, req *FastTimeoutRequest) (*FastTimeoutResponse, error) { + var resp = new(FastTimeoutResponse) + return resp, n.genericRPC(info, CmdFastTimeout, req, resp) +} +func newConnPool(maxSinglePoolNum int) *connPool { + p := &connPool{ + pool: NewLockItem[typConnPool](map[ServerAddr][]*netConn{}), + maxSinglePoolNum: maxSinglePoolNum, + } + return p +} +func (c *connPool) GetConn(addr ServerAddr) (conn *netConn) { + c.pool.Action(func(t *typConnPool) { + if list, ok := (*t)[addr]; ok { + if len(list) == 0 { + return + } + conn = list[len(list)-1] + list = list[:len(list)-1] + (*t)[addr] = list + } + }) + return +} + +func (c *connPool) PutConn(conn *netConn) { + c.pool.Action(func(t *typConnPool) { + if c.maxSinglePoolNum <= len((*t)[conn.remote]) { + conn.Close() + return + } + (*t)[conn.remote] = append((*t)[conn.remote], conn) + }) +} + +func (n *netConn) Close() { + n.c.Close() +} + +type NetWorkTransportConfig struct { + ServerAddressProvider ServerAddrProvider + + Logger Logger + + NetLayer NetLayer + + MaxPool int + + Timeout time.Duration +} + +func NewNetTransport(conf *NetWorkTransportConfig) *NetTransport { + cmdCh := make(chan *RPC) + logConnCtx, cancel := context.WithCancel(context.Background()) + t := &NetTransport{ + logger: conf.Logger, + timeout: conf.Timeout, + cmdChan: cmdCh, + netLayer: conf.NetLayer, + connPoll: newConnPool(conf.MaxPool), + serverAddrProvider: conf.ServerAddressProvider, + processor: newProcessorProxy(cmdCh), + TimeoutScale: DefaultTimeoutScale, + shutDown: newShutDown(), + ctx: NewLockItem(ctx{ + ctx: logConnCtx, + cancel: cancel, + }), + } + go t.Start() + return t +} + +func genCtx() ctx { + c, cancel := context.WithCancel(context.Background()) + return ctx{ + ctx: c, + cancel: cancel, + } +} + +func (n *NetTransport) CloseConnections() { + n.connPoll.pool.Action(func(t *typConnPool) { + for _, connLise := range *t { + for _, conn := range connLise { + conn.Close() + } + } + *t = map[ServerAddr][]*netConn{} + }) + n.ctx.Action(func(t *ctx) { + t.cancel() + *t = genCtx() + }) +} +func (n *NetTransport) Close() error { + n.shutDown.done(func(_ bool) { + n.netLayer.Close() + }) + return nil +} + +func (n *NetTransport) Start() { + var c int64 + for { + conn, err := n.netLayer.Accept() + if err != nil { + if n.processError(err, c) { + return + } + } + c = 0 + go n.handleConn(n.ctx.Get().ctx, newNetConn("", conn)) + } +} + +const baseDelay = 5 * time.Millisecond +const maxDelay = 1 * time.Second + +func (n *NetTransport) processError(err error, count int64) (needEnd bool) { + select { + case <-n.shutDown.C: + log.Printf("server shut down") + return true + default: + } + e, ok := err.(net.Error) + if !ok { + return true + } + switch { + case e.Timeout(): + log.Printf("listener|Accept|errorx %n ", err) + time.Sleep(func() (delay time.Duration) { + delay = time.Duration(count) * baseDelay + if delay > maxDelay { + delay = maxDelay + } + return delay + }() * time.Millisecond) + return false + } + return true +} + +func (n *NetTransport) handleConn(ctx context.Context, conn *netConn) { + defer conn.Close() + for { + select { + case <-ctx.Done(): + return + default: + cmdType, data, err := defaultPackageParser.Decode(conn.rw.Reader) + if err != nil { + if !errors.Is(err, io.EOF) { + n.logger.Errorf("processConnection|Decode errorx : %s\n", err) + } + return + } + respData, err := n.processor.Do(cmdType, data, conn.rw) + if err != nil { + n.logger.Errorf("NetTransport|processor errorx:%s", err) + return + } + if err = defaultPackageParser.Encode(conn.rw.Writer, cmdType, respData.([]byte)); err != nil { + n.logger.Errorf("NetTransport|Encode errorx:%s", err) + return + } + if err := conn.rw.Flush(); err != nil { + n.logger.Errorf("NetTransport|Flush errorx:%s", err) + return + } + } + } +} + +type netPipeline struct { + conn *netConn + trans *NetTransport + doneCh chan AppendEntriesFuture + inProgressCh chan *appendEntriesFuture + shutdownCh chan struct{} + shutDownOnce sync.Once +} + +func (n *netPipeline) AppendEntries(request *AppendEntryRequest) (AppendEntriesFuture, error) { + af := newAppendEntriesFuture(request) + if err := n.trans.sendRpc(n.conn, CmdAppendEntry, af.req); err != nil { + return nil, err + } + select { + case <-n.shutdownCh: + return nil, ErrShutDown + case n.inProgressCh <- af: + } + + return af, nil +} + +func (n *netPipeline) Consumer() <-chan AppendEntriesFuture { + return n.doneCh +} + +func (n *netPipeline) Close() error { + n.shutDownOnce.Do(func() { + n.conn.Close() + close(n.shutdownCh) + }) + return nil +} + +func newNetPipeline(trans *NetTransport, conn *netConn) *netPipeline { + pipeline := &netPipeline{ + conn: conn, + trans: trans, + doneCh: make(chan AppendEntriesFuture, rpcMaxPipeline), + inProgressCh: make(chan *appendEntriesFuture, rpcMaxPipeline), + shutdownCh: make(chan struct{}), + } + go pipeline.decodeResponses() + return pipeline +} + +func (n *netPipeline) decodeResponses() { + timeout := n.trans.timeout + for { + select { + case <-n.shutdownCh: + return + case af := <-n.inProgressCh: + if timeout > 0 { + n.conn.c.SetDeadline(time.Now().Add(timeout)) + } + var resp *AppendEntryResponse + err := n.trans.recvRpc(n.conn, &resp) + if err != nil { + n.trans.logger.Errorf("decodeResponses|recvRpc errorx :%s", err) + + } + af.responded(resp, err) + select { + case <-n.shutdownCh: + return + case n.doneCh <- af: + } + } + } +} diff --git a/raft.go b/raft.go new file mode 100644 index 0000000..950962b --- /dev/null +++ b/raft.go @@ -0,0 +1,1722 @@ +package papillon + +import ( + "container/list" + "errors" + "fmt" + . "github.com/fuyao-w/common-util" + "github.com/fuyao-w/log" + "golang.org/x/sync/errgroup" + "io" + + "sync" + "sync/atomic" + "time" +) + +type ( + Raft struct { + commitIndex uint64 // 集群已提交的日志,初始化为 0 只有再提交过日志后才可以更新 + lastApplied uint64 // 已提交给状态机的最新 index , 注意:不代表状态机已经应用 + currentTerm uint64 // 当前任期,需要持久化 + configuration *AtomicVal[Configuration] // 集群配置的副本 + conf *AtomicVal[*Config] // 参数配置信息 + confReloadMu sync.Mutex + state State // 节点状态 + run func() // 主线程函数 + lastContact *AtomicVal[time.Time] // 与领导者上次联系的时间 + localInfo ServerInfo // 当前节点地址 + lastEntry *LockItem[lastEntry] // 节点最新的索引、任期 + cluster cluster // 集群配置 + leaderInfo *LockItem[ServerInfo] // 领导人地址 + funcEg *errgroup.Group + shutDown shutDown + logger Logger // 日志 + //-------主线程----------- + rpcCh <-chan *RPC // 处理 RPC 命令 + + commitCh chan struct{} // 日志已提交通知,可以应用到 FSM 可以有 buffer + + leaderState leaderState // 领导人上下文 + heartbeatTimeout <-chan time.Time // 由主线程设置 + electionTimeout <-chan time.Time // 由主线程设置 + candidateFromLeaderTransfer bool // 当前节点在领导权转移过程中 + //-------fsm----------- + fsm FSM // 状态机,日志提交后由此应用 + fsmApplyCh chan []*LogFuture // 状态机线程的日志提交通知 + fsmSnapshotCh chan *fsmSnapshotFuture // 从状态机取快照 + fsmRestoreCh chan *restoreFuture // 通知状态机重新应用快照 + + //-----API-------------------- + //clusterGetCh chan *clusterGetFuture // 获取集群配置 + apiSnapshotBuildCh chan *apiSnapshotFuture // 生成快照 + apiSnapshotRestoreCh chan *userRestoreFuture // 重新应用快照的时候不能接收新的日志,需要从 runState 线程触发 + apiLogApplyCh chan *LogFuture // 日志提交请求,由于需要支持批量提交,所以单独提出来 + commandCh chan *command // 对节点发起的命令,包括领导权验证等 + stateChangeCh chan State + //-----组件------ + rpc RpcInterface // RPC 组件 + kvStore KVStorage // 任期、投票信息持久化组件 + logStore LogStore // 日志持久化组件 + snapshotStore SnapshotStore // 快照组件 + } + lastEntry struct { + // 快照中存的最新日志 + snapshot lastLog + log lastLog + // 已经落到稳定存储的最新日志 + //logIndex uint64 + //logTerm uint64 + } + lastLog struct { + index uint64 + term uint64 + } + // leaderState 领导人上下文 + leaderState struct { + sync.Mutex + commitIndex uint64 // 通过计算副本得出的已提交索引,只能由新日志提交触发更新 + matchIndex map[ServerID]uint64 // 每个跟随者对应的已复制的 index + replicate map[ServerID]*replication // 所有的跟随者 + inflight *list.List // 等待提交并应用到状态机的 LogFuture + startIndex uint64 // 记录任期开始时的最新一条索引,防止在日志提交的时候发生 commit index 回退 + stepDown chan struct{} // 领导人下台通知 + leadershipTransfer int32 // 是否发生领导权转移 1 :是 ,0 :否 + } + // replication 领导人复制时每个跟随者维护的上下文状态 + replication struct { + peer *LockItem[ServerInfo] // 跟随者的 server 信息 + nextIndex uint64 // 待复制给跟随者的下一条日志索引,初始化为领导人最新的日志索引 + heartBeatStop, done, pipelineFinish chan struct{} // 心跳停止、复制线程结束、pipeline 返回结果处理线程结束 + trigger chan *defaultDeferResponse // 强制复制,不需要复制结果可以投递 nil + notifyCh chan struct{} // 强制心跳 + stop chan bool // 复制停止通知,true 代表需要在停机前尽力复制 + lastContact *LockItem[time.Time] // 上次与跟随者联系的时间,用于计算领导权 + notify *LockItem[map[*verifyFuture]struct{}] // VerifyLeader 请求跟踪 + } +) + +const ( + Voter Suffrage = iota + NonVoter +) + +type ( + ServerAddr string + ServerID string + Suffrage int + // ServerInfo 节点的地址信息 + ServerInfo struct { + Suffrage Suffrage + Addr ServerAddr + ID ServerID + } +) + +func (r *Raft) Conf() *Config { + return r.conf.Load() +} +func (r *Raft) getCurrentTerm() uint64 { + return atomic.LoadUint64(&r.currentTerm) +} +func (r *Raft) getLastApplied() uint64 { + return atomic.LoadUint64(&r.lastApplied) +} + +func (r *Raft) setLastApplied(index uint64) { + atomic.StoreUint64(&r.lastApplied, index) +} +func (fr *replication) getNextIndex() uint64 { + return atomic.LoadUint64(&fr.nextIndex) +} +func (fr *replication) setNextIndex(newNextIndex uint64) { + atomic.StoreUint64(&fr.nextIndex, newNextIndex) +} +func (fr *replication) notifyAll(leadership bool) { + fr.notify.Action(func(t *map[*verifyFuture]struct{}) { + for v := range *t { + v.vote(leadership) + } + *t = map[*verifyFuture]struct{}{} + }) +} +func (fr *replication) observe(v *verifyFuture) { + fr.notify.Action(func(t *map[*verifyFuture]struct{}) { + (*t)[v] = struct{}{} + }) +} +func (l *leaderState) setupLeadershipTransfer(status bool) (succ bool) { + old, newVal := int32(1), int32(0) + if status { + old = 0 + newVal = 1 + } + return atomic.CompareAndSwapInt32(&l.leadershipTransfer, old, newVal) +} +func (l *leaderState) getLeadershipTransfer() (status bool) { + return atomic.LoadInt32(&l.leadershipTransfer) == 1 +} + +func (s *ServerInfo) isVoter() bool { + return s.Suffrage == Voter +} + +func NewRaft(conf *Config, + fsm FSM, + rpc RpcInterface, + logStore LogStore, + kvStore KVStorage, + snapshotStore SnapshotStore) (*Raft, error) { + var ( + _lastLog lastLog + ) + ok, hint := ValidateConfig(conf) + if !ok { + return nil, fmt.Errorf("config validate err :%s", hint) + } + if conf.Logger == nil { + conf.Logger = log.NewLogger() + } + + lastIndex, err := logStore.LastIndex() + if err != nil { + if !errors.Is(ErrNotFoundLog, err) { + conf.Logger.Errorf("init index error") + return nil, fmt.Errorf("recover last log err :%s", err) + } + } + if lastIndex > 0 { + log, err := logStore.GetLog(lastIndex) + if err != nil { + conf.Logger.Errorf("get lastLog error") + return nil, fmt.Errorf("recover last log err :%s", err) + } + _lastLog = lastLog{ + index: log.Index, + term: log.Term, + } + } + + currentTerm, err := kvStore.GetUint64(KeyCurrentTerm) + if err != nil && !errors.Is(ErrKeyNotFound, err) { + conf.Logger.Errorf("init current term error :%s", err) + return nil, fmt.Errorf("recover current term err :%s", err) + } + + raft := &Raft{ + commitIndex: 0, + lastApplied: 0, + currentTerm: currentTerm, + conf: NewAtomicVal(conf), + configuration: NewAtomicVal[Configuration](), + state: 0, + run: nil, + lastContact: NewAtomicVal[time.Time](), + localInfo: ServerInfo{Addr: rpc.LocalAddr(), ID: ServerID(conf.LocalID)}, + lastEntry: NewLockItem(lastEntry{log: _lastLog}), + cluster: cluster{}, + leaderInfo: NewLockItem[ServerInfo](), + funcEg: new(errgroup.Group), + shutDown: newShutDown(), + rpcCh: rpc.Consumer(), + apiLogApplyCh: make(chan *LogFuture), + commitCh: make(chan struct{}, 16), + logger: conf.Logger, + leaderState: leaderState{}, + stateChangeCh: make(chan State, 1), + commandCh: make(chan *command), + fsm: fsm, + fsmApplyCh: make(chan []*LogFuture), + fsmSnapshotCh: make(chan *fsmSnapshotFuture), + fsmRestoreCh: make(chan *restoreFuture), + apiSnapshotBuildCh: make(chan *apiSnapshotFuture), + apiSnapshotRestoreCh: make(chan *userRestoreFuture), + rpc: rpc, + kvStore: kvStore, + logStore: logStore, + snapshotStore: snapshotStore, + } + + if err = raft.recoverSnapshot(); err != nil { + return nil, fmt.Errorf("recover snap shot|%s", err) + } + if err = raft.recoverCluster(); err != nil { + return nil, fmt.Errorf("recover cluster|%s", err) + } + raft.rpc.SetHeartbeatFastPath(raft.processHeartBeat) + raft.setFollower() + raft.goFunc(raft.runState, raft.runFSM, raft.runSnapshot) + return raft, nil +} +func (r *Raft) setLatestConfiguration(index uint64, configuration Configuration) { + r.cluster.setLatest(index, configuration) + r.configuration.Store(configuration) +} +func (r *Raft) setCommitConfiguration(index uint64, configuration Configuration) { + r.cluster.setCommit(index, configuration) +} +func (r *Raft) getLatestIndex() uint64 { + entry := r.lastEntry.Get() + return Max(entry.log.index, entry.snapshot.index) +} +func (r *Raft) getLatestTerm() uint64 { + entry := r.lastEntry.Get() + return Max(entry.log.term, entry.snapshot.term) +} + +func (r *Raft) getLatestEntry() (term uint64, index uint64) { + entry := r.lastEntry.Get() + return Max(entry.log.term, entry.snapshot.term), Max(entry.log.index, entry.snapshot.index) +} + +// recoverCluster 从日志恢复集群配置 +func (r *Raft) recoverCluster() (err error) { + entry := r.lastEntry.Get() + for i := entry.snapshot.index; i < entry.log.index; i++ { + if i <= 0 { + continue + } + entry, err := r.logStore.GetLog(i) + if err != nil { + r.logger.Errorf("") + return fmt.Errorf("get log err :%s", err) + } + r.saveConfiguration(entry) + } + return nil +} + +func (r *Raft) recoverSnapshotByID(id string) (*SnapShotMeta, error) { + meta, readCloser, err := r.snapshotStore.Open(id) + if err != nil { + r.logger.Errorf("") + return nil, fmt.Errorf("open id %s err:%s", id, err) + } + defer readCloser.Close() + if err = r.fsm.ReStore(readCloser); err != nil { + r.logger.Errorf("") + return nil, fmt.Errorf("restore snapshot err:%s", err) + } + return meta, nil +} + +// recoverSnapshot 将本地最新的快照恢复至状态机,并更新本地的索引、任期状态 +func (r *Raft) recoverSnapshot() (err error) { + metaList, err := r.snapshotStore.List() + if err != nil { + r.Conf().Logger.Errorf("list snapshot error :%s", err) + return fmt.Errorf("list snapshot err :%s", err) + } + if len(metaList) == 0 { + return nil + } + meta, err := r.recoverSnapshotByID(metaList[0].ID) + if err != nil { + return err + } + r.setLatestSnapshot(meta.Term, meta.Index) + r.setLastApplied(meta.Index) + // 生成快照默认用稳定的配置 + r.setCommitConfiguration(meta.ConfigurationIndex, meta.Configuration) + r.setLatestConfiguration(meta.ConfigurationIndex, meta.Configuration) + return nil +} + +func (r *Raft) clearLeaderInfo() { + r.updateLeaderInfo(func(s *ServerInfo) { + *s = ServerInfo{} + }) +} + +func (r *Raft) updateLeaderInfo(act func(s *ServerInfo)) { + r.leaderInfo.Action(act) +} +func (r *Raft) goFunc(funcList ...func()) { + for _, f := range funcList { + f := f + r.funcEg.Go(func() error { + f() + return nil + }) + } +} +func (r *Raft) waitShutDown() { + +} + +func (r *Raft) buildRPCHeader() *RPCHeader { + header := &RPCHeader{ + ID: r.localInfo.ID, + Addr: r.localInfo.Addr, + } + return header +} +func (r *Raft) setLastContact() { + r.lastContact.Store(time.Now()) +} +func (r *Raft) getLastContact() time.Time { + return r.lastContact.Load() +} +func (fr *replication) setLastContact() { + fr.lastContact.Set(time.Now()) +} +func (fr *replication) getLastContact() time.Time { + return fr.lastContact.Get() +} +func (r *Raft) getCommitIndex() uint64 { + return atomic.LoadUint64(&r.commitIndex) +} +func (r *Raft) setCommitIndex(commitIndex uint64) { + atomic.StoreUint64(&r.commitIndex, commitIndex) +} + +// runState 运行主线程 +func (r *Raft) runState() { + for { + select { + case <-r.shutDown.C: + return + default: + } + r.run() + } +} +func shouldApply(log *LogEntry) bool { + switch log.Type { + case LogConfiguration, LogCommand, LogBarrier: + return true + } + return false +} + +// applyLogToFsm 将 lastApplied 到 index 的日志应用到状态机 +func (r *Raft) applyLogToFsm(toIndex uint64, ready map[uint64]*LogFuture) { + if toIndex <= r.getLastApplied() { + return + } + var ( + fromIndex = r.getLastApplied() + 1 + maxAppend = r.Conf().MaxAppendEntries + logFutures = make([]*LogFuture, 0, maxAppend) + ) + applyBatch := func(futures []*LogFuture) { + if len(futures) == 0 { + return + } + select { + case r.fsmApplyCh <- futures: + case <-r.shutDown.C: + for _, future := range futures { + future.fail(ErrShutDown) + } + } + } + + for i := fromIndex; i <= toIndex; i++ { + var fu *LogFuture + if fu = ready[i]; fu == nil { + log, err := r.logStore.GetLog(i) + if err != nil { + panic(err) + return + } + fu = &LogFuture{log: log} + fu.init() + } + if !shouldApply(fu.log) { + fu.success() + continue + } + logFutures = append(logFutures, fu) + if len(logFutures) >= maxAppend { + applyBatch(logFutures) + logFutures = make([]*LogFuture, 0, maxAppend) + } + } + applyBatch(logFutures) + + r.setLastApplied(toIndex) +} + +func (r *Raft) shouldBuildSnapshot() bool { + _, snapshotIndex := r.getLatestSnapshot() + logIndex, _ := r.logStore.LastIndex() + return logIndex-snapshotIndex > r.Conf().SnapshotThreshold +} +func (r *Raft) buildSnapshot() (id string, err error) { + fu := new(fsmSnapshotFuture) + fu.init() + select { + case r.fsmSnapshotCh <- fu: + case <-r.shutDown.C: + return "", ErrShutDown + } + resp, err := fu.Response() + if err != nil { + r.logger.Errorf("buildSnapshot err :%s", err) + return "", err + } + + clusterFu := new(clusterGetFuture) + clusterFu.init() + select { + case r.commandCh <- &command{typ: commandClusterGet, item: clusterFu}: + case <-r.shutDown.C: + return "", ErrShutDown + } + cResp, err := clusterFu.Response() + if err != nil { + return "", err + } + commit, commitIndex := cResp.commit, cResp.commitIndex + + if resp.index < commitIndex { // 快照如果没赶上已提交的集群配置 index ,则不能生成快照 + return "", fmt.Errorf("no stable snapshot") + } + sink, err := r.snapshotStore.Create(SnapShotVersionDefault, resp.index, resp.term, + commit, commitIndex, r.rpc) + if err != nil { + r.logger.Errorf("") + return + } + + if err = resp.fsmSnapshot.Persist(sink); err != nil { + r.logger.Errorf("") + sink.Cancel() + return + } + if err = sink.Close(); err != nil { + r.logger.Errorf("") + return "", err + } + r.setLatestSnapshot(resp.term, resp.index) + _ = r.compactLog() + return sink.ID(), nil +} + +func (r *Raft) setFollower() { + r.state.set(Follower) + r.run = r.cycleFollower + overrideNotify(r.stateChangeCh, Follower) +} +func (r *Raft) setCandidate() { + r.state.set(Candidate) + r.run = r.cycleCandidate + overrideNotify(r.stateChangeCh, Candidate) +} +func (r *Raft) setLeader() { + r.state.set(Leader) + r.leaderInfo.Action(func(t *ServerInfo) { + t.ID = r.localInfo.ID + t.Addr = r.localInfo.Addr + }) + r.run = r.cycleLeader + overrideNotify(r.stateChangeCh, Leader) +} + +func (r *Raft) setShutDown() { + r.state.set(ShutDown) + r.run = nil + overrideNotify(r.stateChangeCh, ShutDown) +} + +// onShutDown 关机时处理一些清理逻辑 +func (r *Raft) onShutDown() { + // 清理未处理的 command 请求 + for { + select { + case cmd := <-r.commandCh: + r.processCommand(cmd) + default: + return + } + } +} + +// cycleFollower 更随者主线程 +func (r *Raft) cycleFollower() { + leader := r.leaderInfo.Get() + r.logger.Info("entering follower state", "follower", "leader-address", leader.Addr, "leader-id", leader.ID) + r.heartbeatTimeout = randomTimeout(r.Conf().HeartbeatTimeout) + warnOnce := new(sync.Once) + warn := func(args ...interface{}) { + warnOnce.Do(func() { + r.logger.Warn(args...) + }) + } + doFollower := func() (stop bool) { + select { + case <-r.shutDown.C: + r.onShutDown() + return true + case rpc := <-r.rpcCh: + r.processRPC(rpc) + case cmd := <-r.commandCh: + r.processCommand(cmd) + case <-r.heartbeatTimeout: + tm := r.Conf().HeartbeatTimeout + r.heartbeatTimeout = randomTimeout(tm) + if time.Now().Before(r.getLastContact().Add(tm)) { + return + } + oldLeaderInfo := r.leaderInfo.Get() + r.clearLeaderInfo() + switch { + case r.cluster.latestIndex == 0: + warn("no cluster members") + case r.cluster.commitIndex == r.cluster.latestIndex && !r.canVote(r.localInfo.ID): + warn("no part of stable Configuration, aborting election") + case r.canVote(r.localInfo.ID): + warn("heartbeat abortCh reached, starting election", "last-leader-addr", oldLeaderInfo.Addr, "last-leader-id", oldLeaderInfo.ID) + r.setCandidate() + default: + warn("heartbeat abortCh reached, not part of a stable Configuration or a non-voter, not triggering a leader election") + } + } + return + } + + r.cycle(Follower, doFollower) + +} + +// processRPC 处理 rpc 请求 +func (r *Raft) processRPC(cmd *RPC) { + switch req := cmd.Request.(type) { + case *VoteRequest: + r.processVote(req, cmd) + case *AppendEntryRequest: + r.processAppendEntry(req, cmd) + case *FastTimeoutRequest: + r.processFastTimeout(req, cmd) + case *InstallSnapshotRequest: + r.processInstallSnapshot(req, cmd) + } +} + +// processHeartBeat 心跳的 fastPath ,不用经过主线程 +func (r *Raft) processHeartBeat(cmd *RPC) bool { + req, ok := cmd.Request.(*AppendEntryRequest) + if cmd.CmdType == CmdAppendEntry && ok && len(req.Entries) > 0 { + r.processAppendEntry(req, cmd) + return true + } + return false +} + +func (r *Raft) deleteLog(from, to uint64) error { + pageSize := uint64(5) + for i := from; i < to; i += pageSize { + logs, err := r.logStore.GetLogRange(i, i+pageSize) + if err != nil { + return err + } + if err := r.logStore.DeleteRange(i, i+pageSize); err != nil { + return err + } + if len(logs) < 5 { + break + } + } + return nil +} + +// saveConfiguration 从日志里更新集群配置,由于集群变更只允许单次一个节点的变更,所以当有新日志 +// 到来的时候就可以默认为原有的 latest Configuration 已经提交 ,该函数只能用于启东时恢复或者主线程更新 +func (r *Raft) saveConfiguration(entry *LogEntry) { + if entry.Type != LogConfiguration { + return + } + r.setCommitConfiguration(r.cluster.latestIndex, r.cluster.latest) + r.setLatestConfiguration(entry.Index, DecodeConfiguration(entry.Data)) +} + +// storeEntries 持久化日志,需要校验 PrevLogIndex、PrevLogTerm 是否匹配 +func (r *Raft) storeEntries(req *AppendEntryRequest) ([]*LogEntry, error) { + if len(req.Entries) == 0 { + return nil, nil + } + var ( + newEntries []*LogEntry + latestTerm, latestIndex = r.getLatestEntry() + ) + + if req.PrevLogIndex+1 != req.Entries[0].Index { + return nil, errors.New("param errorx") + } + // 校验日志是否匹配 + // 只要有和 prevTerm、prevIndex 匹配的日志就行,不一定是最新的,日志可能会被领导人覆盖,lastIndex 可能回退 + if req.PrevLogIndex > 0 { // 第一个日志 PrevLogIndex 为 0 ,所以这里加判断 + var prevlLogTerm uint64 + if req.PrevLogIndex == latestIndex { + prevlLogTerm = latestTerm + } else { + log, err := r.logStore.GetLog(req.PrevLogIndex) + if err != nil { + if err.Error() != ErrNotFoundLog.Error() { + r.logger.Errorf("") + } + return nil, err + } + prevlLogTerm = log.Term + + } + if prevlLogTerm != req.PrevLogTerm { + return nil, errors.New("prev log not match") + } + } + + for i, entry := range req.Entries { + if entry.Index > latestIndex { // 绝对是最新的了 + newEntries = req.Entries[i:] + break + } + log, err := r.logStore.GetLog(entry.Index) + if err != nil { + return nil, err // 小于 latestIndex 的不应该有空洞,所以不判断 ErrNotFoundLog + } + if log.Term != entry.Term { + if err := r.logStore.DeleteRange(entry.Index, latestIndex); err != nil { + return nil, err + } + newEntries = req.Entries[i:] + break + } + + } + if len(newEntries) == 0 { + return nil, nil + } + if err := r.logStore.SetLogs(newEntries); err != nil { + r.logger.Errorf("") + return nil, err + } + return newEntries, nil +} + +// processInstallSnapshot 复制远端的快照到本地,并且引用到状态机,这个方法使用在领导者复制过程中待复制日志在快照中的情况 +// 隐含意义就是跟随者的有效日志比领导者的快照旧 +func (r *Raft) processInstallSnapshot(req *InstallSnapshotRequest, cmd *RPC) { + var ( + succ bool + meta = req.SnapshotMeta + ) + defer func() { + cmd.Response <- &InstallSnapshotResponse{ + RPCHeader: r.buildRPCHeader(), + Success: succ, + Term: r.getCurrentTerm(), + } + }() + if req.Term < r.getCurrentTerm() { + return + } + if req.Term > r.getCurrentTerm() { + r.setFollower() + r.setCurrentTerm(req.Term) + } + if len(req.ID) > 0 { + r.leaderInfo.Set(ServerInfo{ + ID: req.ID, + Addr: req.Addr, + }) + } + + sink, err := r.snapshotStore.Create(SnapShotVersionDefault, meta.Index, meta.Term, meta.Configuration, meta.ConfigurationIndex, r.rpc) + if err != nil { + return + } + reader := newCounterReader(cmd.Reader) + written, err := io.Copy(sink, reader) + if err != nil { + r.logger.Errorf("") + sink.Cancel() + return + } + if written != meta.Size { + r.logger.Errorf("") + sink.Cancel() + return + } + if err = sink.Close(); err != nil { + return + } + fu := &restoreFuture{ID: sink.ID()} + fu.init() + + select { + case <-r.shutDown.C: + return + case r.fsmRestoreCh <- fu: + } + if _, err = fu.Response(); err != nil { + r.logger.Errorf("") + return + } + r.setLatestSnapshot(meta.Term, meta.Index) + r.setLastApplied(meta.Index) + + r.setCommitConfiguration(meta.Index, meta.Configuration) + r.setLatestConfiguration(meta.Index, meta.Configuration) + r.setLastContact() + _ = r.compactLog() + succ = true +} + +// compactLog 压缩日志,至少保留 Config.TrailingLogs 长度日志 +func (r *Raft) compactLog() error { + trailingLogs := r.Conf().TrailingLogs + firstIndex, err := r.logStore.FirstIndex() + if err != nil { + r.logger.Errorf("") + return err + } + _, logIndex := r.getLatestLog() + _, snapshotIndex := r.getLatestSnapshot() + idx := Min(snapshotIndex, logIndex-trailingLogs) + if idx < firstIndex { + return nil + } + + if err = r.logStore.DeleteRange(firstIndex, idx); err != nil { + r.logger.Errorf("") + return err + } + return nil +} + +// processFastTimeout 跟随者快速超时成为候选人 +func (r *Raft) processFastTimeout(req *FastTimeoutRequest, cmd *RPC) { + var succ bool + defer func() { + cmd.Response <- &FastTimeoutResponse{ + RPCHeader: r.buildRPCHeader(), + Success: succ, + } + }() + if req.Term < r.getCurrentTerm() { + return + } + r.setCandidate() + r.candidateFromLeaderTransfer = true + succ = true +} + +// processAppendEntry 处理日志复制,当日志长度为 0 的时候当做心跳使用 +func (r *Raft) processAppendEntry(req *AppendEntryRequest, cmd *RPC) { + var ( + succ bool + term = r.getLatestTerm() + ) + + defer func() { + cmd.Response <- &AppendEntryResponse{ + RPCHeader: r.buildRPCHeader(), + Term: r.getCurrentTerm(), + Succ: succ, + LatestIndex: term, + } + io.Copy(io.Discard, cmd.Reader) + }() + if req.Term < r.getCurrentTerm() { + return + } + + // 如果我们正在进行 leader transfer 则不用退回成跟随者,继续发起选举 + if req.Term > r.getCurrentTerm() || (r.state.Get() != Follower && !r.candidateFromLeaderTransfer) { + r.setCurrentTerm(req.Term) + r.setFollower() + } + + if len(req.ID) > 0 { + r.leaderInfo.Set(ServerInfo{ + Addr: req.Addr, + ID: req.ID, + }) + } + entries, err := r.storeEntries(req) + if err != nil { + return + } + // 更新本地最近的 index + last := entries[len(entries)-1] + r.setLatestLog(last.Term, last.Index) + // 处理集群配置更新 + for _, entry := range entries { + r.saveConfiguration(entry) + } + // 更新已提交的配置 + if req.LeaderCommit > r.cluster.latestIndex { + r.setCommitConfiguration(r.cluster.latestIndex, r.cluster.latest) + } + if req.LeaderCommit > r.getCommitIndex() { + // 更新 commit index + index := r.getLatestIndex() + r.setCommitIndex(Min(req.LeaderCommit, index)) + // 应用到状态机错误不用返回给 leader + r.applyLogToFsm(r.getCommitIndex(), nil) + } + succ = true + r.setLastContact() +} +func (r *Raft) applyToFsm() { + +} +func (r *Raft) processVote(req *VoteRequest, cmd *RPC) { + var granted bool + defer func() { + cmd.Response <- &VoteResponse{ + RPCHeader: r.buildRPCHeader(), + Term: r.getCurrentTerm(), + Granted: granted, + } + }() + if len(r.getLatestConfiguration()) > 0 { + if r.inConfiguration(req.ID) { + r.logger.Warn("rejecting vote request since node is not in configuration", + "from", req.Addr) + } else if !r.canVote(req.ID) { + r.logger.Warn("rejecting vote request since node is not a voter", "from", req.Addr) + } + return + } + // 如果跟随者确认当前有合法领导人则直接拒绝,保持集群稳定 + if info := r.leaderInfo.Get(); len(info.Addr) != 0 && info.Addr != req.Addr && !req.LeaderTransfer { + r.logger.Warn("rejecting vote request since we have a leader", + "from", req.Addr, + "leader", r.leaderInfo.Get().ID, + "leader-id", string(req.ID)) + return + } + if req.Term < r.getCurrentTerm() { + return + } + if req.Term > r.getCurrentTerm() { + r.logger.Debug("lost leadership because received a requestVote with a newer term") + r.setCurrentTerm(req.Term) + r.setFollower() + } + + candidateId, err := r.kvStore.Get(KeyLastVoteFor) + if err != nil { + r.logger.Errorf("") + return + } + candidateTerm, err := r.kvStore.GetUint64(KeyLastVoteTerm) + if err != nil { + r.logger.Errorf("") + return + } + + // 每个任期只能投一票 + if candidateTerm == req.Term && len(candidateId) != 0 { + return + } + if err := r.kvStore.Set(KeyLastVoteFor, Str2Bytes(string(req.ID))); err != nil { + r.logger.Errorf("") + return + } + if err := r.kvStore.SetUint64(KeyLastVoteTerm, req.Term); err != nil { + r.logger.Errorf("") + return + } + granted = true +} +func (r *Raft) cycle(state State, f func() (stop bool)) { + for r.state.Get() == state && !f() { + } +} + +// setCurrentTerm 更新任期。需要持久化以在重启时恢复 +func (r *Raft) setCurrentTerm(term uint64) { + if err := r.kvStore.Set(KeyCurrentTerm, iToBytes(term)); err != nil { + panic(err) + return + } + atomic.StoreUint64(&r.currentTerm, term) +} + +func (r *Raft) cycleCandidate() { + r.setCurrentTerm(r.getCurrentTerm() + 1) + r.logger.Infof("entering candidate state , term : %d", r.getCurrentTerm()) + var ( + voteCount uint + quorumSize = r.quorumSize() + electionCh = r.launchElection() + ) + defer func() { + r.candidateFromLeaderTransfer = false + }() + r.electionTimeout = randomTimeout(r.Conf().ElectionTimeout) + + doCandidate := func() (stop bool) { + select { + case <-r.shutDown.C: + r.onShutDown() + return true + case rpc := <-r.rpcCh: + r.processRPC(rpc) + case cmd := <-r.commandCh: + r.processCommand(cmd) + case <-r.electionTimeout: + r.logger.Warn("election timeout reached, restarting election") + return true + case result := <-electionCh: + if result.Term > r.getCurrentTerm() { + r.logger.Debug("newer term discovered, fallback to follower", "term", result.Term) + r.setCurrentTerm(result.Term) + r.setFollower() + return + } + if !result.Granted { + return + } + voteCount++ + r.logger.Debug("vote granted", "from", result.ID, "term", result.Term, "tally", voteCount) + if voteCount >= quorumSize { + r.logger.Info("election won", "term", result.Term, "tally", voteCount, "id", r.localInfo.ID) + r.setLeader() + } + } + return + } + + r.cycle(Candidate, doCandidate) +} + +func (r *Raft) setupLeaderState() { + servers := r.getLatestConfiguration() + index := r.getLatestIndex() + r.leaderState.inflight = list.New() + r.leaderState.startIndex = index + r.leaderState.stepDown = make(chan struct{}) + r.leaderState.replicate = make(map[ServerID]*replication, len(servers)) + r.leaderState.matchIndex = make(map[ServerID]uint64) + for _, info := range r.getLatestConfiguration() { + if info.isVoter() { + r.leaderState.matchIndex[info.ID] = 0 + } + } +} +func (r *Raft) stopReplication() { + for _, replication := range r.leaderState.replicate { + close(replication.stop) + <-replication.done + } +} + +const minHeartBeatInterval = time.Millisecond * 10 + +// recalculate 计算 commit index 必须在 leaderState 的锁里执行 +func (r *Raft) recalculate() uint64 { + list := make([]uint64, len(r.leaderState.matchIndex)) + for _, idx := range r.leaderState.matchIndex { + list = append(list, idx) + } + SortSlice(list) + return list[len(list)>>1] +} +func (l *leaderState) getCommitIndex() uint64 { + return l.commitIndex +} +func (r *Raft) updateMatchIndex(id ServerID, latestIndex uint64) { + r.leaderState.Lock() + defer r.leaderState.Unlock() + // commit index 不允许回退,不允许比新任期提交的第一条日志小 + if idx, ok := r.leaderState.matchIndex[id]; ok && latestIndex > idx { + r.leaderState.matchIndex[id] = latestIndex + r.calcCommitIndex() + } +} +func (r *Raft) calcCommitIndex() { + commitIndex := r.recalculate() + if commitIndex > r.leaderState.getCommitIndex() && commitIndex > r.leaderState.startIndex { + r.leaderState.commitIndex = commitIndex + asyncNotify(r.commitCh) + } +} + +func (r *Raft) onConfigurationUpdate() { + r.leaderState.Lock() + defer r.leaderState.Unlock() + oldMatch := r.leaderState.matchIndex + r.leaderState.matchIndex = make(map[ServerID]uint64) + for _, info := range r.getLatestConfiguration() { + if info.isVoter() { + r.leaderState.matchIndex[info.ID] = oldMatch[info.ID] + } + } + r.calcCommitIndex() +} + +// heartBeat 想跟随者发起心跳,跟随 replicate 关闭 +func (r *Raft) heartBeat(fr *replication) { + var ( + ticker = time.NewTicker(Max(r.Conf().HeartbeatTimeout/10, minHeartBeatInterval)) + ) + defer ticker.Stop() + for { + select { + case <-fr.heartBeatStop: + return + case <-ticker.C: + case <-fr.notifyCh: + } + r.replicateTo(fr, 0) + } +} + +func (r *Raft) buildAppendEntryReq(fr *replication, latestIndex uint64) (*AppendEntryRequest, error) { + var ( + snapshotTerm, snapshotIndex = r.getLatestSnapshot() + entries []*LogEntry + req = &AppendEntryRequest{ + RPCHeader: r.buildRPCHeader(), + Term: r.getCurrentTerm(), + LeaderCommit: r.getCommitIndex(), + } + ) + setupLogEntries := func() (err error) { + if fr.getNextIndex() < latestIndex { + return nil + } + entries, err = r.logStore.GetLogRange(fr.nextIndex, latestIndex) + if err != nil { + return err + } + if uint64(len(entries)) != latestIndex-fr.nextIndex+1 { + return ErrNotFoundLog + } + return nil + } + + setupPrevLog := func() error { + if len(entries) == 0 { + return nil + } + prevIndex := entries[0].Index - 1 + switch { + case prevIndex == 0: // 第一个日志 + case prevIndex == snapshotIndex: // 上一个 index 正好是快照的最后一个日志,避免触发快照安装 + req.PrevLogTerm, req.PrevLogIndex = snapshotTerm, snapshotIndex + default: + log, err := r.logStore.GetLog(prevIndex) + if err == nil { + return err + } + req.PrevLogTerm, req.PrevLogIndex = log.Term, log.Index + } + return nil + } + for _, f := range []func() error{ + setupLogEntries, + setupPrevLog, + } { + if err := f(); err != nil { + return nil, err + } + } + return req, nil +} +func (r *Raft) replicateHelper(fr *replication) (stop bool) { + var ( + ticker = time.NewTicker(r.Conf().CommitTimeout) + ) + defer ticker.Stop() + for !stop { + select { + case <-fr.stop: + return true + case <-ticker.C: + stop = r.replicateTo(fr, r.getLatestIndex()) + case fu := <-fr.trigger: + stop = r.replicateTo(fr, r.getLatestIndex()) + if fu == nil { + continue + } + if stop { + fu.success() + } else { + fu.fail(errors.New("replication failed")) + } + + } + } + return +} + +// leaderLease 领导人线程通知自己下台 +func (r *Raft) leaderLease(term uint64) { + r.setFollower() + r.setCurrentTerm(term) + asyncNotify(r.leaderState.stepDown) +} + +// sendLatestSnapshot 发送最新的快照 +func (r *Raft) sendLatestSnapshot(fr *replication) error { + list, err := r.snapshotStore.List() + if err != nil { + return err + } + if len(list) == 0 { + return errors.New("snapshot not exist") + } + latestID := list[0].ID + meta, readCloser, err := r.snapshotStore.Open(latestID) + if err != nil { + return err + } + defer func() { + readCloser.Close() + }() + resp, err := r.rpc.InstallSnapShot(Ptr(fr.peer.Get()), &InstallSnapshotRequest{ + RPCHeader: r.buildRPCHeader(), + SnapshotMeta: meta, + }, readCloser) + if err != nil { + return err + } + if resp.Term > r.getCurrentTerm() { + + r.leaderLease(resp.Term) + return nil + } + fr.setLastContact() + if resp.Success { + + } else { + + } + return nil +} + +// updateLatestCommit 更新最新的提交索引,并且回调 replication 的 verifyFuture 请求 +func (r *Raft) updateLatestCommit(fr *replication, entries []*LogEntry) { + if len(entries) > 0 { + peer := fr.peer.Get() + last := entries[len(entries)-1] + fr.setNextIndex(last.Index + 1) + r.updateMatchIndex(peer.ID, last.Index) + if r.getCommitIndex()-fr.nextIndex < 50 { + r.logger.Infof("peer :%s catch up", peer.ID) + } + } + fr.notifyAll(true) +} + +func (r *Raft) replicateTo(fr *replication, latestIndex uint64) (stop bool) { + hasMore := func() bool { + select { + case <-fr.stop: + return false + default: + return fr.getNextIndex() < latestIndex + } + } + for { + req, err := r.buildAppendEntryReq(fr, latestIndex) + if err != nil { + if errors.Is(ErrNotFoundLog, err) { + _ = r.sendLatestSnapshot(fr) + return + } + } + resp, err := r.rpc.AppendEntries(Ptr(fr.peer.Get()), req) + if err != nil { + r.logger.Errorf("") + return + } + if resp.Term > r.getCurrentTerm() { + r.leaderLease(resp.Term) + return true + } + fr.setLastContact() + if resp.Succ { + r.updateLatestCommit(fr, req.Entries) + } else { + fr.setNextIndex(Max(1, Min(fr.getNextIndex()-1, resp.LatestIndex))) + } + if !hasMore() { + break + } + } + + return +} + +// processPipelineResult 处理 pipeline 的结果 +func (r *Raft) processPipelineResult(fr *replication, pipeline AppendEntryPipeline) { + defer close(fr.pipelineFinish) + for { + select { + case <-fr.done: + return + case <-fr.stop: + return + case fu := <-pipeline.Consumer(): + resp, err := fu.Response() + if err != nil { + r.logger.Errorf("") + continue + } + if resp.Term > r.getCurrentTerm() { + r.leaderLease(resp.Term) + return + } + fr.setLastContact() + if resp.Succ { + r.updateLatestCommit(fr, fu.Request().Entries) + } else { + fr.setNextIndex(Max(1, Min(fr.getNextIndex()-1, resp.LatestIndex))) + } + } + + } +} +func (r *Raft) pipelineReplicateTo(fr *replication, pipeline AppendEntryPipeline) (stop bool) { + req, err := r.buildAppendEntryReq(fr, r.getLatestIndex()) + if err != nil { + r.logger.Errorf("") + return true + } + _, err = pipeline.AppendEntries(req) + if err != nil { + r.logger.Errorf("") + return true + } + if n := len(req.Entries); n > 0 { + fr.setNextIndex(req.Entries[n-1].Index + 1) + } + return +} +func (r *Raft) pipelineReplicateHelper(fr *replication) (stop bool) { + var ( + ticker = time.NewTicker(r.Conf().CommitTimeout) + ) + defer ticker.Stop() + pipeline, err := r.rpc.AppendEntryPipeline(Ptr(fr.peer.Get())) + if err != nil { + return + } + r.goFunc(func() { + r.processPipelineResult(fr, pipeline) + }) +END: + for { + select { + case <-fr.stop: + stop = r.pipelineReplicateTo(fr, pipeline) + break END + case <-ticker.C: + case <-fr.trigger: + } + if stop = r.pipelineReplicateTo(fr, pipeline); stop { + break END + } + } + close(fr.done) + select { + case <-fr.pipelineFinish: + case <-r.shutDown.C: + } + return +} + +// replicate 复制到制定的跟随者,先短连接(可以发送快照),后长链接 +func (r *Raft) replicate(fr *replication) { + defer func() { close(fr.heartBeatStop) }() + for stop := r.replicateHelper(fr); !stop; stop = r.replicateHelper(fr) { + stop = r.pipelineReplicateHelper(fr) + } +} + +// reloadReplication 重新加载跟随者的复制、心跳线程 +func (r *Raft) reloadReplication() { + var ( + set = map[ServerID]bool{} + index = r.getLatestIndex() + ) + index++ + // 开启新的跟随者线程 + for _, server := range r.getLatestConfiguration() { + set[server.ID] = true + if server.ID == r.localInfo.ID { + continue + } + if fr := r.leaderState.replicate[server.ID]; fr != nil { + if fr.peer.Get().Addr != server.Addr { + r.logger.Info("updating peer", "peer", server.ID) + } + fr.peer.Set(server) + continue + } + fr := &replication{ + nextIndex: index, + peer: NewLockItem(server), + stop: make(chan bool), + heartBeatStop: make(chan struct{}), + notifyCh: make(chan struct{}), + done: make(chan struct{}), + pipelineFinish: make(chan struct{}), + trigger: make(chan *defaultDeferResponse), + lastContact: NewLockItem[time.Time](), + notify: NewLockItem(map[*verifyFuture]struct{}{}), + } + r.leaderState.replicate[server.ID] = fr + r.goFunc(func() { + r.heartBeat(fr) + }, func() { + r.replicate(fr) + }) + } + // 删除已经不在集群的跟随者线程 + for _, rep := range r.leaderState.replicate { + id := rep.peer.Get().ID + if set[id] { + continue + } + rep.notifyAll(false) + // 删除 + delete(r.leaderState.replicate, id) + rep.stop <- true // 尽力复制 + close(rep.stop) + } +} + +func (r *Raft) clearLeaderState() { + for e := r.leaderState.inflight.Front(); e != nil; e = e.Next() { + e.Value.(*LogFuture).fail(ErrNotLeader) + r.leaderState.inflight.Remove(e) + } + r.leaderState.inflight = nil + r.leaderState.replicate = nil + r.leaderState.startIndex = 0 + r.leaderState.matchIndex = nil + r.leaderState.commitIndex = 0 + r.leaderState.stepDown = nil +} + +// checkLeadership 计算领导权 +func (r *Raft) checkLeadership() (leader bool, maxDiff time.Duration) { + var ( + now = time.Now() + quorumSize = r.quorumSize() + leaseTimeout = r.Conf().LeaderLeaseTimeout + ) + for _, peer := range r.getLatestConfiguration() { + if !peer.isVoter() { + continue + } + if peer.ID == r.localInfo.ID { + quorumSize-- + continue + } + diff := now.Sub(r.leaderState.replicate[peer.ID].lastContact.Get()) + if diff <= leaseTimeout { + quorumSize-- + maxDiff = Max(diff, maxDiff) + } else { + if diff < 3*leaseTimeout { + r.logger.Warn("failed to contact", "server-id", peer.ID, "time", diff) + } else { + r.logger.Debug("failed to contact", "server-id", peer.ID, "time", diff) + } + } + + } + return quorumSize <= 0, maxDiff +} + +// broadcastReplicate 通知所有副本强制进行复制 +func (r *Raft) broadcastReplicate() { + for _, repl := range r.leaderState.replicate { + asyncNotify(repl.trigger) + } +} + +// applyLog 想本地提交日志然后通知复制到跟随者 +func (r *Raft) applyLog(future []*LogFuture) { + var ( + index = r.getLatestIndex() + term = r.getCurrentTerm() + logs = make([]*LogEntry, 0, len(future)) + now = time.Now() + ) + for _, fu := range future { + index++ + fu.log.Index = index + fu.log.Term = term + fu.log.CreatedAt = now + log := fu.log + logs = append(logs, log) + } + err := r.logStore.SetLogs(logs) + if err != nil { + r.logger.Errorf("applyLog|SetLogs err :%s", err) + for _, fu := range future { + fu.fail(err) + } + // 如果本地存储出问题了直接放弃领导权 + r.setFollower() + return + } + for _, fu := range future { + r.leaderState.inflight.PushBack(fu) + } + // 更新本地最新 index + r.setLatestLog(term, index) + // 更新 commitIndex + r.updateMatchIndex(r.localInfo.ID, index) + // 通知复制 + r.broadcastReplicate() +} + +// leaderCommit 只在 cycleLeader 中调用,将所有可提交的日志提交到状态机 +func (r *Raft) leaderCommit() (stepDown bool) { + var ( + oldCommitIndex = r.getCommitIndex() + newCommitIndex = r.leaderState.getCommitIndex() + readyFuture = map[uint64]*LogFuture{} + readElem []*list.Element + ) + + r.logger.Infof("leader commit ori :%d,cur:%d", oldCommitIndex, newCommitIndex) + r.setCommitIndex(newCommitIndex) + if r.cluster.latestIndex > oldCommitIndex && r.cluster.latestIndex <= newCommitIndex { + r.setCommitConfiguration(r.cluster.latestIndex, r.cluster.latest) + if !r.canVote(r.localInfo.ID) { // 集群配置提交后,如果当前领导人不在集群内,则下台 + stepDown = true + } + } + + for e := r.leaderState.inflight.Front(); e != nil; e = e.Next() { + if future := e.Value.(*LogFuture); future.Index() <= newCommitIndex { + readyFuture[future.Index()] = future + readElem = append(readElem, e) + } + } + if len(readElem) > 0 { + r.applyLogToFsm(newCommitIndex, readyFuture) + } + for _, element := range readElem { + r.leaderState.inflight.Remove(element) + } + + return false +} + +// apiProcessRestore 从本地应用快照 +func (r *Raft) apiProcessRestore(fu *restoreFuture) { + +} + +// hasExistTerm 判断节点是否干净 +func (r *Raft) hasExistTerm() (exist bool, err error) { + for _, f := range []func() (bool, error){ + func() (bool, error) { + existTerm, err := r.kvStore.GetUint64(KeyCurrentTerm) + if errors.Is(ErrKeyNotFound, err) { + err = nil + } + return existTerm > 0, err + }, + func() (bool, error) { + lastIndex, err := r.logStore.LastIndex() + return lastIndex > 0, err + }, + func() (bool, error) { + snapshots, err := r.snapshotStore.List() + return len(snapshots) > 0, err + }, + } { + if exist, err = f(); exist || err != nil { + if err != nil { + r.logger.Errorf("hasExistTerm err :%s", err) + } + return + } + } + return +} + +func (r *Raft) clacNewConfiguration(req *configurationChangeRequest) (newConfiguration Configuration, err error) { + + switch req.command { + case updateServer: + found := false + for _, server := range r.getLatestConfiguration() { + if server.ID == req.peer.ID { + newConfiguration.Servers = append(newConfiguration.Servers, req.peer) + } else { + newConfiguration.Servers = append(newConfiguration.Servers, server) + } + } + if !found { + return Configuration{}, errors.New("not found") + } + case addServer: + for _, server := range r.getLatestConfiguration() { + if server.ID == req.peer.ID { + return Configuration{}, errors.New("peer duplicate") + } + newConfiguration.Servers = append(newConfiguration.Servers, server) + } + newConfiguration.Servers = append(newConfiguration.Servers, req.peer) + case removeServer: + found := false + for _, server := range r.getLatestConfiguration() { + if server.ID == req.peer.ID { + continue + } else { + newConfiguration.Servers = append(newConfiguration.Servers, server) + } + } + if !found { + return Configuration{}, errors.New("not found") + } + + } + + return +} + +func (r *Raft) cycleLeader() { + r.logger.Debug("cycle leader ", r.leaderInfo.Get().ID) + r.setupLeaderState() + r.reloadReplication() + defer func() { + r.logger.Info("leave leader") + r.stopReplication() + r.clearLeaderState() + // TODO 注意下,是否需要调用 setLastContact 避免跟随者错误处理逻辑 + }() + // 提交一个空日志,用于确认 commitIndex + future := &LogFuture{log: &LogEntry{Type: LogNoop}} + future.init() + r.applyLog([]*LogFuture{future}) + leaderLeaseTimeout := time.After(r.Conf().LeaderLeaseTimeout) + + leaderLoop := func() (stop bool) { + select { + case <-r.shutDown.C: + r.onShutDown() + r.logger.Debug("shut down") + return true + case rpc := <-r.rpcCh: + r.processRPC(rpc) + case cmd := <-r.commandCh: + r.processCommand(cmd) + case <-r.commitCh: + if stepDown := r.leaderCommit(); stepDown { + r.setShutDown() + } + case <-leaderLeaseTimeout: + if leader, maxDiff := r.checkLeadership(); leader { + leaderLeaseTimeout = time.After(Max(r.Conf().LeaderLeaseTimeout-maxDiff, minCheckInterval)) + } else { + r.logger.Infof("leader ship check fail, term :%d ,step down", r.getCurrentTerm()) + r.setFollower() + } + case <-r.leaderState.stepDown: + r.logger.Infof("leader ship step down, term :%d", r.getCurrentTerm()) + r.setFollower() + } + return + } + r.cycle(Leader, leaderLoop) +} + +func (r *Raft) setLatestLog(term, index uint64) { + r.lastEntry.Action(func(t *lastEntry) { + t.log = lastLog{index: index, term: term} + }) +} +func (r *Raft) getLatestLog() (term, index uint64) { + entry := r.lastEntry.Get() + return entry.log.term, entry.log.index +} +func (r *Raft) getLatestSnapshot() (term, index uint64) { + entry := r.lastEntry.Get() + return entry.snapshot.term, entry.snapshot.index +} +func (r *Raft) setLatestSnapshot(term, index uint64) { + r.lastEntry.Action(func(t *lastEntry) { + t.snapshot = lastLog{index: index, term: term} + }) +} +func (r *Raft) getLatestConfiguration() []ServerInfo { + return r.cluster.latest.Servers +} +func (r *Raft) inConfiguration(id ServerID) bool { + for _, server := range r.cluster.latest.Servers { + if server.ID == id { + return true + } + } + return false +} + +func (r *Raft) canVote(id ServerID) bool { + for _, serverInfo := range r.getLatestConfiguration() { + if serverInfo.ID == id { + return serverInfo.isVoter() + } + } + return false +} + +func (r *Raft) quorumSize() (c uint) { + for _, server := range r.getLatestConfiguration() { + if server.Suffrage == Voter { + c++ + } + } + return c>>1 + 1 +} + +func (r *Raft) launchElection() (result chan *voteResult) { + var ( + list = r.getLatestConfiguration() + lastTerm, lastIndex = r.getLatestEntry() + header = r.buildRPCHeader() + currentTerm = r.getCurrentTerm() + ) + result = make(chan *voteResult, len(list)) + for _, info := range list { + if info.ID == r.localInfo.ID { + if err := r.kvStore.Set(KeyLastVoteFor, Str2Bytes(string(info.ID))); err != nil { + r.logger.Errorf("launchElection vote for self errorx :%s", err) + continue + } + result <- &voteResult{ + VoteResponse: &VoteResponse{ + RPCHeader: header, + Term: currentTerm, + Granted: true, + }, + ServerID: info.ID, + } + } else { + info := info + r.goFunc(func() { + resp, err := r.rpc.VoteRequest(&info, &VoteRequest{ + RPCHeader: header, + Term: currentTerm, + CandidateID: info.ID, + LastLogIndex: lastIndex, + LastLogTerm: lastTerm, + LeaderTransfer: r.candidateFromLeaderTransfer, + }) + if err != nil { + return + } + result <- &voteResult{ + VoteResponse: resp, + ServerID: info.ID, + } + }) + + } + } + return +} diff --git a/raft_test.go b/raft_test.go new file mode 100644 index 0000000..46547d7 --- /dev/null +++ b/raft_test.go @@ -0,0 +1,242 @@ +package papillon + +import ( + "encoding/json" + "fmt" + "github.com/spf13/cast" + "io" + "net/http" + "testing" + "time" +) + +func TestRaft(t *testing.T) { + raft, rpc := buildRaft("1", nil, nil) + _ = rpc + _ = raft + go func() { + time.Sleep(time.Second) + raft.BootstrapCluster(Configuration{Servers: []ServerInfo{{Voter, "1", "1"}}}) + }() + http.Handle("/get", &getHandle{raftList: []*Raft{raft}}) + http.Handle("/set", &setHandle{raftList: []*Raft{raft}}) + http.Handle("/verify", &verifyHandle{raft}) + http.Handle("/config", &configGetHandle{raft}) + http.Handle("/get_log", &getLogHandle{raftList: []*Raft{raft}}) + http.Handle("/leader_transfer", &leaderTransferHandle{raftList: []*Raft{raft}}) + http.Handle("/snapshot", &snapshotHandle{raftList: []*Raft{raft}}) + http.Handle("/restore", &userRestoreSnapshotHandle{raftList: []*Raft{raft}}) + http.Handle("/add_peer", &addPeerHandle{raftList: []*Raft{raft}}) + http.ListenAndServe("localhost:8080", nil) +} + +func buildRaft(localID string, rpc RpcInterface, store interface { + LogStore + KVStorage +}) (*Raft, *memFSM) { + + conf := &Config{ + LocalID: localID, + //HeartBeatCycle: time.Second * 2, + //MemberList: nil, + HeartbeatTimeout: time.Second * 2, + SnapshotInterval: time.Second * 15, + ElectionTimeout: time.Second * 5, + CommitTimeout: time.Second, + LeaderLeaseTimeout: time.Second * 1, + MaxAppendEntries: 10, + SnapshotThreshold: 100, + TrailingLogs: 1000, + ApplyBatch: 1, + LeadershipCatchUpRounds: 500, + //ShutdownOnRemove: false, + } + if store == nil { + store = newMemoryStore() + } + if rpc == nil { + rpc = newMemRpc(localID) + } + fsm := newMemFSM() + fileSnapshot, err := NewFileSnapshot("./testsnapshot/snapshot"+localID, false, 3) + if err != nil { + panic(err) + } + raft, err := NewRaft(conf, fsm, rpc, store, store, fileSnapshot) + if err != nil { + panic(err) + } + return raft, fsm +} + +func getLeader(rafts ...*Raft) *Raft { + for _, raft := range rafts { + _, err := raft.VerifyLeader().Response() + if err != nil { + continue + } + return raft + } + return nil +} + +type ( + getHandle struct { + raftList []*Raft + } + setHandle struct { + raftList []*Raft + } + verifyHandle struct { + *Raft + } + configGetHandle struct { + *Raft + } + getLogHandle struct { + raftList []*Raft + } + leaderTransferHandle struct { + raftList []*Raft + } + snapshotHandle struct { + raftList []*Raft + } + userRestoreSnapshotHandle struct { + raftList []*Raft + } + addPeerHandle struct { + raftList []*Raft + } +) + +func (g *leaderTransferHandle) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + raft := getLeader(g.raftList...) + fu := raft.LeaderTransfer("", "") + _, err := fu.Response() + if err == nil { + writer.Write([]byte("succ")) + } else { + + writer.Write([]byte(err.Error())) + } +} +func (g *getHandle) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + key := request.URL.Query().Get("key") + raft := getLeader(g.raftList...) + fmt.Println("getleader", raft.localInfo.ID) + fsm := raft.fsm.(*memFSM) + val := fsm.getVal(key) + if len(val) > 0 { + writer.Write([]byte(val)) + } else { + writer.Write([]byte("not found")) + } + writer.WriteHeader(200) +} +func (s *setHandle) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + key := request.URL.Query().Get("key") + value := request.URL.Query().Get("value") + raft := getLeader(s.raftList...) + fu := raft.Apply(kvSchema{}.encode(key, value), time.Second) + _, err := fu.Response() + if err != nil { + writer.Write([]byte("fail" + err.Error())) + } else { + writer.Write([]byte("succ")) + } + +} +func (s *verifyHandle) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + fu := s.Raft.VerifyLeader() + _, err := fu.Response() + + if err != nil { + writer.Write([]byte("fail" + err.Error())) + } else { + writer.Write([]byte("succ")) + } + +} +func (s *configGetHandle) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + + cluster := s.Raft.GetConfiguration() + b, _ := json.Marshal(cluster) + writer.Write([]byte(b)) + +} +func (s *getLogHandle) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + from := cast.ToUint64(request.URL.Query().Get("from")) + to := cast.ToUint64(request.URL.Query().Get("to")) + idx := cast.ToInt(request.URL.Query().Get("idx")) + if idx < 0 || idx > len(s.raftList)-1 { + writer.Write([]byte("params error")) + return + } + raft := s.raftList[idx] + logs, err := raft.logStore.GetLogRange(from, to) + if err != nil { + writer.Write([]byte("fail" + err.Error())) + } else { + b, _ := json.MarshalIndent(logs, "", " ") + writer.Write([]byte(b)) + } + +} +func (s *snapshotHandle) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + raft := getLeader(s.raftList...) + fu := raft.SnapShot() + open, err := fu.Response() + if err != nil { + writer.Write([]byte(err.Error())) + return + } + meta, reader, err := open() + if err != nil { + writer.Write([]byte(err.Error())) + return + } + defer reader.Close() + io.Copy(writer, reader) + b, _ := json.MarshalIndent(&meta, "", " ") + writer.Write(b) +} +func (s *addPeerHandle) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + //addr := request.URL.Query().Get("addr") + //id := request.URL.Query().Get("id") + // + //leader := getLeader(s.raftList...) + //fu := leader.AddVoter(ServerID(id), ServerAddr(addr), 0, time.Second) + //_, err := fu.Response() + //if err != nil { + // writer.Write([]byte(err.Error())) + //} else { + // writer.Write([]byte("succ")) + //} +} +func (s *userRestoreSnapshotHandle) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + //raft := getLeader(s.raftList...) + //snapshot := raft.snapShotStore + //list, err := snapshot.List() + //if err != nil { + // writer.Write([]byte(err.Error())) + // return + //} + //if len(list) == 0 { + // writer.Write([]byte("have no snapshot")) + // return + //} + //meta, rc, err := snapshot.Open(list[0].ID) + //if len(list) == 0 { + // writer.Write([]byte("open" + "err:" + err.Error())) + // return + //} + //defer rc.Close() + // + //err = raft.RestoreSnapshot(meta, rc) + //if err != nil { + // writer.Write([]byte(err.Error())) + //} else { + // writer.Write([]byte("succ")) + //} +} diff --git a/rpc.go b/rpc.go new file mode 100644 index 0000000..9ae3d53 --- /dev/null +++ b/rpc.go @@ -0,0 +1,107 @@ +package papillon + +import ( + "io" +) + +type ( + + // AppendEntryRequest 追加日志 + AppendEntryRequest struct { + *RPCHeader + Term uint64 + PrevLogIndex uint64 + PrevLogTerm uint64 + Entries []*LogEntry + LeaderCommit uint64 + } + AppendEntryResponse struct { + *RPCHeader + Term uint64 + Succ bool + LatestIndex uint64 // peer 当前保存最新的日志 index ,用于新节点快速定位 nextIndex + } + // VoteRequest 投票 + VoteRequest struct { + *RPCHeader + Term uint64 + CandidateID ServerID + LastLogIndex uint64 + LastLogTerm uint64 + LeaderTransfer bool + } + VoteResponse struct { + *RPCHeader + Term uint64 + Granted bool + } + // InstallSnapshotRequest 安装快照 + InstallSnapshotRequest struct { + *RPCHeader + SnapshotMeta *SnapShotMeta + Term uint64 + } + InstallSnapshotResponse struct { + *RPCHeader + Term uint64 + Success bool + } + // FastTimeoutRequest 引导 leader 直接超时 + FastTimeoutRequest struct { + *RPCHeader + Term uint64 + LeaderShipTransfer bool + } + FastTimeoutResponse struct { + *RPCHeader + Success bool + } + // voteResult 投票结果 + voteResult struct { + *VoteResponse + ServerID ServerID + } + + RPCHeader struct { + ID ServerID + Addr ServerAddr + } +) + +type ( + // rpcType rpc 接口类型 + rpcType uint64 + // RPC rpc 请求的封装 + RPC struct { + CmdType rpcType + Request any + Response chan any + Reader io.Reader // 链接读接口,安装快照的时候用 + } + RpcInterface interface { + // Consumer 返回一个可消费的 Chan + Consumer() <-chan *RPC + // VoteRequest 发起投票请求 + VoteRequest(*ServerInfo, *VoteRequest) (*VoteResponse, error) + // AppendEntries 追加日志 + AppendEntries(*ServerInfo, *AppendEntryRequest) (*AppendEntryResponse, error) + // AppendEntryPipeline 以 pipe 形式追加日志 + AppendEntryPipeline(*ServerInfo) (AppendEntryPipeline, error) + // InstallSnapShot 安装快照 + InstallSnapShot(*ServerInfo, *InstallSnapshotRequest, io.Reader) (*InstallSnapshotResponse, error) + // SetHeartbeatFastPath 用于快速处理,不用经过主流程,不支持也没关系 + SetHeartbeatFastPath(cb fastPath) + // FastTimeout 快速超时转换为候选人 + FastTimeout(*ServerInfo, *FastTimeoutRequest) (*FastTimeoutResponse, error) + + LocalAddr() ServerAddr + EncodeAddr(info *ServerInfo) []byte + DecodeAddr([]byte) ServerAddr + } + + AppendEntryPipeline interface { + AppendEntries(*AppendEntryRequest) (AppendEntriesFuture, error) + Consumer() <-chan AppendEntriesFuture + Close() error + } +) diff --git a/snapshot.go b/snapshot.go new file mode 100644 index 0000000..3b8fc5c --- /dev/null +++ b/snapshot.go @@ -0,0 +1,67 @@ +package papillon + +import ( + "fmt" + "io" + "time" +) + +type ( + SnapshotStore interface { + Open(id string) (*SnapShotMeta, io.ReadCloser, error) + List() ([]*SnapShotMeta, error) + Create(version SnapShotVersion, index, term uint64, configuration Configuration, configurationIndex uint64, rpc RpcInterface) (SnapshotSink, error) + } + SnapshotSink interface { + io.WriteCloser + ID() string + Cancel() error + } + // SnapShotVersion 表示快照的版本,会在以后的快照结构变更的时候使用 + SnapShotVersion uint64 + SnapShotMeta struct { + Version SnapShotVersion + ID string + Index uint64 + Term uint64 + Configuration Configuration + ConfigurationIndex uint64 + Size int64 + } +) + +const ( + SnapShotVersionDefault SnapShotVersion = iota + 1 +) + +func snapshotName(term, index uint64) string { + now := time.Now() + msec := now.UnixNano() / int64(time.Millisecond) + return fmt.Sprintf("%d-%d-%d", term, index, msec) +} + +// runSnapshot 快照线程 +func (r *Raft) runSnapshot() { + ticker := time.NewTicker(r.conf.Load().SnapshotInterval) + defer ticker.Stop() + for { + select { + case <-r.shutDown.C: + return + case fu := <-r.apiSnapshotBuildCh: + id, err := r.buildSnapshot() + fn := func() (*SnapShotMeta, io.ReadCloser, error) { + return r.snapshotStore.Open(id) + } + if err != nil { + fn = nil + } + fu.responded(fn, err) + case <-ticker.C: + if !r.shouldBuildSnapshot() { + continue + } + r.buildSnapshot() + } + } +} diff --git a/state.go b/state.go new file mode 100644 index 0000000..94ee745 --- /dev/null +++ b/state.go @@ -0,0 +1,43 @@ +package papillon + +import "sync/atomic" + +type ( + State uint64 +) + +func (s *State) String() string { + switch *s { + case Follower: + return "Follower" + case Candidate: + return "Candidate" + case Leader: + return "Leader" + case ShutDown: + return "ShutDown" + default: + return "Unknown" + } +} + +const ( + Follower State = iota + Candidate + Leader + ShutDown +) + +func newState() *State { + state := new(State) + state.set(Follower) + return state +} + +func (s *State) set(newState State) { + atomic.StoreUint64((*uint64)(s), uint64(newState)) +} + +func (s *State) Get() State { + return State(atomic.LoadUint64((*uint64)(s))) +} diff --git a/store.go b/store.go new file mode 100644 index 0000000..78903c4 --- /dev/null +++ b/store.go @@ -0,0 +1,41 @@ +package papillon + +var ( + KeyCurrentTerm = []byte("CurrentTerm") + KeyLastVoteFor = []byte("LastVoteFor") + KeyLastVoteTerm = []byte("LastVoteTerm") +) + +// LogStore 提供日志操作的抽象 +type LogStore interface { + // FirstIndex 返回第一个写入的索引,-1 代表没有 + FirstIndex() (uint64, error) + // LastIndex 返回最后一个写入的索引,-1 代表没有 + LastIndex() (uint64, error) + // GetLog 返回指定位置的索引 + GetLog(index uint64) (log *LogEntry, err error) + // GetLogRange 按指定范围遍历索引,闭区间 + GetLogRange(from, to uint64) (log []*LogEntry, err error) + // SetLogs 追加日志 + SetLogs(logs []*LogEntry) error + // DeleteRange 批量删除指定范围的索引内容,用于快照生成 + DeleteRange(from, to uint64) error +} + +// KVStorage 提供稳定存储的抽象 +type KVStorage interface { + // Get 用于存储日志 + Get(key []byte) (val []byte, err error) + // Set 用于存储日志 + Set(key, val []byte) error + + // SetUint64 用于存储任期 + SetUint64(key []byte, val uint64) error + // GetUint64 用于返回任期 + GetUint64(key []byte) (uint64, error) +} + +type ConfigurationStorage interface { + KVStorage + SetConfiguration(index uint64, configuration Configuration) error +} diff --git a/t_test.go b/t_test.go new file mode 100644 index 0000000..33e3188 --- /dev/null +++ b/t_test.go @@ -0,0 +1 @@ +package papillon diff --git a/tcp_transport.go b/tcp_transport.go new file mode 100644 index 0000000..c310313 --- /dev/null +++ b/tcp_transport.go @@ -0,0 +1,67 @@ +package papillon + +import ( + "errors" + "github.com/fuyao-w/log" + "net" + "time" +) + +type TcpLayer struct { + listener net.Listener + advertise net.Addr +} + +func NewTCPTransport(bindAddr string, maxPool int, timeout time.Duration) (*NetTransport, error) { + return newTcpTransport(bindAddr, func(layer NetLayer) *NetTransport { + return NewNetTransport(&NetWorkTransportConfig{ + ServerAddressProvider: nil, + Logger: log.NewLogger(), + NetLayer: layer, + MaxPool: maxPool, + Timeout: timeout, + }) + }) +} +func newTcpTransport(bindAddr string, transportCreator func(layer NetLayer) *NetTransport) (*NetTransport, error) { + listener, err := net.Listen("tcp", bindAddr) + if err != nil { + return nil, err + } + layer := NewTcpLayer(listener, nil) + + addr, ok := layer.Addr().(*net.TCPAddr) + if !ok { + listener.Close() + return nil, errors.New("add not tcp") + } + if addr.IP == nil || addr.IP.IsUnspecified() { + listener.Close() + return nil, errors.New("err not advertisable") + } + return transportCreator(layer), nil +} +func NewTcpLayer(l net.Listener, advertise net.Addr) NetLayer { + return &TcpLayer{ + listener: l, + advertise: advertise, + } +} +func (t *TcpLayer) Accept() (net.Conn, error) { + return t.listener.Accept() +} + +func (t *TcpLayer) Close() error { + return t.listener.Close() +} + +func (t *TcpLayer) Addr() net.Addr { + if t.advertise != nil { + return t.advertise + } + return t.listener.Addr() +} + +func (t *TcpLayer) Dial(peer ServerAddr, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("tcp", string(peer), timeout) +} diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..01062b8 --- /dev/null +++ b/transport.go @@ -0,0 +1,77 @@ +package papillon + +import ( + "bufio" + "encoding/json" + "net" + "time" +) + +func (c rpcType) String() string { + switch c { + case CmdVoteRequest: + return "VoteRequest" + case CmdAppendEntryPipeline: + return "AppendEntryPipeline" + case CmdAppendEntry: + return "AppendEntry" + case CmdInstallSnapshot: + return "InstallSnapshot" + case CmdFastTimeout: + return "FastTimeout" + default: + return "UNKNOWN" + } +} + +const ( + CmdVoteRequest rpcType = iota + 1 + CmdAppendEntry + CmdAppendEntryPipeline + CmdInstallSnapshot + CmdFastTimeout +) +const ( + rpcMaxPipeline = 128 + // DefaultTimeoutScale is the default TimeoutScale in a NetworkTransport. + DefaultTimeoutScale = 256 * 1024 // 256KB +) + +// NetLayer 网络层抽象 +type NetLayer interface { + net.Listener + // Dial is used to create a new outgoing connection + Dial(peer ServerAddr, timeout time.Duration) (net.Conn, error) +} + +type ( + WithPeers interface { + Connect(addr ServerAddr, rpc RpcInterface) + Disconnect(addr ServerAddr) + DisconnectAll() + } + fastPath func(cb *RPC) bool + + PackageParser interface { + Encode(writer *bufio.Writer, cmdType rpcType, data []byte) (err error) + Decode(reader *bufio.Reader) (rpcType, []byte, error) + } + + CmdConvert interface { + Deserialization(data []byte, i interface{}) error + Serialization(i interface{}) (bytes []byte, err error) + } + + // JsonCmdHandler 提供 json 的序列化能力 + JsonCmdHandler struct{} +) + +var defaultCmdConverter = new(JsonCmdHandler) + +func (j *JsonCmdHandler) Deserialization(data []byte, i interface{}) error { + return json.Unmarshal(data, i) +} + +func (j JsonCmdHandler) Serialization(i interface{}) (bytes []byte, err error) { + return json.Marshal(i) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..918afd8 --- /dev/null +++ b/util.go @@ -0,0 +1,198 @@ +package papillon + +import ( + . "github.com/fuyao-w/common-util" + "github.com/spf13/cast" + "io" + "sync/atomic" + + crand "crypto/rand" + "fmt" + "math" + "math/big" + "math/rand" + "os" + "os/signal" + + "syscall" + "time" +) + +var rnd = rand.New(rand.NewSource(newSeed())) + +func newSeed() int64 { + r, err := crand.Int(crand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + panic(fmt.Errorf("failed to read random bytes :%s", err)) + } + return r.Int64() +} + +// shutDown 处理关机逻辑,并提供回调信息 +type shutDown struct { + dataBus DataBus + state *LockItem[bool] + C chan struct{} +} + +func newShutDown() shutDown { + return shutDown{ + dataBus: DataBus{}, + state: NewLockItem[bool](), + C: make(chan struct{}), + } +} + +func (s *shutDown) done(act func(oldState bool)) { + s.state.Action(func(t *bool) { + old := *t + *t = true + if act != nil { + act(old) + } + close(s.C) + s.dataBus.Publish(0, nil) + }) +} + +// WaitForShutDown 阻塞直到关机 +func (s *shutDown) WaitForShutDown() { + notify := make(chan os.Signal, 1) + // kill 默认会发送 syscall.SIGTERM 信号 + // kill -2 发送 syscall.SIGINT 信号,我们常用的Ctrl+C就是触发系统SIGINT信号 + // kill -9 发送 syscall.SIGKILL 信号,但是不能被捕获,所以不需要添加它 + // signal.Notify把收到的 syscall.SIGINT或syscall.SIGTERM 信号转发给quit + signal.Notify(notify, syscall.SIGINT, syscall.SIGTERM) // 此处不会阻塞 + <-notify + s.done(nil) +} + +func (s *shutDown) AddCallback(obs observer) { + s.dataBus.AddObserver(obs) +} + +type ( + // observer 回调函数类型 + observer func(event int, param interface{}) + // DataBus 提供发布、订阅功能 + DataBus struct { + observers []observer + } +) + +// AddObserver 添加订阅者 +func (d *DataBus) AddObserver(obs observer) { + d.observers = append(d.observers, obs) +} + +// Publish 触发事件 +func (d *DataBus) Publish(event int, param interface{}) { + for _, obs := range d.observers { + obs(event, param) + } +} + +// randomTimeout 返回 t 到 2 x t 时间的随机时间 +func randomTimeout(t time.Duration) <-chan time.Time { + if t == 0 { + return nil + } + return time.After(t + time.Duration(rnd.Int63())%t) +} + +func generateUUID() string { + var buf = make([]byte, 1<<4) + if _, err := crand.Read(buf); err != nil { + panic(fmt.Errorf("failed to read random bytes :%s", err)) + } + return fmt.Sprintf("%08x-%04x-%04x-%04x-%12x", + buf[:4], + buf[4:6], + buf[6:8], + buf[8:10], + buf[10:], + ) +} + +// asyncNotify 不阻塞的给 chan 发送一个信号,并返回是否发送成功 +func asyncNotify[T any](ch chan T) bool { + select { + case ch <- Zero[T](): + return true + default: + return false + } +} + +// overrideNotify 通知一个 T 类型的 channel, 如果 channel 中已经有值了,则会覆盖 +// channel 长度必须为 1 ,如果并发访问会 panic +func overrideNotify[T any](ch chan T, v T) { + for i := 0; i < 2; i++ { + select { + case ch <- v: + // 发送成功 + return + case <-ch: + // 上次投递的没人消费 + } + } + // 如果循环两次说明有其他线程在并发投递 + panic("race:channel was send concurrently") +} + +type AtomicVal[T any] struct { + v atomic.Value +} + +func NewAtomicVal[T any](val ...T) *AtomicVal[T] { + v := &AtomicVal[T]{} + if len(val) > 0 { + v.Store(val[0]) + } + return v +} +func (a *AtomicVal[T]) Load() T { + val, ok := a.v.Load().(T) + if ok { + return val + } + return Zero[T]() +} +func (a *AtomicVal[T]) Store(t T) { + a.v.Store(t) +} + +type Logger interface { + Infof(format string, v ...any) + Info(v ...any) + Errorf(format string, v ...any) + Error(v ...any) + Warnf(format string, v ...any) + Warn(v ...any) + Debugf(format string, v ...any) + Debug(v ...any) +} + +func newCounterReader(r io.Reader) *countingReader { + return &countingReader{reader: r} +} + +// countingReader 支持随时查询读取长度 +type countingReader struct { + reader io.Reader + count int64 +} + +func (r *countingReader) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + atomic.AddInt64(&r.count, int64(n)) + return +} + +func (r *countingReader) Count() int64 { + return atomic.LoadInt64(&r.count) +} + +func iToBytes(i uint64) []byte { + return Str2Bytes(cast.ToString(i)) +}