From 6fd3902ad19b256d2181b7a2274ac151ff57050e Mon Sep 17 00:00:00 2001 From: Michael Boeckli Date: Tue, 19 Mar 2024 14:21:51 +0100 Subject: [PATCH] Graceful shutdown of read handler go routine --- Peer.go | 258 ++++++++++++++++++++++++-------------------- PeerManager_test.go | 4 + Peer_Mock.go | 2 + interface.go | 1 + 4 files changed, 146 insertions(+), 119 deletions(-) diff --git a/Peer.go b/Peer.go index b15a4e3..9be8b2b 100644 --- a/Peer.go +++ b/Peer.go @@ -71,6 +71,7 @@ type Peer struct { dataBatcher *batcher.Batcher[chainhash.Hash] maximumMessageSize int64 isHealthy bool + quitReadHandler chan struct{} } // NewPeer returns a new bitcoin peer for the provided address and configuration. @@ -192,7 +193,7 @@ func (p *Peer) connect() error { p.readConn = conn } - go p.readHandler() + p.startReadHandler() // write version message to our peer directly and not through the write channel, // write channel is not ready to send message until the VERACK handshake is done @@ -278,150 +279,163 @@ func (p *Peer) readRetry(r io.Reader, pver uint32, bsvnet wire.BitcoinNet) (wire return msg, nil } -func (p *Peer) readHandler() { - readConn := p.readConn +func (p *Peer) startReadHandler() { + p.quitReadHandler = make(chan struct{}) - if readConn == nil { - p.logger.Error("no connection") - return - } + go func() { - reader := bufio.NewReader(&io.LimitedReader{R: readConn, N: p.maximumMessageSize}) - for { - msg, err := p.readRetry(reader, wire.ProtocolVersion, p.network) - if err != nil { - p.logger.Error("Retrying to read failed", slog.String(errKey, err.Error())) + readConn := p.readConn - // by disconnecting ensure that peer will try to reconnect - p.disconnect() + if readConn == nil { + p.logger.Error("no connection") return } - commandLogger := p.logger.With(slog.String(commandKey, strings.ToUpper(msg.Command()))) + reader := bufio.NewReader(&io.LimitedReader{R: readConn, N: p.maximumMessageSize}) + for { + select { + case <-p.quitReadHandler: + return + default: + msg, err := p.readRetry(reader, wire.ProtocolVersion, p.network) + if err != nil { + p.logger.Error("Retrying to read failed", slog.String(errKey, err.Error())) - // we could check this based on type (switch msg.(type)) but that would not allow - // us to override the default behaviour for a specific message type - switch msg.Command() { - case wire.CmdVersion: - commandLogger.Debug(receivedMsg) - if p.sentVerAck.Load() { - commandLogger.Warn("Received version message after sending verack") - continue - } + p.disconnect() - verackMsg := wire.NewMsgVerAck() - if err = wire.WriteMessage(readConn, verackMsg, wire.ProtocolVersion, p.network); err != nil { - commandLogger.Error("failed to write message", slog.String(errKey, err.Error())) - } - commandLogger.Debug(sentMsg, slog.String(commandKey, strings.ToUpper(verackMsg.Command()))) - p.sentVerAck.Store(true) + p.mu.Lock() + p.quitReadHandler = nil + p.mu.Unlock() - case wire.CmdPing: - commandLogger.Debug(receivedMsg, slog.String(commandKey, strings.ToUpper(wire.CmdPing))) - p.pingPongAlive <- struct{}{} + return + } - pingMsg, ok := msg.(*wire.MsgPing) - if !ok { - continue - } - p.writeChan <- wire.NewMsgPong(pingMsg.Nonce) + commandLogger := p.logger.With(slog.String(commandKey, strings.ToUpper(msg.Command()))) - case wire.CmdInv: - invMsg, ok := msg.(*wire.MsgInv) - if !ok { - continue - } - for _, inv := range invMsg.InvList { - commandLogger.Debug(receivedMsg, slog.String(hashKey, inv.Hash.String()), slog.String(typeKey, inv.Type.String())) - } + // we could check this based on type (switch msg.(type)) but that would not allow + // us to override the default behaviour for a specific message type + switch msg.Command() { + case wire.CmdVersion: + commandLogger.Debug(receivedMsg) + if p.sentVerAck.Load() { + commandLogger.Warn("Received version message after sending verack") + continue + } - go func(invList []*wire.InvVect, routineLogger *slog.Logger) { - for _, invVect := range invList { - switch invVect.Type { - case wire.InvTypeTx: - if err = p.peerHandler.HandleTransactionAnnouncement(invVect, p); err != nil { - commandLogger.Error("Unable to process tx", slog.String(hashKey, invVect.Hash.String()), slog.String(typeKey, invVect.Type.String()), slog.String(errKey, err.Error())) - } - case wire.InvTypeBlock: - if err = p.peerHandler.HandleBlockAnnouncement(invVect, p); err != nil { - commandLogger.Error("Unable to process block", slog.String(hashKey, invVect.Hash.String()), slog.String(typeKey, invVect.Type.String()), slog.String(errKey, err.Error())) - } + verackMsg := wire.NewMsgVerAck() + if err = wire.WriteMessage(readConn, verackMsg, wire.ProtocolVersion, p.network); err != nil { + commandLogger.Error("failed to write message", slog.String(errKey, err.Error())) } - } - }(invMsg.InvList, commandLogger) + commandLogger.Debug(sentMsg, slog.String(commandKey, strings.ToUpper(verackMsg.Command()))) + p.sentVerAck.Store(true) - case wire.CmdGetData: - dataMsg, ok := msg.(*wire.MsgGetData) - if !ok { - continue - } - for _, inv := range dataMsg.InvList { - commandLogger.Debug(receivedMsg, slog.String(hashKey, inv.Hash.String()), slog.String(typeKey, inv.Type.String())) - } - p.handleGetDataMsg(dataMsg, commandLogger) + case wire.CmdPing: + commandLogger.Debug(receivedMsg, slog.String(commandKey, strings.ToUpper(wire.CmdPing))) + p.pingPongAlive <- struct{}{} - case wire.CmdTx: - txMsg, ok := msg.(*wire.MsgTx) - if !ok { - continue - } - commandLogger.Debug(receivedMsg, slog.String(hashKey, txMsg.TxHash().String()), slog.Int("size", txMsg.SerializeSize())) - if err = p.peerHandler.HandleTransaction(txMsg, p); err != nil { - commandLogger.Error("Unable to process tx", slog.String(hashKey, txMsg.TxHash().String()), slog.String(errKey, err.Error())) - } + pingMsg, ok := msg.(*wire.MsgPing) + if !ok { + continue + } + p.writeChan <- wire.NewMsgPong(pingMsg.Nonce) - case wire.CmdBlock: - msgBlock, ok := msg.(*wire.MsgBlock) - if ok { - commandLogger.Info(receivedMsg, slog.String(hashKey, msgBlock.Header.BlockHash().String())) + case wire.CmdInv: + invMsg, ok := msg.(*wire.MsgInv) + if !ok { + continue + } + for _, inv := range invMsg.InvList { + commandLogger.Debug(receivedMsg, slog.String(hashKey, inv.Hash.String()), slog.String(typeKey, inv.Type.String())) + } - err = p.peerHandler.HandleBlock(msgBlock, p) - if err != nil { - commandLogger.Error("Unable to process block", slog.String(hashKey, msgBlock.Header.BlockHash().String()), slog.String(errKey, err.Error())) - } - continue - } + go func(invList []*wire.InvVect, routineLogger *slog.Logger) { + for _, invVect := range invList { + switch invVect.Type { + case wire.InvTypeTx: + if err = p.peerHandler.HandleTransactionAnnouncement(invVect, p); err != nil { + commandLogger.Error("Unable to process tx", slog.String(hashKey, invVect.Hash.String()), slog.String(typeKey, invVect.Type.String()), slog.String(errKey, err.Error())) + } + case wire.InvTypeBlock: + if err = p.peerHandler.HandleBlockAnnouncement(invVect, p); err != nil { + commandLogger.Error("Unable to process block", slog.String(hashKey, invVect.Hash.String()), slog.String(typeKey, invVect.Type.String()), slog.String(errKey, err.Error())) + } + } + } + }(invMsg.InvList, commandLogger) - // Please note that this is the BlockMessage, not the wire.MsgBlock - blockMsg, ok := msg.(*BlockMessage) - if !ok { - commandLogger.Error("Unable to cast block message, calling with generic wire.Message") - err = p.peerHandler.HandleBlock(msg, p) - if err != nil { - commandLogger.Error("Unable to process block message", slog.String(errKey, err.Error())) - } - continue - } + case wire.CmdGetData: + dataMsg, ok := msg.(*wire.MsgGetData) + if !ok { + continue + } + for _, inv := range dataMsg.InvList { + commandLogger.Debug(receivedMsg, slog.String(hashKey, inv.Hash.String()), slog.String(typeKey, inv.Type.String())) + } + p.handleGetDataMsg(dataMsg, commandLogger) - commandLogger.Info(receivedMsg, slog.String(hashKey, blockMsg.Header.BlockHash().String())) + case wire.CmdTx: + txMsg, ok := msg.(*wire.MsgTx) + if !ok { + continue + } + commandLogger.Debug(receivedMsg, slog.String(hashKey, txMsg.TxHash().String()), slog.Int("size", txMsg.SerializeSize())) + if err = p.peerHandler.HandleTransaction(txMsg, p); err != nil { + commandLogger.Error("Unable to process tx", slog.String(hashKey, txMsg.TxHash().String()), slog.String(errKey, err.Error())) + } - err = p.peerHandler.HandleBlock(blockMsg, p) - if err != nil { - commandLogger.Error("Unable to process block", slog.String(hashKey, blockMsg.Header.BlockHash().String()), slog.String(errKey, err.Error())) - } + case wire.CmdBlock: + msgBlock, ok := msg.(*wire.MsgBlock) + if ok { + commandLogger.Info(receivedMsg, slog.String(hashKey, msgBlock.Header.BlockHash().String())) - case wire.CmdReject: - rejMsg, ok := msg.(*wire.MsgReject) - if !ok { - continue - } - if err = p.peerHandler.HandleTransactionRejection(rejMsg, p); err != nil { - commandLogger.Error("Unable to process block", slog.String(hashKey, rejMsg.Hash.String()), slog.String(errKey, err.Error())) - } + err = p.peerHandler.HandleBlock(msgBlock, p) + if err != nil { + commandLogger.Error("Unable to process block", slog.String(hashKey, msgBlock.Header.BlockHash().String()), slog.String(errKey, err.Error())) + } + continue + } - case wire.CmdVerAck: - commandLogger.Debug(receivedMsg) - p.receivedVerAck.Store(true) + // Please note that this is the BlockMessage, not the wire.MsgBlock + blockMsg, ok := msg.(*BlockMessage) + if !ok { + commandLogger.Error("Unable to cast block message, calling with generic wire.Message") + err = p.peerHandler.HandleBlock(msg, p) + if err != nil { + commandLogger.Error("Unable to process block message", slog.String(errKey, err.Error())) + } + continue + } - case wire.CmdPong: - commandLogger.Debug(receivedMsg, slog.String(commandKey, strings.ToUpper(wire.CmdPong))) - p.pingPongAlive <- struct{}{} + commandLogger.Info(receivedMsg, slog.String(hashKey, blockMsg.Header.BlockHash().String())) - default: + err = p.peerHandler.HandleBlock(blockMsg, p) + if err != nil { + commandLogger.Error("Unable to process block", slog.String(hashKey, blockMsg.Header.BlockHash().String()), slog.String(errKey, err.Error())) + } + + case wire.CmdReject: + rejMsg, ok := msg.(*wire.MsgReject) + if !ok { + continue + } + if err = p.peerHandler.HandleTransactionRejection(rejMsg, p); err != nil { + commandLogger.Error("Unable to process block", slog.String(hashKey, rejMsg.Hash.String()), slog.String(errKey, err.Error())) + } + + case wire.CmdVerAck: + commandLogger.Debug(receivedMsg) + p.receivedVerAck.Store(true) - commandLogger.Debug("command ignored") + case wire.CmdPong: + commandLogger.Debug(receivedMsg, slog.String(commandKey, strings.ToUpper(wire.CmdPong))) + p.pingPongAlive <- struct{}{} + + default: + commandLogger.Debug("command ignored") + } + } } - } + }() } func (p *Peer) handleGetDataMsg(dataMsg *wire.MsgGetData, logger *slog.Logger) { @@ -665,3 +679,9 @@ func (p *Peer) IsHealthy() bool { return p.isHealthy } + +func (p *Peer) Shutdown() { + if p.quitReadHandler != nil { + p.quitReadHandler <- struct{}{} + } +} diff --git a/PeerManager_test.go b/PeerManager_test.go index 4e599a5..23faf62 100644 --- a/PeerManager_test.go +++ b/PeerManager_test.go @@ -34,6 +34,7 @@ func TestNewPeerManager(t *testing.T) { peerHandler := NewMockPeerHandler() peer, err := NewPeer(logger, "localhost:18333", peerHandler, wire.TestNet) + defer peer.Shutdown() require.NoError(t, err) err = pm.AddPeer(peer) @@ -57,6 +58,7 @@ func TestNewPeerManager(t *testing.T) { for _, peerAddress := range peers { peer, _ := NewPeer(logger, peerAddress, peerHandler, wire.TestNet) _ = pm.AddPeer(peer) + defer peer.Shutdown() } assert.Len(t, pm.GetPeers(), 4) @@ -73,6 +75,8 @@ func TestAnnounceNewTransaction(t *testing.T) { peer, _ := NewPeerMock("localhost:18333", peerHandler, wire.TestNet) err := pm.AddPeer(peer) + defer peer.Shutdown() + require.NoError(t, err) pm.AnnounceTransaction(tx1Hash, nil) diff --git a/Peer_Mock.go b/Peer_Mock.go index 53b54c2..dc2fa35 100644 --- a/Peer_Mock.go +++ b/Peer_Mock.go @@ -47,6 +47,8 @@ func (p *PeerMock) IsHealthy() bool { return true } +func (p *PeerMock) Shutdown() {} + func (p *PeerMock) Connected() bool { return true } diff --git a/interface.go b/interface.go index 6d98c9d..0862a14 100644 --- a/interface.go +++ b/interface.go @@ -30,6 +30,7 @@ type PeerI interface { RequestBlock(blockHash *chainhash.Hash) Network() wire.BitcoinNet IsHealthy() bool + Shutdown() } type PeerHandlerI interface {