lnd/lnrpc/routerrpc/router_server_test.go
Jesse de Wit 0266ab77ab
routing+routerrpc: test stream cancellation
Test stream cancellation of the TrackPayments rpc call. In order to achieve
this, ControlTowerSubscriber is converted to an interface, to avoid trying to
close a null channel when closing the subscription. By returning a mock
implementation of the ControlTowerSubscriber in the test that problem is
avoided.
2022-09-10 21:53:01 +02:00

217 lines
5.1 KiB
Go

package routerrpc
import (
"context"
"testing"
"time"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/routing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
type streamMock struct {
grpc.ServerStream
ctx context.Context
sentFromServer chan *lnrpc.Payment
}
func makeStreamMock(ctx context.Context) *streamMock {
return &streamMock{
ctx: ctx,
sentFromServer: make(chan *lnrpc.Payment, 10),
}
}
func (m *streamMock) Context() context.Context {
return m.ctx
}
func (m *streamMock) Send(p *lnrpc.Payment) error {
m.sentFromServer <- p
return nil
}
type controlTowerSubscriberMock struct {
updates <-chan interface{}
}
func (s controlTowerSubscriberMock) Updates() <-chan interface{} {
return s.updates
}
func (s controlTowerSubscriberMock) Close() {
}
type controlTowerMock struct {
queue *queue.ConcurrentQueue
routing.ControlTower
}
func makeControlTowerMock() *controlTowerMock {
towerMock := &controlTowerMock{
queue: queue.NewConcurrentQueue(20),
}
towerMock.queue.Start()
return towerMock
}
func (t *controlTowerMock) SubscribeAllPayments() (
routing.ControlTowerSubscriber, error) {
return &controlTowerSubscriberMock{
updates: t.queue.ChanOut(),
}, nil
}
// TestTrackPaymentsReturnsOnCancelContext tests whether TrackPayments returns
// when the stream context is cancelled.
func TestTrackPaymentsReturnsOnCancelContext(t *testing.T) {
// Setup mocks and request.
request := &TrackPaymentsRequest{
NoInflightUpdates: false,
}
towerMock := makeControlTowerMock()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
Tower: towerMock,
},
},
}
// Cancel stream immediately
cancelStream()
// Make sure the call returns.
err := server.TrackPayments(request, stream)
require.Equal(t, context.Canceled, err)
}
// TestTrackPaymentsInflightUpdate tests whether all updates from the control
// tower are propagated to the client.
func TestTrackPaymentsInflightUpdates(t *testing.T) {
// Setup mocks and request.
request := &TrackPaymentsRequest{
NoInflightUpdates: false,
}
towerMock := makeControlTowerMock()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
defer cancelStream()
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
Tower: towerMock,
},
},
}
// Listen to payment updates in a goroutine.
go func() {
err := server.TrackPayments(request, stream)
require.Equal(t, context.Canceled, err)
}()
// Enqueue some payment updates on the mock.
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{},
Status: channeldb.StatusInFlight,
}
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{},
Status: channeldb.StatusSucceeded,
}
// Wait until there's 2 updates or the deadline is exceeded.
deadline := time.Now().Add(1 * time.Second)
for {
if len(stream.sentFromServer) == 2 {
break
}
if time.Now().After(deadline) {
require.FailNow(t, "deadline exceeded.")
}
}
// Both updates should be sent to the client.
require.Len(t, stream.sentFromServer, 2)
// The updates should be in the right order.
payment := <-stream.sentFromServer
require.Equal(t, lnrpc.Payment_IN_FLIGHT, payment.Status)
payment = <-stream.sentFromServer
require.Equal(t, lnrpc.Payment_SUCCEEDED, payment.Status)
}
// TestTrackPaymentsInflightUpdate tests whether only final updates from the
// control tower are propagated to the client when noInflightUpdates = true.
func TestTrackPaymentsNoInflightUpdates(t *testing.T) {
// Setup mocks and request.
request := &TrackPaymentsRequest{
NoInflightUpdates: true,
}
towerMock := &controlTowerMock{
queue: queue.NewConcurrentQueue(20),
}
towerMock.queue.Start()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
defer cancelStream()
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
Tower: towerMock,
},
},
}
// Listen to payment updates in a goroutine.
go func() {
err := server.TrackPayments(request, stream)
require.Equal(t, context.Canceled, err)
}()
// Enqueue some payment updates on the mock.
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{},
Status: channeldb.StatusInFlight,
}
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{},
Status: channeldb.StatusSucceeded,
}
// Wait until there's 1 update or the deadline is exceeded.
deadline := time.Now().Add(1 * time.Second)
for {
if len(stream.sentFromServer) == 1 {
break
}
if time.Now().After(deadline) {
require.FailNow(t, "deadline exceeded.")
}
}
// Only 1 update should be sent to the client.
require.Len(t, stream.sentFromServer, 1)
// Only the final states should be sent to the client.
payment := <-stream.sentFromServer
require.Equal(t, lnrpc.Payment_SUCCEEDED, payment.Status)
}