diff --git a/discovery_test.go b/discovery_test.go index 59e51812..0444c542 100644 --- a/discovery_test.go +++ b/discovery_test.go @@ -108,6 +108,21 @@ func (d *mockDiscoveryClient) FindPeers(ctx context.Context, ns string, opts ... return d.server.FindPeers(ns, options.Limit) } +type dummyDiscovery struct{} + +func (d *dummyDiscovery) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { + return time.Hour, nil +} + +func (d *dummyDiscovery) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { + retCh := make(chan peer.AddrInfo) + go func() { + time.Sleep(time.Second) + close(retCh) + }() + return retCh, nil +} + func TestSimpleDiscovery(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/pubsub.go b/pubsub.go index cffb83be..a769a68c 100644 --- a/pubsub.go +++ b/pubsub.go @@ -789,9 +789,13 @@ func (p *PubSub) tryJoin(topic string, opts ...TopicOpt) (*Topic, bool, error) { } resp := make(chan *Topic, 1) - t.p.addTopic <- &addTopicReq{ + select { + case t.p.addTopic <- &addTopicReq{ topic: t, resp: resp, + }: + case <-t.p.ctx.Done(): + return nil, false, t.p.ctx.Err() } returnedTopic := <-resp @@ -848,7 +852,11 @@ type topicReq struct { // GetTopics returns the topics this node is subscribed to. func (p *PubSub) GetTopics() []string { out := make(chan []string, 1) - p.getTopics <- &topicReq{resp: out} + select { + case p.getTopics <- &topicReq{resp: out}: + case <-p.ctx.Done(): + return nil + } return <-out } @@ -880,16 +888,23 @@ type listPeerReq struct { // ListPeers returns a list of peers we are connected to in the given topic. func (p *PubSub) ListPeers(topic string) []peer.ID { out := make(chan []peer.ID) - p.getPeers <- &listPeerReq{ + select { + case p.getPeers <- &listPeerReq{ resp: out, topic: topic, + }: + case <-p.ctx.Done(): + return nil } return <-out } // BlacklistPeer blacklists a peer; all messages from this peer will be unconditionally dropped. func (p *PubSub) BlacklistPeer(pid peer.ID) { - p.blacklistPeer <- pid + select { + case p.blacklistPeer <- pid: + case <-p.ctx.Done(): + } } // RegisterTopicValidator registers a validator for topic. @@ -910,7 +925,11 @@ func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...Val } } - p.addVal <- addVal + select { + case p.addVal <- addVal: + case <-p.ctx.Done(): + return p.ctx.Err() + } return <-addVal.resp } @@ -922,6 +941,10 @@ func (p *PubSub) UnregisterTopicValidator(topic string) error { resp: make(chan error, 1), } - p.rmVal <- rmVal + select { + case p.rmVal <- rmVal: + case <-p.ctx.Done(): + return p.ctx.Err() + } return <-rmVal.resp } diff --git a/topic.go b/topic.go index 2ccc4a44..54347914 100644 --- a/topic.go +++ b/topic.go @@ -2,6 +2,7 @@ package pubsub import ( "context" + "errors" "fmt" "sync" @@ -10,6 +11,9 @@ import ( "github.com/libp2p/go-libp2p-core/peer" ) +// ErrTopicClosed is returned if a Topic is utilized after it has been closed +var ErrTopicClosed = errors.New("this Topic is closed, try opening a new one") + // Topic is the handle for a pubsub topic type Topic struct { p *PubSub @@ -17,13 +21,23 @@ type Topic struct { evtHandlerMux sync.RWMutex evtHandlers map[*TopicEventHandler]struct{} + + mux sync.RWMutex + closed bool } // EventHandler creates a handle for topic specific events // Multiple event handlers may be created and will operate independently of each other func (t *Topic) EventHandler(opts ...TopicEventHandlerOpt) (*TopicEventHandler, error) { + t.mux.RLock() + defer t.mux.RUnlock() + if t.closed { + return nil, ErrTopicClosed + } + h := &TopicEventHandler{ - err: nil, + topic: t, + err: nil, evtLog: make(map[peer.ID]EventType), evtLogCh: make(chan struct{}, 1), @@ -37,7 +51,9 @@ func (t *Topic) EventHandler(opts ...TopicEventHandlerOpt) (*TopicEventHandler, } done := make(chan struct{}, 1) - t.p.eval <- func() { + + select { + case t.p.eval <- func() { tmap := t.p.topics[t.topic] for p := range tmap { h.evtLog[p] = PeerJoin @@ -47,6 +63,9 @@ func (t *Topic) EventHandler(opts ...TopicEventHandlerOpt) (*TopicEventHandler, t.evtHandlers[h] = struct{}{} t.evtHandlerMux.Unlock() done <- struct{}{} + }: + case <-t.p.ctx.Done(): + return nil, t.p.ctx.Err() } <-done @@ -67,6 +86,12 @@ func (t *Topic) sendNotification(evt PeerEvent) { // Note that subscription is not an instanteneous operation. It may take some time // before the subscription is processed by the pubsub main loop and propagated to our peers. func (t *Topic) Subscribe(opts ...SubOpt) (*Subscription, error) { + t.mux.RLock() + defer t.mux.RUnlock() + if t.closed { + return nil, ErrTopicClosed + } + sub := &Subscription{ topic: t.topic, ch: make(chan *Message, 32), @@ -84,9 +109,13 @@ func (t *Topic) Subscribe(opts ...SubOpt) (*Subscription, error) { t.p.disc.Discover(sub.topic) - t.p.addSub <- &addSubReq{ + select { + case t.p.addSub <- &addSubReq{ sub: sub, resp: out, + }: + case <-t.p.ctx.Done(): + return nil, t.p.ctx.Err() } return <-out, nil @@ -103,6 +132,12 @@ type PubOpt func(pub *PublishOptions) error // Publish publishes data to topic. func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error { + t.mux.RLock() + defer t.mux.RUnlock() + if t.closed { + return ErrTopicClosed + } + seqno := t.p.nextSeqno() id := t.p.host.ID() m := &pb.Message{ @@ -131,7 +166,11 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error t.p.disc.Bootstrap(ctx, t.topic, pub.ready) } - t.p.publish <- &Message{m, id} + select { + case t.p.publish <- &Message{m, id}: + case <-t.p.ctx.Done(): + return t.p.ctx.Err() + } return nil } @@ -148,13 +187,37 @@ func WithReadiness(ready RouterReady) PubOpt { // Close closes down the topic. Will return an error unless there are no active event handlers or subscriptions. // Does not error if the topic is already closed. func (t *Topic) Close() error { + t.mux.Lock() + defer t.mux.Unlock() + if t.closed { + return nil + } + req := &rmTopicReq{t, make(chan error, 1)} - t.p.rmTopic <- req - return <-req.resp + + select { + case t.p.rmTopic <- req: + case <-t.p.ctx.Done(): + return t.p.ctx.Err() + } + + err := <-req.resp + + if err == nil { + t.closed = true + } + + return err } // ListPeers returns a list of peers we are connected to in the given topic. func (t *Topic) ListPeers() []peer.ID { + t.mux.RLock() + defer t.mux.RUnlock() + if t.closed { + return []peer.ID{} + } + return t.p.ListPeers(t.topic) } diff --git a/topic_test.go b/topic_test.go index 8ea6d1ee..13618179 100644 --- a/topic_test.go +++ b/topic_test.go @@ -1,6 +1,7 @@ package pubsub import ( + "bytes" "context" "fmt" "sync" @@ -38,7 +39,39 @@ func getTopicEvts(topics []*Topic, opts ...TopicEventHandlerOpt) []*TopicEventHa return handlers } -func TestTopicClose(t *testing.T) { +func TestTopicCloseWithOpenSubscription(t *testing.T) { + var sub *Subscription + var err error + testTopicCloseWithOpenResource(t, + func(topic *Topic) { + sub, err = topic.Subscribe() + if err != nil { + t.Fatal(err) + } + }, + func() { + sub.Cancel() + }, + ) +} + +func TestTopicCloseWithOpenEventHandler(t *testing.T) { + var evts *TopicEventHandler + var err error + testTopicCloseWithOpenResource(t, + func(topic *Topic) { + evts, err = topic.EventHandler() + if err != nil { + t.Fatal(err) + } + }, + func() { + evts.Cancel() + }, + ) +} + +func testTopicCloseWithOpenResource(t *testing.T, openResource func(topic *Topic), closeResource func()) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -57,23 +90,20 @@ func TestTopicClose(t *testing.T) { t.Fatal(err) } - // Try create and cancel topic while there's an outstanding subscription + // Try create and cancel topic while there's an outstanding subscription/event handler topic, err = ps.Join(topicID) if err != nil { t.Fatal(err) } - sub, err := topic.Subscribe() - if err != nil { - t.Fatal(err) - } + openResource(topic) if err := topic.Close(); err == nil { - t.Fatal("expected an error closing a topic with an open subscription") + t.Fatal("expected an error closing a topic with an open resource") } - // Check if the topic closes properly after canceling the outstanding subscription - sub.Cancel() + // Check if the topic closes properly after closing the resource + closeResource() time.Sleep(time.Millisecond * 100) if err := topic.Close(); err != nil { @@ -81,6 +111,132 @@ func TestTopicClose(t *testing.T) { } } +func TestTopicReuse(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const numHosts = 2 + topicID := "foobar" + hosts := getNetHosts(t, ctx, numHosts) + + sender := getPubsub(ctx, hosts[0], WithDiscovery(&dummyDiscovery{})) + receiver := getPubsub(ctx, hosts[1]) + + connectAll(t, hosts) + + // Sender creates topic + sendTopic, err := sender.Join(topicID) + if err != nil { + t.Fatal(err) + } + + // Receiver creates and subscribes to the topic + receiveTopic, err := receiver.Join(topicID) + if err != nil { + t.Fatal(err) + } + + sub, err := receiveTopic.Subscribe() + if err != nil { + t.Fatal(err) + } + + firstMsg := []byte("1") + if err := sendTopic.Publish(ctx, firstMsg, WithReadiness(MinTopicSize(1))); err != nil { + t.Fatal(err) + } + + msg, err := sub.Next(ctx) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(msg.GetData(), firstMsg) != 0 { + t.Fatal("received incorrect message") + } + + if err := sendTopic.Close(); err != nil { + t.Fatal(err) + } + + // Recreate the same topic + newSendTopic, err := sender.Join(topicID) + if err != nil { + t.Fatal(err) + } + + // Try sending data with original topic + illegalSend := []byte("illegal") + if err := sendTopic.Publish(ctx, illegalSend); err != ErrTopicClosed { + t.Fatal(err) + } + + timeoutCtx, timeoutCancel := context.WithTimeout(ctx, time.Second*2) + defer timeoutCancel() + msg, err = sub.Next(timeoutCtx) + if err != context.DeadlineExceeded { + if err != nil { + t.Fatal(err) + } + if bytes.Compare(msg.GetData(), illegalSend) != 0 { + t.Fatal("received incorrect message from illegal topic") + } + t.Fatal("received message sent by illegal topic") + } + timeoutCancel() + + // Try cancelling the new topic by using the original topic + if err := sendTopic.Close(); err != nil { + t.Fatal(err) + } + + secondMsg := []byte("2") + if err := newSendTopic.Publish(ctx, secondMsg); err != nil { + t.Fatal(err) + } + + timeoutCtx, timeoutCancel = context.WithTimeout(ctx, time.Second*2) + defer timeoutCancel() + msg, err = sub.Next(ctx) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(msg.GetData(), secondMsg) != 0 { + t.Fatal("received incorrect message") + } +} + +func TestTopicEventHandlerCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const numHosts = 5 + topicID := "foobar" + hosts := getNetHosts(t, ctx, numHosts) + ps := getPubsub(ctx, hosts[0]) + + // Try create and cancel topic + topic, err := ps.Join(topicID) + if err != nil { + t.Fatal(err) + } + + evts, err := topic.EventHandler() + if err != nil { + t.Fatal(err) + } + evts.Cancel() + timeoutCtx, timeoutCancel := context.WithTimeout(ctx, time.Second*2) + defer timeoutCancel() + connectAll(t, hosts) + _, err = evts.NextPeerEvent(timeoutCtx) + if err != context.DeadlineExceeded { + if err != nil { + t.Fatal(err) + } + t.Fatal("received event after cancel") + } +} + func TestSubscriptionJoinNotification(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel()