routing+routerrpc: test stream cancellation

Test stream cancellation of the TrackPayments rpc call. In order to achieve
this, ControlTowerSubscriber is converted to an interface, to avoid trying to
close a null channel when closing the subscription. By returning a mock
implementation of the ControlTowerSubscriber in the test that problem is
avoided.
This commit is contained in:
Jesse de Wit 2022-09-05 13:20:38 +02:00
parent 4bc3007668
commit 0266ab77ab
No known key found for this signature in database
GPG key ID: 78A9DCCE385AE6B4
5 changed files with 131 additions and 70 deletions

View file

@ -791,7 +791,7 @@ func (s *Server) TrackPayments(request *TrackPaymentsRequest,
// trackPaymentStream streams payment updates to the client.
func (s *Server) trackPaymentStream(context context.Context,
subscription *routing.ControlTowerSubscriber, noInflightUpdates bool,
subscription routing.ControlTowerSubscriber, noInflightUpdates bool,
send func(*lnrpc.Payment) error) error {
defer subscription.Close()
@ -799,7 +799,7 @@ func (s *Server) trackPaymentStream(context context.Context,
// Stream updates back to the client.
for {
select {
case item, ok := <-subscription.Updates:
case item, ok := <-subscription.Updates():
if !ok {
// No more payment updates.
return nil

View file

@ -13,21 +13,46 @@ import (
"google.golang.org/grpc"
)
func makeStreamMock() *StreamMock {
return &StreamMock{
ctx: context.Background(),
sentFromServer: make(chan *lnrpc.Payment, 10),
}
}
type StreamMock struct {
type streamMock struct {
grpc.ServerStream
ctx context.Context
sentFromServer chan *lnrpc.Payment
}
func makeControlTowerMock() *ControlTowerMock {
towerMock := &ControlTowerMock{
func makeStreamMock(ctx context.Context) *streamMock {
return &streamMock{
ctx: ctx,
sentFromServer: make(chan *lnrpc.Payment, 10),
}
}
func (m *streamMock) Context() context.Context {
return m.ctx
}
func (m *streamMock) Send(p *lnrpc.Payment) error {
m.sentFromServer <- p
return nil
}
type controlTowerSubscriberMock struct {
updates <-chan interface{}
}
func (s controlTowerSubscriberMock) Updates() <-chan interface{} {
return s.updates
}
func (s controlTowerSubscriberMock) Close() {
}
type controlTowerMock struct {
queue *queue.ConcurrentQueue
routing.ControlTower
}
func makeControlTowerMock() *controlTowerMock {
towerMock := &controlTowerMock{
queue: queue.NewConcurrentQueue(20),
}
towerMock.queue.Start()
@ -35,26 +60,40 @@ func makeControlTowerMock() *ControlTowerMock {
return towerMock
}
type ControlTowerMock struct {
queue *queue.ConcurrentQueue
routing.ControlTower
}
func (t *controlTowerMock) SubscribeAllPayments() (
routing.ControlTowerSubscriber, error) {
func (t *ControlTowerMock) SubscribeAllPayments() (
*routing.ControlTowerSubscriber, error) {
return &routing.ControlTowerSubscriber{
Updates: t.queue.ChanOut(),
return &controlTowerSubscriberMock{
updates: t.queue.ChanOut(),
}, nil
}
func (m *StreamMock) Context() context.Context {
return m.ctx
}
// TestTrackPaymentsReturnsOnCancelContext tests whether TrackPayments returns
// when the stream context is cancelled.
func TestTrackPaymentsReturnsOnCancelContext(t *testing.T) {
// Setup mocks and request.
request := &TrackPaymentsRequest{
NoInflightUpdates: false,
}
towerMock := makeControlTowerMock()
func (m *StreamMock) Send(p *lnrpc.Payment) error {
m.sentFromServer <- p
return nil
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
Tower: towerMock,
},
},
}
// Cancel stream immediately
cancelStream()
// Make sure the call returns.
err := server.TrackPayments(request, stream)
require.Equal(t, context.Canceled, err)
}
// TestTrackPaymentsInflightUpdate tests whether all updates from the control
@ -65,7 +104,11 @@ func TestTrackPaymentsInflightUpdates(t *testing.T) {
NoInflightUpdates: false,
}
towerMock := makeControlTowerMock()
stream := makeStreamMock()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
defer cancelStream()
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
@ -77,7 +120,7 @@ func TestTrackPaymentsInflightUpdates(t *testing.T) {
// Listen to payment updates in a goroutine.
go func() {
err := server.TrackPayments(request, stream)
require.NoError(t, err)
require.Equal(t, context.Canceled, err)
}()
// Enqueue some payment updates on the mock.
@ -119,11 +162,15 @@ func TestTrackPaymentsNoInflightUpdates(t *testing.T) {
request := &TrackPaymentsRequest{
NoInflightUpdates: true,
}
towerMock := &ControlTowerMock{
towerMock := &controlTowerMock{
queue: queue.NewConcurrentQueue(20),
}
towerMock.queue.Start()
stream := makeStreamMock()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
defer cancelStream()
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
@ -135,7 +182,7 @@ func TestTrackPaymentsNoInflightUpdates(t *testing.T) {
// Listen to payment updates in a goroutine.
go func() {
err := server.TrackPayments(request, stream)
require.NoError(t, err)
require.Equal(t, context.Canceled, err)
}()
// Enqueue some payment updates on the mock.

View file

@ -60,40 +60,48 @@ type ControlTower interface {
// SubscribePayment subscribes to updates for the payment with the given
// hash. A first update with the current state of the payment is always
// sent out immediately.
SubscribePayment(paymentHash lntypes.Hash) (*ControlTowerSubscriber,
SubscribePayment(paymentHash lntypes.Hash) (ControlTowerSubscriber,
error)
// SubscribeAllPayments subscribes to updates for all payments. A first
// update with the current state of every inflight payment is always
// sent out immediately.
SubscribeAllPayments() (*ControlTowerSubscriber, error)
SubscribeAllPayments() (ControlTowerSubscriber, error)
}
// ControlTowerSubscriber contains the state for a payment update subscriber.
type ControlTowerSubscriber struct {
type ControlTowerSubscriber interface {
// Updates is the channel over which *channeldb.MPPayment updates can be
// received.
Updates <-chan interface{}
Updates() <-chan interface{}
queue *queue.ConcurrentQueue
quit chan struct{}
// Close signals that the subscriber is no longer interested in updates.
Close()
}
// ControlTowerSubscriberImpl contains the state for a payment update
// subscriber.
type controlTowerSubscriberImpl struct {
updates <-chan interface{}
queue *queue.ConcurrentQueue
quit chan struct{}
}
// newControlTowerSubscriber instantiates a new subscriber state object.
func newControlTowerSubscriber() *ControlTowerSubscriber {
func newControlTowerSubscriber() *controlTowerSubscriberImpl {
// Create a queue for payment updates.
queue := queue.NewConcurrentQueue(20)
queue.Start()
return &ControlTowerSubscriber{
Updates: queue.ChanOut(),
return &controlTowerSubscriberImpl{
updates: queue.ChanOut(),
queue: queue,
quit: make(chan struct{}),
}
}
// Close signals that the subscriber is no longer interested in updates.
func (s *ControlTowerSubscriber) Close() {
func (s *controlTowerSubscriberImpl) Close() {
// Close quit channel so that any pending writes to the queue are
// cancelled.
close(s.quit)
@ -102,6 +110,12 @@ func (s *ControlTowerSubscriber) Close() {
s.queue.Stop()
}
// Updates is the channel over which *channeldb.MPPayment updates can be
// received.
func (s *controlTowerSubscriberImpl) Updates() <-chan interface{} {
return s.updates
}
// controlTower is persistent implementation of ControlTower to restrict
// double payment sending.
type controlTower struct {
@ -111,8 +125,8 @@ type controlTower struct {
// to all payments. This is used to easily remove the subscriber when
// necessary.
subscriberIndex uint64
subscribersAllPayments map[uint64]*ControlTowerSubscriber
subscribers map[lntypes.Hash][]*ControlTowerSubscriber
subscribersAllPayments map[uint64]*controlTowerSubscriberImpl
subscribers map[lntypes.Hash][]*controlTowerSubscriberImpl
subscribersMtx sync.Mutex
// paymentsMtx provides synchronization on the payment level to ensure
@ -126,9 +140,9 @@ func NewControlTower(db *channeldb.PaymentControl) ControlTower {
return &controlTower{
db: db,
subscribersAllPayments: make(
map[uint64]*ControlTowerSubscriber,
map[uint64]*controlTowerSubscriberImpl,
),
subscribers: make(map[lntypes.Hash][]*ControlTowerSubscriber),
subscribers: make(map[lntypes.Hash][]*controlTowerSubscriberImpl),
paymentsMtx: multimutex.NewHashMutex(),
}
}
@ -245,7 +259,7 @@ func (p *controlTower) FetchInFlightPayments() ([]*channeldb.MPPayment, error) {
// first update with the current state of the payment is always sent out
// immediately.
func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) {
ControlTowerSubscriber, error) {
// Take lock before querying the db to prevent missing or duplicating an
// update.
@ -286,7 +300,7 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
// of the payment stream could produce out-of-order and/or duplicate events. In
// order to get updates for every in-flight payment attempt make sure to
// subscribe to this method before initiating any payments.
func (p *controlTower) SubscribeAllPayments() (*ControlTowerSubscriber, error) {
func (p *controlTower) SubscribeAllPayments() (ControlTowerSubscriber, error) {
subscriber := newControlTowerSubscriber()
// Add the subscriber to the list before fetching in-flight payments, so
@ -337,7 +351,7 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
// Copy subscribers to all payments locally while holding the lock in
// order to avoid concurrency issues while reading/writing the map.
subscribersAllPayments := make(map[uint64]*ControlTowerSubscriber)
subscribersAllPayments := make(map[uint64]*controlTowerSubscriberImpl)
for k, v := range p.subscribersAllPayments {
subscribersAllPayments[k] = v
}

View file

@ -115,7 +115,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
// We expect all subscribers to now report the final outcome followed by
// no other events.
subscribers := []*ControlTowerSubscriber{
subscribers := []ControlTowerSubscriber{
subscriber1, subscriber2, subscriber3,
}
@ -123,7 +123,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
var result *channeldb.MPPayment
for result == nil || result.Status == channeldb.StatusInFlight {
select {
case item := <-s.Updates:
case item := <-s.Updates():
result = item.(*channeldb.MPPayment)
case <-time.After(testTimeout):
t.Fatal("timeout waiting for payment result")
@ -149,7 +149,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
// After the final event, we expect the channel to be closed.
select {
case _, ok := <-s.Updates:
case _, ok := <-s.Updates():
if ok {
t.Fatal("expected channel to be closed")
}
@ -248,7 +248,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) {
// After exactly 5 updates both payments will/should have completed.
for i := 0; i < 5; i++ {
select {
case item := <-subscription.Updates:
case item := <-subscription.Updates():
id := item.(*channeldb.MPPayment).Info.PaymentIdentifier
results[id] = item.(*channeldb.MPPayment)
case <-time.After(testTimeout):
@ -320,13 +320,13 @@ func TestPaymentControlSubscribeAllImmediate(t *testing.T) {
// Assert the new subscription receives the old update.
select {
case update := <-subscription.Updates:
case update := <-subscription.Updates():
require.NotNil(t, update)
require.Equal(
t, info.PaymentIdentifier,
update.(*channeldb.MPPayment).Info.PaymentIdentifier,
)
require.Len(t, subscription.Updates, 0)
require.Len(t, subscription.Updates(), 0)
case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result")
}
@ -361,14 +361,14 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) {
// Assert all subscriptions receive the update.
select {
case update1 := <-subscription1.Updates:
case update1 := <-subscription1.Updates():
require.NotNil(t, update1)
case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result")
}
select {
case update2 := <-subscription2.Updates:
case update2 := <-subscription2.Updates():
require.NotNil(t, update2)
case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result")
@ -388,13 +388,13 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) {
// Assert only subscription 2 receives the update.
select {
case update2 := <-subscription2.Updates:
case update2 := <-subscription2.Updates():
require.NotNil(t, update2)
case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result")
}
require.Len(t, subscription1.Updates, 0)
require.Len(t, subscription1.Updates(), 0)
// Close the second subscription.
subscription2.Close()
@ -404,8 +404,8 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) {
require.NoError(t, err)
// Assert no subscriptions receive the update.
require.Len(t, subscription1.Updates, 0)
require.Len(t, subscription2.Updates, 0)
require.Len(t, subscription1.Updates(), 0)
require.Len(t, subscription2.Updates(), 0)
}
func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
@ -467,7 +467,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
// We expect both subscribers to now report the final outcome followed
// by no other events.
subscribers := []*ControlTowerSubscriber{
subscribers := []ControlTowerSubscriber{
subscriber1, subscriber2,
}
@ -475,7 +475,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
var result *channeldb.MPPayment
for result == nil || result.Status == channeldb.StatusInFlight {
select {
case item := <-s.Updates:
case item := <-s.Updates():
result = item.(*channeldb.MPPayment)
case <-time.After(testTimeout):
t.Fatal("timeout waiting for payment result")
@ -513,7 +513,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
// After the final event, we expect the channel to be closed.
select {
case _, ok := <-s.Updates:
case _, ok := <-s.Updates():
if ok {
t.Fatal("expected channel to be closed")
}

View file

@ -552,13 +552,13 @@ func (m *mockControlTowerOld) FetchInFlightPayments() (
}
func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) {
ControlTowerSubscriber, error) {
return nil, errors.New("not implemented")
}
func (m *mockControlTowerOld) SubscribeAllPayments() (
*ControlTowerSubscriber, error) {
ControlTowerSubscriber, error) {
return nil, errors.New("not implemented")
}
@ -774,17 +774,17 @@ func (m *mockControlTower) FetchInFlightPayments() (
}
func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) {
ControlTowerSubscriber, error) {
args := m.Called(paymentHash)
return args.Get(0).(*ControlTowerSubscriber), args.Error(1)
return args.Get(0).(ControlTowerSubscriber), args.Error(1)
}
func (m *mockControlTower) SubscribeAllPayments() (
*ControlTowerSubscriber, error) {
ControlTowerSubscriber, error) {
args := m.Called()
return args.Get(0).(*ControlTowerSubscriber), args.Error(1)
return args.Get(0).(ControlTowerSubscriber), args.Error(1)
}
type mockLink struct {