routing+routerrpc: test stream cancellation

Test stream cancellation of the TrackPayments rpc call. In order to achieve
this, ControlTowerSubscriber is converted to an interface, to avoid trying to
close a null channel when closing the subscription. By returning a mock
implementation of the ControlTowerSubscriber in the test that problem is
avoided.
This commit is contained in:
Jesse de Wit 2022-09-05 13:20:38 +02:00
parent 4bc3007668
commit 0266ab77ab
No known key found for this signature in database
GPG key ID: 78A9DCCE385AE6B4
5 changed files with 131 additions and 70 deletions

View file

@ -791,7 +791,7 @@ func (s *Server) TrackPayments(request *TrackPaymentsRequest,
// trackPaymentStream streams payment updates to the client. // trackPaymentStream streams payment updates to the client.
func (s *Server) trackPaymentStream(context context.Context, func (s *Server) trackPaymentStream(context context.Context,
subscription *routing.ControlTowerSubscriber, noInflightUpdates bool, subscription routing.ControlTowerSubscriber, noInflightUpdates bool,
send func(*lnrpc.Payment) error) error { send func(*lnrpc.Payment) error) error {
defer subscription.Close() defer subscription.Close()
@ -799,7 +799,7 @@ func (s *Server) trackPaymentStream(context context.Context,
// Stream updates back to the client. // Stream updates back to the client.
for { for {
select { select {
case item, ok := <-subscription.Updates: case item, ok := <-subscription.Updates():
if !ok { if !ok {
// No more payment updates. // No more payment updates.
return nil return nil

View file

@ -13,21 +13,46 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
) )
func makeStreamMock() *StreamMock { type streamMock struct {
return &StreamMock{
ctx: context.Background(),
sentFromServer: make(chan *lnrpc.Payment, 10),
}
}
type StreamMock struct {
grpc.ServerStream grpc.ServerStream
ctx context.Context ctx context.Context
sentFromServer chan *lnrpc.Payment sentFromServer chan *lnrpc.Payment
} }
func makeControlTowerMock() *ControlTowerMock { func makeStreamMock(ctx context.Context) *streamMock {
towerMock := &ControlTowerMock{ return &streamMock{
ctx: ctx,
sentFromServer: make(chan *lnrpc.Payment, 10),
}
}
func (m *streamMock) Context() context.Context {
return m.ctx
}
func (m *streamMock) Send(p *lnrpc.Payment) error {
m.sentFromServer <- p
return nil
}
type controlTowerSubscriberMock struct {
updates <-chan interface{}
}
func (s controlTowerSubscriberMock) Updates() <-chan interface{} {
return s.updates
}
func (s controlTowerSubscriberMock) Close() {
}
type controlTowerMock struct {
queue *queue.ConcurrentQueue
routing.ControlTower
}
func makeControlTowerMock() *controlTowerMock {
towerMock := &controlTowerMock{
queue: queue.NewConcurrentQueue(20), queue: queue.NewConcurrentQueue(20),
} }
towerMock.queue.Start() towerMock.queue.Start()
@ -35,26 +60,40 @@ func makeControlTowerMock() *ControlTowerMock {
return towerMock return towerMock
} }
type ControlTowerMock struct { func (t *controlTowerMock) SubscribeAllPayments() (
queue *queue.ConcurrentQueue routing.ControlTowerSubscriber, error) {
routing.ControlTower
}
func (t *ControlTowerMock) SubscribeAllPayments() ( return &controlTowerSubscriberMock{
*routing.ControlTowerSubscriber, error) { updates: t.queue.ChanOut(),
return &routing.ControlTowerSubscriber{
Updates: t.queue.ChanOut(),
}, nil }, nil
} }
func (m *StreamMock) Context() context.Context { // TestTrackPaymentsReturnsOnCancelContext tests whether TrackPayments returns
return m.ctx // when the stream context is cancelled.
} func TestTrackPaymentsReturnsOnCancelContext(t *testing.T) {
// Setup mocks and request.
request := &TrackPaymentsRequest{
NoInflightUpdates: false,
}
towerMock := makeControlTowerMock()
func (m *StreamMock) Send(p *lnrpc.Payment) error { streamCtx, cancelStream := context.WithCancel(context.Background())
m.sentFromServer <- p stream := makeStreamMock(streamCtx)
return nil
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
Tower: towerMock,
},
},
}
// Cancel stream immediately
cancelStream()
// Make sure the call returns.
err := server.TrackPayments(request, stream)
require.Equal(t, context.Canceled, err)
} }
// TestTrackPaymentsInflightUpdate tests whether all updates from the control // TestTrackPaymentsInflightUpdate tests whether all updates from the control
@ -65,7 +104,11 @@ func TestTrackPaymentsInflightUpdates(t *testing.T) {
NoInflightUpdates: false, NoInflightUpdates: false,
} }
towerMock := makeControlTowerMock() towerMock := makeControlTowerMock()
stream := makeStreamMock()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
defer cancelStream()
server := &Server{ server := &Server{
cfg: &Config{ cfg: &Config{
RouterBackend: &RouterBackend{ RouterBackend: &RouterBackend{
@ -77,7 +120,7 @@ func TestTrackPaymentsInflightUpdates(t *testing.T) {
// Listen to payment updates in a goroutine. // Listen to payment updates in a goroutine.
go func() { go func() {
err := server.TrackPayments(request, stream) err := server.TrackPayments(request, stream)
require.NoError(t, err) require.Equal(t, context.Canceled, err)
}() }()
// Enqueue some payment updates on the mock. // Enqueue some payment updates on the mock.
@ -119,11 +162,15 @@ func TestTrackPaymentsNoInflightUpdates(t *testing.T) {
request := &TrackPaymentsRequest{ request := &TrackPaymentsRequest{
NoInflightUpdates: true, NoInflightUpdates: true,
} }
towerMock := &ControlTowerMock{ towerMock := &controlTowerMock{
queue: queue.NewConcurrentQueue(20), queue: queue.NewConcurrentQueue(20),
} }
towerMock.queue.Start() towerMock.queue.Start()
stream := makeStreamMock()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
defer cancelStream()
server := &Server{ server := &Server{
cfg: &Config{ cfg: &Config{
RouterBackend: &RouterBackend{ RouterBackend: &RouterBackend{
@ -135,7 +182,7 @@ func TestTrackPaymentsNoInflightUpdates(t *testing.T) {
// Listen to payment updates in a goroutine. // Listen to payment updates in a goroutine.
go func() { go func() {
err := server.TrackPayments(request, stream) err := server.TrackPayments(request, stream)
require.NoError(t, err) require.Equal(t, context.Canceled, err)
}() }()
// Enqueue some payment updates on the mock. // Enqueue some payment updates on the mock.

View file

@ -60,40 +60,48 @@ type ControlTower interface {
// SubscribePayment subscribes to updates for the payment with the given // SubscribePayment subscribes to updates for the payment with the given
// hash. A first update with the current state of the payment is always // hash. A first update with the current state of the payment is always
// sent out immediately. // sent out immediately.
SubscribePayment(paymentHash lntypes.Hash) (*ControlTowerSubscriber, SubscribePayment(paymentHash lntypes.Hash) (ControlTowerSubscriber,
error) error)
// SubscribeAllPayments subscribes to updates for all payments. A first // SubscribeAllPayments subscribes to updates for all payments. A first
// update with the current state of every inflight payment is always // update with the current state of every inflight payment is always
// sent out immediately. // sent out immediately.
SubscribeAllPayments() (*ControlTowerSubscriber, error) SubscribeAllPayments() (ControlTowerSubscriber, error)
} }
// ControlTowerSubscriber contains the state for a payment update subscriber. // ControlTowerSubscriber contains the state for a payment update subscriber.
type ControlTowerSubscriber struct { type ControlTowerSubscriber interface {
// Updates is the channel over which *channeldb.MPPayment updates can be // Updates is the channel over which *channeldb.MPPayment updates can be
// received. // received.
Updates <-chan interface{} Updates() <-chan interface{}
// Close signals that the subscriber is no longer interested in updates.
Close()
}
// ControlTowerSubscriberImpl contains the state for a payment update
// subscriber.
type controlTowerSubscriberImpl struct {
updates <-chan interface{}
queue *queue.ConcurrentQueue queue *queue.ConcurrentQueue
quit chan struct{} quit chan struct{}
} }
// newControlTowerSubscriber instantiates a new subscriber state object. // newControlTowerSubscriber instantiates a new subscriber state object.
func newControlTowerSubscriber() *ControlTowerSubscriber { func newControlTowerSubscriber() *controlTowerSubscriberImpl {
// Create a queue for payment updates. // Create a queue for payment updates.
queue := queue.NewConcurrentQueue(20) queue := queue.NewConcurrentQueue(20)
queue.Start() queue.Start()
return &ControlTowerSubscriber{ return &controlTowerSubscriberImpl{
Updates: queue.ChanOut(), updates: queue.ChanOut(),
queue: queue, queue: queue,
quit: make(chan struct{}), quit: make(chan struct{}),
} }
} }
// Close signals that the subscriber is no longer interested in updates. // Close signals that the subscriber is no longer interested in updates.
func (s *ControlTowerSubscriber) Close() { func (s *controlTowerSubscriberImpl) Close() {
// Close quit channel so that any pending writes to the queue are // Close quit channel so that any pending writes to the queue are
// cancelled. // cancelled.
close(s.quit) close(s.quit)
@ -102,6 +110,12 @@ func (s *ControlTowerSubscriber) Close() {
s.queue.Stop() s.queue.Stop()
} }
// Updates is the channel over which *channeldb.MPPayment updates can be
// received.
func (s *controlTowerSubscriberImpl) Updates() <-chan interface{} {
return s.updates
}
// controlTower is persistent implementation of ControlTower to restrict // controlTower is persistent implementation of ControlTower to restrict
// double payment sending. // double payment sending.
type controlTower struct { type controlTower struct {
@ -111,8 +125,8 @@ type controlTower struct {
// to all payments. This is used to easily remove the subscriber when // to all payments. This is used to easily remove the subscriber when
// necessary. // necessary.
subscriberIndex uint64 subscriberIndex uint64
subscribersAllPayments map[uint64]*ControlTowerSubscriber subscribersAllPayments map[uint64]*controlTowerSubscriberImpl
subscribers map[lntypes.Hash][]*ControlTowerSubscriber subscribers map[lntypes.Hash][]*controlTowerSubscriberImpl
subscribersMtx sync.Mutex subscribersMtx sync.Mutex
// paymentsMtx provides synchronization on the payment level to ensure // paymentsMtx provides synchronization on the payment level to ensure
@ -126,9 +140,9 @@ func NewControlTower(db *channeldb.PaymentControl) ControlTower {
return &controlTower{ return &controlTower{
db: db, db: db,
subscribersAllPayments: make( subscribersAllPayments: make(
map[uint64]*ControlTowerSubscriber, map[uint64]*controlTowerSubscriberImpl,
), ),
subscribers: make(map[lntypes.Hash][]*ControlTowerSubscriber), subscribers: make(map[lntypes.Hash][]*controlTowerSubscriberImpl),
paymentsMtx: multimutex.NewHashMutex(), paymentsMtx: multimutex.NewHashMutex(),
} }
} }
@ -245,7 +259,7 @@ func (p *controlTower) FetchInFlightPayments() ([]*channeldb.MPPayment, error) {
// first update with the current state of the payment is always sent out // first update with the current state of the payment is always sent out
// immediately. // immediately.
func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) ( func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) { ControlTowerSubscriber, error) {
// Take lock before querying the db to prevent missing or duplicating an // Take lock before querying the db to prevent missing or duplicating an
// update. // update.
@ -286,7 +300,7 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
// of the payment stream could produce out-of-order and/or duplicate events. In // of the payment stream could produce out-of-order and/or duplicate events. In
// order to get updates for every in-flight payment attempt make sure to // order to get updates for every in-flight payment attempt make sure to
// subscribe to this method before initiating any payments. // subscribe to this method before initiating any payments.
func (p *controlTower) SubscribeAllPayments() (*ControlTowerSubscriber, error) { func (p *controlTower) SubscribeAllPayments() (ControlTowerSubscriber, error) {
subscriber := newControlTowerSubscriber() subscriber := newControlTowerSubscriber()
// Add the subscriber to the list before fetching in-flight payments, so // Add the subscriber to the list before fetching in-flight payments, so
@ -337,7 +351,7 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
// Copy subscribers to all payments locally while holding the lock in // Copy subscribers to all payments locally while holding the lock in
// order to avoid concurrency issues while reading/writing the map. // order to avoid concurrency issues while reading/writing the map.
subscribersAllPayments := make(map[uint64]*ControlTowerSubscriber) subscribersAllPayments := make(map[uint64]*controlTowerSubscriberImpl)
for k, v := range p.subscribersAllPayments { for k, v := range p.subscribersAllPayments {
subscribersAllPayments[k] = v subscribersAllPayments[k] = v
} }

View file

@ -115,7 +115,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
// We expect all subscribers to now report the final outcome followed by // We expect all subscribers to now report the final outcome followed by
// no other events. // no other events.
subscribers := []*ControlTowerSubscriber{ subscribers := []ControlTowerSubscriber{
subscriber1, subscriber2, subscriber3, subscriber1, subscriber2, subscriber3,
} }
@ -123,7 +123,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
var result *channeldb.MPPayment var result *channeldb.MPPayment
for result == nil || result.Status == channeldb.StatusInFlight { for result == nil || result.Status == channeldb.StatusInFlight {
select { select {
case item := <-s.Updates: case item := <-s.Updates():
result = item.(*channeldb.MPPayment) result = item.(*channeldb.MPPayment)
case <-time.After(testTimeout): case <-time.After(testTimeout):
t.Fatal("timeout waiting for payment result") t.Fatal("timeout waiting for payment result")
@ -149,7 +149,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
// After the final event, we expect the channel to be closed. // After the final event, we expect the channel to be closed.
select { select {
case _, ok := <-s.Updates: case _, ok := <-s.Updates():
if ok { if ok {
t.Fatal("expected channel to be closed") t.Fatal("expected channel to be closed")
} }
@ -248,7 +248,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) {
// After exactly 5 updates both payments will/should have completed. // After exactly 5 updates both payments will/should have completed.
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
select { select {
case item := <-subscription.Updates: case item := <-subscription.Updates():
id := item.(*channeldb.MPPayment).Info.PaymentIdentifier id := item.(*channeldb.MPPayment).Info.PaymentIdentifier
results[id] = item.(*channeldb.MPPayment) results[id] = item.(*channeldb.MPPayment)
case <-time.After(testTimeout): case <-time.After(testTimeout):
@ -320,13 +320,13 @@ func TestPaymentControlSubscribeAllImmediate(t *testing.T) {
// Assert the new subscription receives the old update. // Assert the new subscription receives the old update.
select { select {
case update := <-subscription.Updates: case update := <-subscription.Updates():
require.NotNil(t, update) require.NotNil(t, update)
require.Equal( require.Equal(
t, info.PaymentIdentifier, t, info.PaymentIdentifier,
update.(*channeldb.MPPayment).Info.PaymentIdentifier, update.(*channeldb.MPPayment).Info.PaymentIdentifier,
) )
require.Len(t, subscription.Updates, 0) require.Len(t, subscription.Updates(), 0)
case <-time.After(testTimeout): case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result") require.Fail(t, "timeout waiting for payment result")
} }
@ -361,14 +361,14 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) {
// Assert all subscriptions receive the update. // Assert all subscriptions receive the update.
select { select {
case update1 := <-subscription1.Updates: case update1 := <-subscription1.Updates():
require.NotNil(t, update1) require.NotNil(t, update1)
case <-time.After(testTimeout): case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result") require.Fail(t, "timeout waiting for payment result")
} }
select { select {
case update2 := <-subscription2.Updates: case update2 := <-subscription2.Updates():
require.NotNil(t, update2) require.NotNil(t, update2)
case <-time.After(testTimeout): case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result") require.Fail(t, "timeout waiting for payment result")
@ -388,13 +388,13 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) {
// Assert only subscription 2 receives the update. // Assert only subscription 2 receives the update.
select { select {
case update2 := <-subscription2.Updates: case update2 := <-subscription2.Updates():
require.NotNil(t, update2) require.NotNil(t, update2)
case <-time.After(testTimeout): case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result") require.Fail(t, "timeout waiting for payment result")
} }
require.Len(t, subscription1.Updates, 0) require.Len(t, subscription1.Updates(), 0)
// Close the second subscription. // Close the second subscription.
subscription2.Close() subscription2.Close()
@ -404,8 +404,8 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Assert no subscriptions receive the update. // Assert no subscriptions receive the update.
require.Len(t, subscription1.Updates, 0) require.Len(t, subscription1.Updates(), 0)
require.Len(t, subscription2.Updates, 0) require.Len(t, subscription2.Updates(), 0)
} }
func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
@ -467,7 +467,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
// We expect both subscribers to now report the final outcome followed // We expect both subscribers to now report the final outcome followed
// by no other events. // by no other events.
subscribers := []*ControlTowerSubscriber{ subscribers := []ControlTowerSubscriber{
subscriber1, subscriber2, subscriber1, subscriber2,
} }
@ -475,7 +475,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
var result *channeldb.MPPayment var result *channeldb.MPPayment
for result == nil || result.Status == channeldb.StatusInFlight { for result == nil || result.Status == channeldb.StatusInFlight {
select { select {
case item := <-s.Updates: case item := <-s.Updates():
result = item.(*channeldb.MPPayment) result = item.(*channeldb.MPPayment)
case <-time.After(testTimeout): case <-time.After(testTimeout):
t.Fatal("timeout waiting for payment result") t.Fatal("timeout waiting for payment result")
@ -513,7 +513,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
// After the final event, we expect the channel to be closed. // After the final event, we expect the channel to be closed.
select { select {
case _, ok := <-s.Updates: case _, ok := <-s.Updates():
if ok { if ok {
t.Fatal("expected channel to be closed") t.Fatal("expected channel to be closed")
} }

View file

@ -552,13 +552,13 @@ func (m *mockControlTowerOld) FetchInFlightPayments() (
} }
func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) ( func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) { ControlTowerSubscriber, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (m *mockControlTowerOld) SubscribeAllPayments() ( func (m *mockControlTowerOld) SubscribeAllPayments() (
*ControlTowerSubscriber, error) { ControlTowerSubscriber, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
@ -774,17 +774,17 @@ func (m *mockControlTower) FetchInFlightPayments() (
} }
func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) ( func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) { ControlTowerSubscriber, error) {
args := m.Called(paymentHash) args := m.Called(paymentHash)
return args.Get(0).(*ControlTowerSubscriber), args.Error(1) return args.Get(0).(ControlTowerSubscriber), args.Error(1)
} }
func (m *mockControlTower) SubscribeAllPayments() ( func (m *mockControlTower) SubscribeAllPayments() (
*ControlTowerSubscriber, error) { ControlTowerSubscriber, error) {
args := m.Called() args := m.Called()
return args.Get(0).(*ControlTowerSubscriber), args.Error(1) return args.Get(0).(ControlTowerSubscriber), args.Error(1)
} }
type mockLink struct { type mockLink struct {