From 0266ab77ab6dd5235a65c66dd9b646fb199b11d7 Mon Sep 17 00:00:00 2001 From: Jesse de Wit Date: Mon, 5 Sep 2022 13:20:38 +0200 Subject: [PATCH] routing+routerrpc: test stream cancellation Test stream cancellation of the TrackPayments rpc call. In order to achieve this, ControlTowerSubscriber is converted to an interface, to avoid trying to close a null channel when closing the subscription. By returning a mock implementation of the ControlTowerSubscriber in the test that problem is avoided. --- lnrpc/routerrpc/router_server.go | 4 +- lnrpc/routerrpc/router_server_test.go | 107 ++++++++++++++++++-------- routing/control_tower.go | 48 ++++++++---- routing/control_tower_test.go | 30 ++++---- routing/mock_test.go | 12 +-- 5 files changed, 131 insertions(+), 70 deletions(-) diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index c5e7a17f7..da4321eed 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -791,7 +791,7 @@ func (s *Server) TrackPayments(request *TrackPaymentsRequest, // trackPaymentStream streams payment updates to the client. func (s *Server) trackPaymentStream(context context.Context, - subscription *routing.ControlTowerSubscriber, noInflightUpdates bool, + subscription routing.ControlTowerSubscriber, noInflightUpdates bool, send func(*lnrpc.Payment) error) error { defer subscription.Close() @@ -799,7 +799,7 @@ func (s *Server) trackPaymentStream(context context.Context, // Stream updates back to the client. for { select { - case item, ok := <-subscription.Updates: + case item, ok := <-subscription.Updates(): if !ok { // No more payment updates. return nil diff --git a/lnrpc/routerrpc/router_server_test.go b/lnrpc/routerrpc/router_server_test.go index c7d854924..1cdb4b586 100644 --- a/lnrpc/routerrpc/router_server_test.go +++ b/lnrpc/routerrpc/router_server_test.go @@ -13,21 +13,46 @@ import ( "google.golang.org/grpc" ) -func makeStreamMock() *StreamMock { - return &StreamMock{ - ctx: context.Background(), - sentFromServer: make(chan *lnrpc.Payment, 10), - } -} - -type StreamMock struct { +type streamMock struct { grpc.ServerStream ctx context.Context sentFromServer chan *lnrpc.Payment } -func makeControlTowerMock() *ControlTowerMock { - towerMock := &ControlTowerMock{ +func makeStreamMock(ctx context.Context) *streamMock { + return &streamMock{ + ctx: ctx, + sentFromServer: make(chan *lnrpc.Payment, 10), + } +} + +func (m *streamMock) Context() context.Context { + return m.ctx +} + +func (m *streamMock) Send(p *lnrpc.Payment) error { + m.sentFromServer <- p + return nil +} + +type controlTowerSubscriberMock struct { + updates <-chan interface{} +} + +func (s controlTowerSubscriberMock) Updates() <-chan interface{} { + return s.updates +} + +func (s controlTowerSubscriberMock) Close() { +} + +type controlTowerMock struct { + queue *queue.ConcurrentQueue + routing.ControlTower +} + +func makeControlTowerMock() *controlTowerMock { + towerMock := &controlTowerMock{ queue: queue.NewConcurrentQueue(20), } towerMock.queue.Start() @@ -35,26 +60,40 @@ func makeControlTowerMock() *ControlTowerMock { return towerMock } -type ControlTowerMock struct { - queue *queue.ConcurrentQueue - routing.ControlTower -} +func (t *controlTowerMock) SubscribeAllPayments() ( + routing.ControlTowerSubscriber, error) { -func (t *ControlTowerMock) SubscribeAllPayments() ( - *routing.ControlTowerSubscriber, error) { - - return &routing.ControlTowerSubscriber{ - Updates: t.queue.ChanOut(), + return &controlTowerSubscriberMock{ + updates: t.queue.ChanOut(), }, nil } -func (m *StreamMock) Context() context.Context { - return m.ctx -} +// TestTrackPaymentsReturnsOnCancelContext tests whether TrackPayments returns +// when the stream context is cancelled. +func TestTrackPaymentsReturnsOnCancelContext(t *testing.T) { + // Setup mocks and request. + request := &TrackPaymentsRequest{ + NoInflightUpdates: false, + } + towerMock := makeControlTowerMock() -func (m *StreamMock) Send(p *lnrpc.Payment) error { - m.sentFromServer <- p - return nil + streamCtx, cancelStream := context.WithCancel(context.Background()) + stream := makeStreamMock(streamCtx) + + server := &Server{ + cfg: &Config{ + RouterBackend: &RouterBackend{ + Tower: towerMock, + }, + }, + } + + // Cancel stream immediately + cancelStream() + + // Make sure the call returns. + err := server.TrackPayments(request, stream) + require.Equal(t, context.Canceled, err) } // TestTrackPaymentsInflightUpdate tests whether all updates from the control @@ -65,7 +104,11 @@ func TestTrackPaymentsInflightUpdates(t *testing.T) { NoInflightUpdates: false, } towerMock := makeControlTowerMock() - stream := makeStreamMock() + + streamCtx, cancelStream := context.WithCancel(context.Background()) + stream := makeStreamMock(streamCtx) + defer cancelStream() + server := &Server{ cfg: &Config{ RouterBackend: &RouterBackend{ @@ -77,7 +120,7 @@ func TestTrackPaymentsInflightUpdates(t *testing.T) { // Listen to payment updates in a goroutine. go func() { err := server.TrackPayments(request, stream) - require.NoError(t, err) + require.Equal(t, context.Canceled, err) }() // Enqueue some payment updates on the mock. @@ -119,11 +162,15 @@ func TestTrackPaymentsNoInflightUpdates(t *testing.T) { request := &TrackPaymentsRequest{ NoInflightUpdates: true, } - towerMock := &ControlTowerMock{ + towerMock := &controlTowerMock{ queue: queue.NewConcurrentQueue(20), } towerMock.queue.Start() - stream := makeStreamMock() + + streamCtx, cancelStream := context.WithCancel(context.Background()) + stream := makeStreamMock(streamCtx) + defer cancelStream() + server := &Server{ cfg: &Config{ RouterBackend: &RouterBackend{ @@ -135,7 +182,7 @@ func TestTrackPaymentsNoInflightUpdates(t *testing.T) { // Listen to payment updates in a goroutine. go func() { err := server.TrackPayments(request, stream) - require.NoError(t, err) + require.Equal(t, context.Canceled, err) }() // Enqueue some payment updates on the mock. diff --git a/routing/control_tower.go b/routing/control_tower.go index 5b9577e40..9f95bb252 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -60,40 +60,48 @@ type ControlTower interface { // SubscribePayment subscribes to updates for the payment with the given // hash. A first update with the current state of the payment is always // sent out immediately. - SubscribePayment(paymentHash lntypes.Hash) (*ControlTowerSubscriber, + SubscribePayment(paymentHash lntypes.Hash) (ControlTowerSubscriber, error) // SubscribeAllPayments subscribes to updates for all payments. A first // update with the current state of every inflight payment is always // sent out immediately. - SubscribeAllPayments() (*ControlTowerSubscriber, error) + SubscribeAllPayments() (ControlTowerSubscriber, error) } // ControlTowerSubscriber contains the state for a payment update subscriber. -type ControlTowerSubscriber struct { +type ControlTowerSubscriber interface { // Updates is the channel over which *channeldb.MPPayment updates can be // received. - Updates <-chan interface{} + Updates() <-chan interface{} - queue *queue.ConcurrentQueue - quit chan struct{} + // Close signals that the subscriber is no longer interested in updates. + Close() +} + +// ControlTowerSubscriberImpl contains the state for a payment update +// subscriber. +type controlTowerSubscriberImpl struct { + updates <-chan interface{} + queue *queue.ConcurrentQueue + quit chan struct{} } // newControlTowerSubscriber instantiates a new subscriber state object. -func newControlTowerSubscriber() *ControlTowerSubscriber { +func newControlTowerSubscriber() *controlTowerSubscriberImpl { // Create a queue for payment updates. queue := queue.NewConcurrentQueue(20) queue.Start() - return &ControlTowerSubscriber{ - Updates: queue.ChanOut(), + return &controlTowerSubscriberImpl{ + updates: queue.ChanOut(), queue: queue, quit: make(chan struct{}), } } // Close signals that the subscriber is no longer interested in updates. -func (s *ControlTowerSubscriber) Close() { +func (s *controlTowerSubscriberImpl) Close() { // Close quit channel so that any pending writes to the queue are // cancelled. close(s.quit) @@ -102,6 +110,12 @@ func (s *ControlTowerSubscriber) Close() { s.queue.Stop() } +// Updates is the channel over which *channeldb.MPPayment updates can be +// received. +func (s *controlTowerSubscriberImpl) Updates() <-chan interface{} { + return s.updates +} + // controlTower is persistent implementation of ControlTower to restrict // double payment sending. type controlTower struct { @@ -111,8 +125,8 @@ type controlTower struct { // to all payments. This is used to easily remove the subscriber when // necessary. subscriberIndex uint64 - subscribersAllPayments map[uint64]*ControlTowerSubscriber - subscribers map[lntypes.Hash][]*ControlTowerSubscriber + subscribersAllPayments map[uint64]*controlTowerSubscriberImpl + subscribers map[lntypes.Hash][]*controlTowerSubscriberImpl subscribersMtx sync.Mutex // paymentsMtx provides synchronization on the payment level to ensure @@ -126,9 +140,9 @@ func NewControlTower(db *channeldb.PaymentControl) ControlTower { return &controlTower{ db: db, subscribersAllPayments: make( - map[uint64]*ControlTowerSubscriber, + map[uint64]*controlTowerSubscriberImpl, ), - subscribers: make(map[lntypes.Hash][]*ControlTowerSubscriber), + subscribers: make(map[lntypes.Hash][]*controlTowerSubscriberImpl), paymentsMtx: multimutex.NewHashMutex(), } } @@ -245,7 +259,7 @@ func (p *controlTower) FetchInFlightPayments() ([]*channeldb.MPPayment, error) { // first update with the current state of the payment is always sent out // immediately. func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) ( - *ControlTowerSubscriber, error) { + ControlTowerSubscriber, error) { // Take lock before querying the db to prevent missing or duplicating an // update. @@ -286,7 +300,7 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) ( // of the payment stream could produce out-of-order and/or duplicate events. In // order to get updates for every in-flight payment attempt make sure to // subscribe to this method before initiating any payments. -func (p *controlTower) SubscribeAllPayments() (*ControlTowerSubscriber, error) { +func (p *controlTower) SubscribeAllPayments() (ControlTowerSubscriber, error) { subscriber := newControlTowerSubscriber() // Add the subscriber to the list before fetching in-flight payments, so @@ -337,7 +351,7 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash, // Copy subscribers to all payments locally while holding the lock in // order to avoid concurrency issues while reading/writing the map. - subscribersAllPayments := make(map[uint64]*ControlTowerSubscriber) + subscribersAllPayments := make(map[uint64]*controlTowerSubscriberImpl) for k, v := range p.subscribersAllPayments { subscribersAllPayments[k] = v } diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 19e4465c2..3dd2c5234 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -115,7 +115,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { // We expect all subscribers to now report the final outcome followed by // no other events. - subscribers := []*ControlTowerSubscriber{ + subscribers := []ControlTowerSubscriber{ subscriber1, subscriber2, subscriber3, } @@ -123,7 +123,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { var result *channeldb.MPPayment for result == nil || result.Status == channeldb.StatusInFlight { select { - case item := <-s.Updates: + case item := <-s.Updates(): result = item.(*channeldb.MPPayment) case <-time.After(testTimeout): t.Fatal("timeout waiting for payment result") @@ -149,7 +149,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { // After the final event, we expect the channel to be closed. select { - case _, ok := <-s.Updates: + case _, ok := <-s.Updates(): if ok { t.Fatal("expected channel to be closed") } @@ -248,7 +248,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { // After exactly 5 updates both payments will/should have completed. for i := 0; i < 5; i++ { select { - case item := <-subscription.Updates: + case item := <-subscription.Updates(): id := item.(*channeldb.MPPayment).Info.PaymentIdentifier results[id] = item.(*channeldb.MPPayment) case <-time.After(testTimeout): @@ -320,13 +320,13 @@ func TestPaymentControlSubscribeAllImmediate(t *testing.T) { // Assert the new subscription receives the old update. select { - case update := <-subscription.Updates: + case update := <-subscription.Updates(): require.NotNil(t, update) require.Equal( t, info.PaymentIdentifier, update.(*channeldb.MPPayment).Info.PaymentIdentifier, ) - require.Len(t, subscription.Updates, 0) + require.Len(t, subscription.Updates(), 0) case <-time.After(testTimeout): require.Fail(t, "timeout waiting for payment result") } @@ -361,14 +361,14 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) { // Assert all subscriptions receive the update. select { - case update1 := <-subscription1.Updates: + case update1 := <-subscription1.Updates(): require.NotNil(t, update1) case <-time.After(testTimeout): require.Fail(t, "timeout waiting for payment result") } select { - case update2 := <-subscription2.Updates: + case update2 := <-subscription2.Updates(): require.NotNil(t, update2) case <-time.After(testTimeout): require.Fail(t, "timeout waiting for payment result") @@ -388,13 +388,13 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) { // Assert only subscription 2 receives the update. select { - case update2 := <-subscription2.Updates: + case update2 := <-subscription2.Updates(): require.NotNil(t, update2) case <-time.After(testTimeout): require.Fail(t, "timeout waiting for payment result") } - require.Len(t, subscription1.Updates, 0) + require.Len(t, subscription1.Updates(), 0) // Close the second subscription. subscription2.Close() @@ -404,8 +404,8 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) { require.NoError(t, err) // Assert no subscriptions receive the update. - require.Len(t, subscription1.Updates, 0) - require.Len(t, subscription2.Updates, 0) + require.Len(t, subscription1.Updates(), 0) + require.Len(t, subscription2.Updates(), 0) } func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, @@ -467,7 +467,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, // We expect both subscribers to now report the final outcome followed // by no other events. - subscribers := []*ControlTowerSubscriber{ + subscribers := []ControlTowerSubscriber{ subscriber1, subscriber2, } @@ -475,7 +475,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, var result *channeldb.MPPayment for result == nil || result.Status == channeldb.StatusInFlight { select { - case item := <-s.Updates: + case item := <-s.Updates(): result = item.(*channeldb.MPPayment) case <-time.After(testTimeout): t.Fatal("timeout waiting for payment result") @@ -513,7 +513,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, // After the final event, we expect the channel to be closed. select { - case _, ok := <-s.Updates: + case _, ok := <-s.Updates(): if ok { t.Fatal("expected channel to be closed") } diff --git a/routing/mock_test.go b/routing/mock_test.go index ad7881cd3..b0a9ba4ff 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -552,13 +552,13 @@ func (m *mockControlTowerOld) FetchInFlightPayments() ( } func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) ( - *ControlTowerSubscriber, error) { + ControlTowerSubscriber, error) { return nil, errors.New("not implemented") } func (m *mockControlTowerOld) SubscribeAllPayments() ( - *ControlTowerSubscriber, error) { + ControlTowerSubscriber, error) { return nil, errors.New("not implemented") } @@ -774,17 +774,17 @@ func (m *mockControlTower) FetchInFlightPayments() ( } func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) ( - *ControlTowerSubscriber, error) { + ControlTowerSubscriber, error) { args := m.Called(paymentHash) - return args.Get(0).(*ControlTowerSubscriber), args.Error(1) + return args.Get(0).(ControlTowerSubscriber), args.Error(1) } func (m *mockControlTower) SubscribeAllPayments() ( - *ControlTowerSubscriber, error) { + ControlTowerSubscriber, error) { args := m.Called() - return args.Get(0).(*ControlTowerSubscriber), args.Error(1) + return args.Get(0).(ControlTowerSubscriber), args.Error(1) } type mockLink struct {