diff --git a/libs/pubsub/example_test.go b/libs/pubsub/example_test.go index f7ed17c88..888368b9e 100644 --- a/libs/pubsub/example_test.go +++ b/libs/pubsub/example_test.go @@ -19,7 +19,7 @@ func TestExample(t *testing.T) { defer s.Stop() ctx := context.Background() - subscription, err := s.Subscribe(ctx, "example-client", query.MustParse("abci.account.name='John'"), 1) + subscription, err := s.Subscribe(ctx, "example-client", query.MustParse("abci.account.name='John'")) require.NoError(t, err) err = s.PublishWithTags(ctx, "Tombstone", pubsub.NewTagMap(map[string]string{"abci.account.name": "John"})) require.NoError(t, err) diff --git a/libs/pubsub/pubsub.go b/libs/pubsub/pubsub.go index c118df0e0..774b11f23 100644 --- a/libs/pubsub/pubsub.go +++ b/libs/pubsub/pubsub.go @@ -18,7 +18,7 @@ // } // ctx, cancel := context.WithTimeout(context.Background(), 1 * time.Second) // defer cancel() -// subscription, err := pubsub.Subscribe(ctx, "johns-transactions", q, 1) +// subscription, err := pubsub.Subscribe(ctx, "johns-transactions", q) // if err != nil { // return err // } @@ -133,9 +133,28 @@ func (s *Server) BufferCapacity() int { // Subscribe creates a subscription for the given client. An error will be // returned to the caller if the context is canceled or if subscription already -// exist for pair clientID and query. outCapacity will be used to set a -// capacity for Subscription#Out channel. -func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, outCapacity int) (*Subscription, error) { +// exist for pair clientID and query. outCapacity can be used to set a +// capacity for Subscription#Out channel (1 by default). +func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, outCapacity ...int) (*Subscription, error) { + outCap := 1 + if len(outCapacity) > 0 { + if outCapacity[0] <= 0 { + panic("Negative or zero capacity. Use SubscribeUnbuffered if you want an unbuffered channel") + } + outCap = outCapacity[0] + } + + return s.subscribe(ctx, clientID, query, outCap) +} + +// SubscribeUnbuffered does the same as Subscribe, except it returns a +// subscription with unbuffered channel. Use with caution as it can freeze the +// server. +func (s *Server) SubscribeUnbuffered(ctx context.Context, clientID string, query Query) (*Subscription, error) { + return s.subscribe(ctx, clientID, query, 0) +} + +func (s *Server) subscribe(ctx context.Context, clientID string, query Query, outCapacity int) (*Subscription, error) { s.mtx.RLock() clientSubscriptions, ok := s.subscriptions[clientID] if ok { @@ -263,7 +282,7 @@ type queryPlusRefCount struct { func (s *Server) OnStart() error { go s.loop(state{ subscriptions: make(map[string]map[string]*Subscription), - queries: make(map[string]*queryPlusRefCount), + queries: make(map[string]*queryPlusRefCount), }) return nil } diff --git a/libs/pubsub/pubsub_test.go b/libs/pubsub/pubsub_test.go index 9c6ef3b30..7963be509 100644 --- a/libs/pubsub/pubsub_test.go +++ b/libs/pubsub/pubsub_test.go @@ -27,7 +27,7 @@ func TestSubscribe(t *testing.T) { defer s.Stop() ctx := context.Background() - subscription, err := s.Subscribe(ctx, clientID, query.Empty{}, 1) + subscription, err := s.Subscribe(ctx, clientID, query.Empty{}) require.NoError(t, err) err = s.Publish(ctx, "Ka-Zar") require.NoError(t, err) @@ -38,6 +38,40 @@ func TestSubscribe(t *testing.T) { assertReceive(t, "Quicksilver", subscription.Out()) } +func TestSubscribeWithOutCapacity(t *testing.T) { + s := pubsub.NewServer() + s.SetLogger(log.TestingLogger()) + s.Start() + defer s.Stop() + + ctx := context.Background() + assert.Panics(t, func() { + s.Subscribe(ctx, clientID, query.Empty{}, -1) + s.Subscribe(ctx, clientID, query.Empty{}, 0) + }) + subscription, err := s.Subscribe(ctx, clientID, query.Empty{}, 1) + require.NoError(t, err) + err = s.Publish(ctx, "Aggamon") + require.NoError(t, err) + assertReceive(t, "Aggamon", subscription.Out()) +} + +func TestSubscribeUnbuffered(t *testing.T) { + s := pubsub.NewServer() + s.SetLogger(log.TestingLogger()) + s.Start() + defer s.Stop() + + ctx := context.Background() + subscription, err := s.SubscribeUnbuffered(ctx, clientID, query.Empty{}) + require.NoError(t, err) + go func() { + err = s.Publish(ctx, "Ultron") + require.NoError(t, err) + }() + assertReceive(t, "Ultron", subscription.Out()) +} + func TestDifferentClients(t *testing.T) { s := pubsub.NewServer() s.SetLogger(log.TestingLogger()) @@ -45,20 +79,20 @@ func TestDifferentClients(t *testing.T) { defer s.Stop() ctx := context.Background() - subscription1, err := s.Subscribe(ctx, "client-1", query.MustParse("tm.events.type='NewBlock'"), 1) + subscription1, err := s.Subscribe(ctx, "client-1", query.MustParse("tm.events.type='NewBlock'")) require.NoError(t, err) err = s.PublishWithTags(ctx, "Iceman", pubsub.NewTagMap(map[string]string{"tm.events.type": "NewBlock"})) require.NoError(t, err) assertReceive(t, "Iceman", subscription1.Out()) - subscription2, err := s.Subscribe(ctx, "client-2", query.MustParse("tm.events.type='NewBlock' AND abci.account.name='Igor'"), 1) + subscription2, err := s.Subscribe(ctx, "client-2", query.MustParse("tm.events.type='NewBlock' AND abci.account.name='Igor'")) require.NoError(t, err) err = s.PublishWithTags(ctx, "Ultimo", pubsub.NewTagMap(map[string]string{"tm.events.type": "NewBlock", "abci.account.name": "Igor"})) require.NoError(t, err) assertReceive(t, "Ultimo", subscription1.Out()) assertReceive(t, "Ultimo", subscription2.Out()) - subscription3, err := s.Subscribe(ctx, "client-3", query.MustParse("tm.events.type='NewRoundStep' AND abci.account.name='Igor' AND abci.invoice.number = 10"), 1) + subscription3, err := s.Subscribe(ctx, "client-3", query.MustParse("tm.events.type='NewRoundStep' AND abci.account.name='Igor' AND abci.invoice.number = 10")) require.NoError(t, err) err = s.PublishWithTags(ctx, "Valeria Richards", pubsub.NewTagMap(map[string]string{"tm.events.type": "NewRoundStep"})) require.NoError(t, err) @@ -74,13 +108,13 @@ func TestClientSubscribesTwice(t *testing.T) { ctx := context.Background() q := query.MustParse("tm.events.type='NewBlock'") - subscription1, err := s.Subscribe(ctx, clientID, q, 1) + subscription1, err := s.Subscribe(ctx, clientID, q) require.NoError(t, err) err = s.PublishWithTags(ctx, "Goblin Queen", pubsub.NewTagMap(map[string]string{"tm.events.type": "NewBlock"})) require.NoError(t, err) assertReceive(t, "Goblin Queen", subscription1.Out()) - subscription2, err := s.Subscribe(ctx, clientID, q, 1) + subscription2, err := s.Subscribe(ctx, clientID, q) require.Error(t, err) require.Nil(t, subscription2) @@ -96,7 +130,7 @@ func TestUnsubscribe(t *testing.T) { defer s.Stop() ctx := context.Background() - subscription, err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), 1) + subscription, err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'")) require.NoError(t, err) err = s.Unsubscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'")) require.NoError(t, err) @@ -118,7 +152,7 @@ func TestClientUnsubscribesTwice(t *testing.T) { defer s.Stop() ctx := context.Background() - _, err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), 1) + _, err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'")) require.NoError(t, err) err = s.Unsubscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'")) require.NoError(t, err) @@ -136,11 +170,11 @@ func TestResubscribe(t *testing.T) { defer s.Stop() ctx := context.Background() - subscription, err := s.Subscribe(ctx, clientID, query.Empty{}, 1) + subscription, err := s.Subscribe(ctx, clientID, query.Empty{}) require.NoError(t, err) err = s.Unsubscribe(ctx, clientID, query.Empty{}) require.NoError(t, err) - subscription, err = s.Subscribe(ctx, clientID, query.Empty{}, 1) + subscription, err = s.Subscribe(ctx, clientID, query.Empty{}) require.NoError(t, err) err = s.Publish(ctx, "Cable") @@ -155,9 +189,9 @@ func TestUnsubscribeAll(t *testing.T) { defer s.Stop() ctx := context.Background() - subscription1, err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), 1) + subscription1, err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'")) require.NoError(t, err) - subscription2, err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlockHeader'"), 1) + subscription2, err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlockHeader'")) require.NoError(t, err) err = s.UnsubscribeAll(ctx, clientID) @@ -211,7 +245,7 @@ func benchmarkNClients(n int, b *testing.B) { ctx := context.Background() for i := 0; i < n; i++ { - subscription, err := s.Subscribe(ctx, clientID, query.MustParse(fmt.Sprintf("abci.Account.Owner = 'Ivan' AND abci.Invoices.Number = %d", i)), 1) + subscription, err := s.Subscribe(ctx, clientID, query.MustParse(fmt.Sprintf("abci.Account.Owner = 'Ivan' AND abci.Invoices.Number = %d", i))) if err != nil { b.Fatal(err) } @@ -242,7 +276,7 @@ func benchmarkNClientsOneQuery(n int, b *testing.B) { ctx := context.Background() q := query.MustParse("abci.Account.Owner = 'Ivan' AND abci.Invoices.Number = 1") for i := 0; i < n; i++ { - subscription, err := s.Subscribe(ctx, clientID, q, 1) + subscription, err := s.Subscribe(ctx, clientID, q) if err != nil { b.Fatal(err) }