Merge pull request #6335 from JssDWt/feature/subscribe-paymentattempts

routerrpc: TrackPayments
This commit is contained in:
Oliver Gugger 2022-09-12 12:45:35 +02:00 committed by GitHub
commit bd69e79f84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1722 additions and 712 deletions

View file

@ -23,6 +23,10 @@ transaction](https://github.com/lightningnetwork/lnd/pull/6730).
* [Add list addresses RPC](https://github.com/lightningnetwork/lnd/pull/6596).
* Add [TrackPayments](https://github.com/lightningnetwork/lnd/pull/6335)
method to the RPC to allow subscribing to updates from any inflight payment.
Similar to TrackPaymentV2, but for any inflight payment.
## Wallet
* [Allows Taproot public keys and tap scripts to be imported as watch-only
@ -88,6 +92,7 @@ minimum version needed to build the project.
* Elle Mouton
* ErikEk
* hieblmi
* Jesse de Wit
* Olaoluwa Osuntokun
* Oliver Gugger
* Priyansh Rastogi

File diff suppressed because it is too large Load diff

View file

@ -101,6 +101,34 @@ func request_Router_TrackPaymentV2_0(ctx context.Context, marshaler runtime.Mars
}
var (
filter_Router_TrackPayments_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)}
)
func request_Router_TrackPayments_0(ctx context.Context, marshaler runtime.Marshaler, client RouterClient, req *http.Request, pathParams map[string]string) (Router_TrackPaymentsClient, runtime.ServerMetadata, error) {
var protoReq TrackPaymentsRequest
var metadata runtime.ServerMetadata
if err := req.ParseForm(); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_Router_TrackPayments_0); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
stream, err := client.TrackPayments(ctx, &protoReq)
if err != nil {
return nil, metadata, err
}
header, err := stream.Header()
if err != nil {
return nil, metadata, err
}
metadata.HeaderMD = header
return stream, metadata, nil
}
func request_Router_EstimateRouteFee_0(ctx context.Context, marshaler runtime.Marshaler, client RouterClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var protoReq RouteFeeRequest
var metadata runtime.ServerMetadata
@ -556,6 +584,13 @@ func RegisterRouterHandlerServer(ctx context.Context, mux *runtime.ServeMux, ser
return
})
mux.Handle("GET", pattern_Router_TrackPayments_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")
_, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
})
mux.Handle("POST", pattern_Router_EstimateRouteFee_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
@ -881,6 +916,26 @@ func RegisterRouterHandlerClient(ctx context.Context, mux *runtime.ServeMux, cli
})
mux.Handle("GET", pattern_Router_TrackPayments_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
rctx, err := runtime.AnnotateContext(ctx, mux, req, "/routerrpc.Router/TrackPayments", runtime.WithHTTPPathPattern("/v2/router/payments"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_Router_TrackPayments_0(rctx, inboundMarshaler, client, req, pathParams)
ctx = runtime.NewServerMetadataContext(ctx, md)
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
forward_Router_TrackPayments_0(ctx, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...)
})
mux.Handle("POST", pattern_Router_EstimateRouteFee_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
@ -1129,6 +1184,8 @@ var (
pattern_Router_TrackPaymentV2_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v2", "router", "track", "payment_hash"}, ""))
pattern_Router_TrackPayments_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v2", "router", "payments"}, ""))
pattern_Router_EstimateRouteFee_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v2", "router", "route", "estimatefee"}, ""))
pattern_Router_SendToRouteV2_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v2", "router", "route", "send"}, ""))
@ -1159,6 +1216,8 @@ var (
forward_Router_TrackPaymentV2_0 = runtime.ForwardResponseStream
forward_Router_TrackPayments_0 = runtime.ForwardResponseStream
forward_Router_EstimateRouteFee_0 = runtime.ForwardResponseMessage
forward_Router_SendToRouteV2_0 = runtime.ForwardResponseMessage

View file

@ -107,6 +107,48 @@ func RegisterRouterJSONCallbacks(registry map[string]func(ctx context.Context,
}()
}
registry["routerrpc.Router.TrackPayments"] = func(ctx context.Context,
conn *grpc.ClientConn, reqJSON string, callback func(string, error)) {
req := &TrackPaymentsRequest{}
err := marshaler.Unmarshal([]byte(reqJSON), req)
if err != nil {
callback("", err)
return
}
client := NewRouterClient(conn)
stream, err := client.TrackPayments(ctx, req)
if err != nil {
callback("", err)
return
}
go func() {
for {
select {
case <-stream.Context().Done():
callback("", stream.Context().Err())
return
default:
}
resp, err := stream.Recv()
if err != nil {
callback("", err)
return
}
respBytes, err := marshaler.Marshal(resp)
if err != nil {
callback("", err)
return
}
callback(string(respBytes), nil)
}
}()
}
registry["routerrpc.Router.EstimateRouteFee"] = func(ctx context.Context,
conn *grpc.ClientConn, reqJSON string, callback func(string, error)) {

View file

@ -22,6 +22,16 @@ service Router {
*/
rpc TrackPaymentV2 (TrackPaymentRequest) returns (stream lnrpc.Payment);
/*
TrackPayments returns an update stream for every payment that is not in a
terminal state. Note that if payments are in-flight while starting a new
subscription, the start of the payment stream could produce out-of-order
and/or duplicate events. In order to get updates for every in-flight
payment attempt make sure to subscribe to this method before initiating any
payments.
*/
rpc TrackPayments (TrackPaymentsRequest) returns (stream lnrpc.Payment);
/*
EstimateRouteFee allows callers to obtain a lower bound w.r.t how much it
may cost to send an HTLC to the target end destination.
@ -303,6 +313,14 @@ message TrackPaymentRequest {
bool no_inflight_updates = 2;
}
message TrackPaymentsRequest {
/*
If set, only the final payment updates are streamed back. Intermediate
updates that show which htlcs are still in flight are suppressed.
*/
bool no_inflight_updates = 1;
}
message RouteFeeRequest {
/*
The destination once wishes to obtain a routing fee quote to.

View file

@ -250,6 +250,47 @@
]
}
},
"/v2/router/payments": {
"get": {
"summary": "TrackPayments returns an update stream for every payment that is not in a\nterminal state. Note that if payments are in-flight while starting a new\nsubscription, the start of the payment stream could produce out-of-order\nand/or duplicate events. In order to get updates for every in-flight\npayment attempt make sure to subscribe to this method before initiating any\npayments.",
"operationId": "Router_TrackPayments",
"responses": {
"200": {
"description": "A successful response.(streaming responses)",
"schema": {
"type": "object",
"properties": {
"result": {
"$ref": "#/definitions/lnrpcPayment"
},
"error": {
"$ref": "#/definitions/rpcStatus"
}
},
"title": "Stream result of lnrpcPayment"
}
},
"default": {
"description": "An unexpected error response.",
"schema": {
"$ref": "#/definitions/rpcStatus"
}
}
},
"parameters": [
{
"name": "no_inflight_updates",
"description": "If set, only the final payment updates are streamed back. Intermediate\nupdates that show which htlcs are still in flight are suppressed.",
"in": "query",
"required": false,
"type": "boolean"
}
],
"tags": [
"Router"
]
}
},
"/v2/router/route": {
"post": {
"summary": "BuildRoute builds a fully specified route based on a list of hop public\nkeys. It retrieves the relevant channel policies from the graph in order to\ncalculate the correct fees and time locks.",

View file

@ -8,6 +8,8 @@ http:
body: "*"
- selector: routerrpc.Router.TrackPaymentV2
get: "/v2/router/track/{payment_hash}"
- selector: routerrpc.Router.TrackPayments
get: "/v2/router/payments"
- selector: routerrpc.Router.EstimateRouteFee
post: "/v2/router/route/estimatefee"
body: "*"

View file

@ -26,6 +26,13 @@ type RouterClient interface {
// TrackPaymentV2 returns an update stream for the payment identified by the
// payment hash.
TrackPaymentV2(ctx context.Context, in *TrackPaymentRequest, opts ...grpc.CallOption) (Router_TrackPaymentV2Client, error)
// TrackPayments returns an update stream for every payment that is not in a
// terminal state. Note that if payments are in-flight while starting a new
// subscription, the start of the payment stream could produce out-of-order
// and/or duplicate events. In order to get updates for every in-flight
// payment attempt make sure to subscribe to this method before initiating any
// payments.
TrackPayments(ctx context.Context, in *TrackPaymentsRequest, opts ...grpc.CallOption) (Router_TrackPaymentsClient, error)
// EstimateRouteFee allows callers to obtain a lower bound w.r.t how much it
// may cost to send an HTLC to the target end destination.
EstimateRouteFee(ctx context.Context, in *RouteFeeRequest, opts ...grpc.CallOption) (*RouteFeeResponse, error)
@ -165,6 +172,38 @@ func (x *routerTrackPaymentV2Client) Recv() (*lnrpc.Payment, error) {
return m, nil
}
func (c *routerClient) TrackPayments(ctx context.Context, in *TrackPaymentsRequest, opts ...grpc.CallOption) (Router_TrackPaymentsClient, error) {
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[2], "/routerrpc.Router/TrackPayments", opts...)
if err != nil {
return nil, err
}
x := &routerTrackPaymentsClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type Router_TrackPaymentsClient interface {
Recv() (*lnrpc.Payment, error)
grpc.ClientStream
}
type routerTrackPaymentsClient struct {
grpc.ClientStream
}
func (x *routerTrackPaymentsClient) Recv() (*lnrpc.Payment, error) {
m := new(lnrpc.Payment)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *routerClient) EstimateRouteFee(ctx context.Context, in *RouteFeeRequest, opts ...grpc.CallOption) (*RouteFeeResponse, error) {
out := new(RouteFeeResponse)
err := c.cc.Invoke(ctx, "/routerrpc.Router/EstimateRouteFee", in, out, opts...)
@ -257,7 +296,7 @@ func (c *routerClient) BuildRoute(ctx context.Context, in *BuildRouteRequest, op
}
func (c *routerClient) SubscribeHtlcEvents(ctx context.Context, in *SubscribeHtlcEventsRequest, opts ...grpc.CallOption) (Router_SubscribeHtlcEventsClient, error) {
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[2], "/routerrpc.Router/SubscribeHtlcEvents", opts...)
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[3], "/routerrpc.Router/SubscribeHtlcEvents", opts...)
if err != nil {
return nil, err
}
@ -290,7 +329,7 @@ func (x *routerSubscribeHtlcEventsClient) Recv() (*HtlcEvent, error) {
// Deprecated: Do not use.
func (c *routerClient) SendPayment(ctx context.Context, in *SendPaymentRequest, opts ...grpc.CallOption) (Router_SendPaymentClient, error) {
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[3], "/routerrpc.Router/SendPayment", opts...)
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[4], "/routerrpc.Router/SendPayment", opts...)
if err != nil {
return nil, err
}
@ -323,7 +362,7 @@ func (x *routerSendPaymentClient) Recv() (*PaymentStatus, error) {
// Deprecated: Do not use.
func (c *routerClient) TrackPayment(ctx context.Context, in *TrackPaymentRequest, opts ...grpc.CallOption) (Router_TrackPaymentClient, error) {
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[4], "/routerrpc.Router/TrackPayment", opts...)
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[5], "/routerrpc.Router/TrackPayment", opts...)
if err != nil {
return nil, err
}
@ -355,7 +394,7 @@ func (x *routerTrackPaymentClient) Recv() (*PaymentStatus, error) {
}
func (c *routerClient) HtlcInterceptor(ctx context.Context, opts ...grpc.CallOption) (Router_HtlcInterceptorClient, error) {
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[5], "/routerrpc.Router/HtlcInterceptor", opts...)
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[6], "/routerrpc.Router/HtlcInterceptor", opts...)
if err != nil {
return nil, err
}
@ -405,6 +444,13 @@ type RouterServer interface {
// TrackPaymentV2 returns an update stream for the payment identified by the
// payment hash.
TrackPaymentV2(*TrackPaymentRequest, Router_TrackPaymentV2Server) error
// TrackPayments returns an update stream for every payment that is not in a
// terminal state. Note that if payments are in-flight while starting a new
// subscription, the start of the payment stream could produce out-of-order
// and/or duplicate events. In order to get updates for every in-flight
// payment attempt make sure to subscribe to this method before initiating any
// payments.
TrackPayments(*TrackPaymentsRequest, Router_TrackPaymentsServer) error
// EstimateRouteFee allows callers to obtain a lower bound w.r.t how much it
// may cost to send an HTLC to the target end destination.
EstimateRouteFee(context.Context, *RouteFeeRequest) (*RouteFeeResponse, error)
@ -483,6 +529,9 @@ func (UnimplementedRouterServer) SendPaymentV2(*SendPaymentRequest, Router_SendP
func (UnimplementedRouterServer) TrackPaymentV2(*TrackPaymentRequest, Router_TrackPaymentV2Server) error {
return status.Errorf(codes.Unimplemented, "method TrackPaymentV2 not implemented")
}
func (UnimplementedRouterServer) TrackPayments(*TrackPaymentsRequest, Router_TrackPaymentsServer) error {
return status.Errorf(codes.Unimplemented, "method TrackPayments not implemented")
}
func (UnimplementedRouterServer) EstimateRouteFee(context.Context, *RouteFeeRequest) (*RouteFeeResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method EstimateRouteFee not implemented")
}
@ -583,6 +632,27 @@ func (x *routerTrackPaymentV2Server) Send(m *lnrpc.Payment) error {
return x.ServerStream.SendMsg(m)
}
func _Router_TrackPayments_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(TrackPaymentsRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(RouterServer).TrackPayments(m, &routerTrackPaymentsServer{stream})
}
type Router_TrackPaymentsServer interface {
Send(*lnrpc.Payment) error
grpc.ServerStream
}
type routerTrackPaymentsServer struct {
grpc.ServerStream
}
func (x *routerTrackPaymentsServer) Send(m *lnrpc.Payment) error {
return x.ServerStream.SendMsg(m)
}
func _Router_EstimateRouteFee_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RouteFeeRequest)
if err := dec(in); err != nil {
@ -933,6 +1003,11 @@ var Router_ServiceDesc = grpc.ServiceDesc{
Handler: _Router_TrackPaymentV2_Handler,
ServerStreams: true,
},
{
StreamName: "TrackPayments",
Handler: _Router_TrackPayments_Handler,
ServerStreams: true,
},
{
StreamName: "SubscribeHtlcEvents",
Handler: _Router_SubscribeHtlcEvents_Handler,

View file

@ -72,6 +72,10 @@ var (
Entity: "offchain",
Action: "read",
}},
"/routerrpc.Router/TrackPayments": {{
Entity: "offchain",
Action: "read",
}},
"/routerrpc.Router/EstimateRouteFee": {{
Entity: "offchain",
Action: "read",
@ -737,22 +741,65 @@ func (s *Server) trackPayment(identifier lntypes.Hash,
router := s.cfg.RouterBackend
// Subscribe to the outcome of this payment.
subscription, err := router.Tower.SubscribePayment(
identifier,
)
subscription, err := router.Tower.SubscribePayment(identifier)
switch {
case err == channeldb.ErrPaymentNotInitiated:
return status.Error(codes.NotFound, err.Error())
case err != nil:
return err
}
// Stream updates to the client.
err = s.trackPaymentStream(
stream.Context(), subscription, noInflightUpdates, stream.Send,
)
if errors.Is(err, context.Canceled) {
log.Debugf("Payment stream %v canceled", identifier)
}
return err
}
// TrackPayments returns a stream of payment state updates.
func (s *Server) TrackPayments(request *TrackPaymentsRequest,
stream Router_TrackPaymentsServer) error {
log.Debug("TrackPayments called")
router := s.cfg.RouterBackend
// Subscribe to payments.
subscription, err := router.Tower.SubscribeAllPayments()
if err != nil {
return err
}
// Stream updates to the client.
err = s.trackPaymentStream(
stream.Context(), subscription, request.NoInflightUpdates,
stream.Send,
)
if errors.Is(err, context.Canceled) {
log.Debugf("TrackPayments payment stream canceled.")
}
return err
}
// trackPaymentStream streams payment updates to the client.
func (s *Server) trackPaymentStream(context context.Context,
subscription routing.ControlTowerSubscriber, noInflightUpdates bool,
send func(*lnrpc.Payment) error) error {
defer subscription.Close()
// Stream updates back to the client. The first update is always the
// current state of the payment.
// Stream updates back to the client.
for {
select {
case item, ok := <-subscription.Updates:
case item, ok := <-subscription.Updates():
if !ok {
// No more payment updates.
return nil
@ -766,13 +813,15 @@ func (s *Server) trackPayment(identifier lntypes.Hash,
continue
}
rpcPayment, err := router.MarshallPayment(result)
rpcPayment, err := s.cfg.RouterBackend.MarshallPayment(
result,
)
if err != nil {
return err
}
// Send event to the client.
err = stream.Send(rpcPayment)
err = send(rpcPayment)
if err != nil {
return err
}
@ -780,9 +829,8 @@ func (s *Server) trackPayment(identifier lntypes.Hash,
case <-s.quit:
return errServerShuttingDown
case <-stream.Context().Done():
log.Debugf("Payment status stream %v canceled", identifier)
return stream.Context().Err()
case <-context.Done():
return context.Err()
}
}
}

View file

@ -0,0 +1,216 @@
package routerrpc
import (
"context"
"testing"
"time"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/routing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
type streamMock struct {
grpc.ServerStream
ctx context.Context
sentFromServer chan *lnrpc.Payment
}
func makeStreamMock(ctx context.Context) *streamMock {
return &streamMock{
ctx: ctx,
sentFromServer: make(chan *lnrpc.Payment, 10),
}
}
func (m *streamMock) Context() context.Context {
return m.ctx
}
func (m *streamMock) Send(p *lnrpc.Payment) error {
m.sentFromServer <- p
return nil
}
type controlTowerSubscriberMock struct {
updates <-chan interface{}
}
func (s controlTowerSubscriberMock) Updates() <-chan interface{} {
return s.updates
}
func (s controlTowerSubscriberMock) Close() {
}
type controlTowerMock struct {
queue *queue.ConcurrentQueue
routing.ControlTower
}
func makeControlTowerMock() *controlTowerMock {
towerMock := &controlTowerMock{
queue: queue.NewConcurrentQueue(20),
}
towerMock.queue.Start()
return towerMock
}
func (t *controlTowerMock) SubscribeAllPayments() (
routing.ControlTowerSubscriber, error) {
return &controlTowerSubscriberMock{
updates: t.queue.ChanOut(),
}, nil
}
// TestTrackPaymentsReturnsOnCancelContext tests whether TrackPayments returns
// when the stream context is cancelled.
func TestTrackPaymentsReturnsOnCancelContext(t *testing.T) {
// Setup mocks and request.
request := &TrackPaymentsRequest{
NoInflightUpdates: false,
}
towerMock := makeControlTowerMock()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
Tower: towerMock,
},
},
}
// Cancel stream immediately
cancelStream()
// Make sure the call returns.
err := server.TrackPayments(request, stream)
require.Equal(t, context.Canceled, err)
}
// TestTrackPaymentsInflightUpdate tests whether all updates from the control
// tower are propagated to the client.
func TestTrackPaymentsInflightUpdates(t *testing.T) {
// Setup mocks and request.
request := &TrackPaymentsRequest{
NoInflightUpdates: false,
}
towerMock := makeControlTowerMock()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
defer cancelStream()
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
Tower: towerMock,
},
},
}
// Listen to payment updates in a goroutine.
go func() {
err := server.TrackPayments(request, stream)
require.Equal(t, context.Canceled, err)
}()
// Enqueue some payment updates on the mock.
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{},
Status: channeldb.StatusInFlight,
}
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{},
Status: channeldb.StatusSucceeded,
}
// Wait until there's 2 updates or the deadline is exceeded.
deadline := time.Now().Add(1 * time.Second)
for {
if len(stream.sentFromServer) == 2 {
break
}
if time.Now().After(deadline) {
require.FailNow(t, "deadline exceeded.")
}
}
// Both updates should be sent to the client.
require.Len(t, stream.sentFromServer, 2)
// The updates should be in the right order.
payment := <-stream.sentFromServer
require.Equal(t, lnrpc.Payment_IN_FLIGHT, payment.Status)
payment = <-stream.sentFromServer
require.Equal(t, lnrpc.Payment_SUCCEEDED, payment.Status)
}
// TestTrackPaymentsInflightUpdate tests whether only final updates from the
// control tower are propagated to the client when noInflightUpdates = true.
func TestTrackPaymentsNoInflightUpdates(t *testing.T) {
// Setup mocks and request.
request := &TrackPaymentsRequest{
NoInflightUpdates: true,
}
towerMock := &controlTowerMock{
queue: queue.NewConcurrentQueue(20),
}
towerMock.queue.Start()
streamCtx, cancelStream := context.WithCancel(context.Background())
stream := makeStreamMock(streamCtx)
defer cancelStream()
server := &Server{
cfg: &Config{
RouterBackend: &RouterBackend{
Tower: towerMock,
},
},
}
// Listen to payment updates in a goroutine.
go func() {
err := server.TrackPayments(request, stream)
require.Equal(t, context.Canceled, err)
}()
// Enqueue some payment updates on the mock.
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{},
Status: channeldb.StatusInFlight,
}
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{},
Status: channeldb.StatusSucceeded,
}
// Wait until there's 1 update or the deadline is exceeded.
deadline := time.Now().Add(1 * time.Second)
for {
if len(stream.sentFromServer) == 1 {
break
}
if time.Now().After(deadline) {
require.FailNow(t, "deadline exceeded.")
}
}
// Only 1 update should be sent to the client.
require.Len(t, stream.sentFromServer, 1)
// Only the final states should be sent to the client.
payment := <-stream.sentFromServer
require.Equal(t, lnrpc.Payment_SUCCEEDED, payment.Status)
}

View file

@ -439,4 +439,8 @@ var allTestCases = []*testCase{
name: "taproot coop close",
test: testTaprootCoopClose,
},
{
name: "trackpayments",
test: testTrackPayments,
},
}

View file

@ -0,0 +1,102 @@
package itest
import (
"context"
"encoding/hex"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
"github.com/lightningnetwork/lnd/lntest"
"github.com/stretchr/testify/require"
)
// testTrackPayments tests whether a client that calls the TrackPayments api
// receives payment updates.
func testTrackPayments(net *lntest.NetworkHarness, t *harnessTest) {
// Open a channel between alice and bob.
net.EnsureConnected(t.t, net.Alice, net.Bob)
channel := openChannelAndAssert(
t, net, net.Alice, net.Bob,
lntest.OpenChannelParams{
Amt: btcutil.Amount(300000),
},
)
defer closeChannelAndAssert(t, net, net.Alice, channel, true)
err := net.Alice.WaitForNetworkChannelOpen(channel)
require.NoError(t.t, err, "unable to wait for channel to open")
ctxb := context.Background()
ctxt, cancelTracker := context.WithCancel(ctxb)
defer cancelTracker()
// Call the TrackPayments api to listen for payment updates.
tracker, err := net.Alice.RouterClient.TrackPayments(
ctxt,
&routerrpc.TrackPaymentsRequest{
NoInflightUpdates: false,
},
)
require.NoError(t.t, err, "failed to call TrackPayments successfully.")
// Create an invoice from bob.
var amountMsat int64 = 1000
invoiceResp, err := net.Bob.AddInvoice(
ctxb,
&lnrpc.Invoice{
ValueMsat: amountMsat,
},
)
require.NoError(t.t, err, "unable to add invoice.")
invoice, err := net.Bob.LookupInvoice(
ctxb,
&lnrpc.PaymentHash{
RHashStr: hex.EncodeToString(invoiceResp.RHash),
},
)
require.NoError(t.t, err, "unable to find invoice.")
// Send payment from alice to bob.
paymentClient, err := net.Alice.RouterClient.SendPaymentV2(
ctxb,
&routerrpc.SendPaymentRequest{
PaymentRequest: invoice.PaymentRequest,
TimeoutSeconds: 60,
},
)
require.NoError(t.t, err, "unable to send payment.")
// Make sure the payment doesn't error due to invalid parameters or so.
_, err = paymentClient.Recv()
require.NoError(t.t, err, "unable to get payment update.")
// Assert the first payment update is an inflight update.
update1, err := tracker.Recv()
require.NoError(t.t, err, "unable to receive payment update 1.")
require.Equal(
t.t, lnrpc.PaymentFailureReason_FAILURE_REASON_NONE,
update1.FailureReason,
)
require.Equal(t.t, lnrpc.Payment_IN_FLIGHT, update1.Status)
require.Equal(t.t, invoice.PaymentRequest, update1.PaymentRequest)
require.Equal(t.t, amountMsat, update1.ValueMsat)
// Assert the second payment update is a payment success update.
update2, err := tracker.Recv()
require.NoError(t.t, err, "unable to receive payment update 2.")
require.Equal(
t.t, lnrpc.PaymentFailureReason_FAILURE_REASON_NONE,
update2.FailureReason,
)
require.Equal(t.t, lnrpc.Payment_SUCCEEDED, update2.Status)
require.Equal(t.t, invoice.PaymentRequest, update2.PaymentRequest)
require.Equal(t.t, amountMsat, update2.ValueMsat)
require.Equal(
t.t, hex.EncodeToString(invoice.RPreimage),
update2.PaymentPreimage,
)
}

View file

@ -60,35 +60,48 @@ type ControlTower interface {
// SubscribePayment subscribes to updates for the payment with the given
// hash. A first update with the current state of the payment is always
// sent out immediately.
SubscribePayment(paymentHash lntypes.Hash) (*ControlTowerSubscriber,
SubscribePayment(paymentHash lntypes.Hash) (ControlTowerSubscriber,
error)
// SubscribeAllPayments subscribes to updates for all payments. A first
// update with the current state of every inflight payment is always
// sent out immediately.
SubscribeAllPayments() (ControlTowerSubscriber, error)
}
// ControlTowerSubscriber contains the state for a payment update subscriber.
type ControlTowerSubscriber struct {
type ControlTowerSubscriber interface {
// Updates is the channel over which *channeldb.MPPayment updates can be
// received.
Updates <-chan interface{}
Updates() <-chan interface{}
queue *queue.ConcurrentQueue
quit chan struct{}
// Close signals that the subscriber is no longer interested in updates.
Close()
}
// ControlTowerSubscriberImpl contains the state for a payment update
// subscriber.
type controlTowerSubscriberImpl struct {
updates <-chan interface{}
queue *queue.ConcurrentQueue
quit chan struct{}
}
// newControlTowerSubscriber instantiates a new subscriber state object.
func newControlTowerSubscriber() *ControlTowerSubscriber {
func newControlTowerSubscriber() *controlTowerSubscriberImpl {
// Create a queue for payment updates.
queue := queue.NewConcurrentQueue(20)
queue.Start()
return &ControlTowerSubscriber{
Updates: queue.ChanOut(),
return &controlTowerSubscriberImpl{
updates: queue.ChanOut(),
queue: queue,
quit: make(chan struct{}),
}
}
// Close signals that the subscriber is no longer interested in updates.
func (s *ControlTowerSubscriber) Close() {
func (s *controlTowerSubscriberImpl) Close() {
// Close quit channel so that any pending writes to the queue are
// cancelled.
close(s.quit)
@ -97,13 +110,24 @@ func (s *ControlTowerSubscriber) Close() {
s.queue.Stop()
}
// Updates is the channel over which *channeldb.MPPayment updates can be
// received.
func (s *controlTowerSubscriberImpl) Updates() <-chan interface{} {
return s.updates
}
// controlTower is persistent implementation of ControlTower to restrict
// double payment sending.
type controlTower struct {
db *channeldb.PaymentControl
subscribers map[lntypes.Hash][]*ControlTowerSubscriber
subscribersMtx sync.Mutex
// subscriberIndex is used to provide a unique id for each subscriber
// to all payments. This is used to easily remove the subscriber when
// necessary.
subscriberIndex uint64
subscribersAllPayments map[uint64]*controlTowerSubscriberImpl
subscribers map[lntypes.Hash][]*controlTowerSubscriberImpl
subscribersMtx sync.Mutex
// paymentsMtx provides synchronization on the payment level to ensure
// that no race conditions occur in between updating the database and
@ -114,8 +138,11 @@ type controlTower struct {
// NewControlTower creates a new instance of the controlTower.
func NewControlTower(db *channeldb.PaymentControl) ControlTower {
return &controlTower{
db: db,
subscribers: make(map[lntypes.Hash][]*ControlTowerSubscriber),
db: db,
subscribersAllPayments: make(
map[uint64]*controlTowerSubscriberImpl,
),
subscribers: make(map[lntypes.Hash][]*controlTowerSubscriberImpl),
paymentsMtx: multimutex.NewHashMutex(),
}
}
@ -232,7 +259,7 @@ func (p *controlTower) FetchInFlightPayments() ([]*channeldb.MPPayment, error) {
// first update with the current state of the payment is always sent out
// immediately.
func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) {
ControlTowerSubscriber, error) {
// Take lock before querying the db to prevent missing or duplicating an
// update.
@ -266,6 +293,39 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
return subscriber, nil
}
// SubscribeAllPayments subscribes to updates for all inflight payments. A first
// update with the current state of every inflight payment is always sent out
// immediately.
// Note: If payments are in-flight while starting a new subscription, the start
// of the payment stream could produce out-of-order and/or duplicate events. In
// order to get updates for every in-flight payment attempt make sure to
// subscribe to this method before initiating any payments.
func (p *controlTower) SubscribeAllPayments() (ControlTowerSubscriber, error) {
subscriber := newControlTowerSubscriber()
// Add the subscriber to the list before fetching in-flight payments, so
// no events are missed. If a payment attempt update occurs after
// appending and before fetching in-flight payments, an out-of-order
// duplicate may be produced, because it is then fetched in below call
// and notified through the subscription.
p.subscribersMtx.Lock()
p.subscribersAllPayments[p.subscriberIndex] = subscriber
p.subscriberIndex++
p.subscribersMtx.Unlock()
inflightPayments, err := p.db.FetchInFlightPayments()
if err != nil {
return nil, err
}
for index := range inflightPayments {
// Always write current payment state to the channel.
subscriber.queue.ChanIn() <- inflightPayments[index]
}
return subscriber, nil
}
// notifySubscribers sends a final payment event to all subscribers of this
// payment. The channel will be closed after this. Note that this function must
// be executed atomically (by means of a lock) with the database update to
@ -275,8 +335,9 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
// Get all subscribers for this payment.
p.subscribersMtx.Lock()
list, ok := p.subscribers[paymentHash]
if !ok {
subscribersPaymentHash, ok := p.subscribers[paymentHash]
if !ok && len(p.subscribersAllPayments) == 0 {
p.subscribersMtx.Unlock()
return
}
@ -287,10 +348,17 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
if terminal {
delete(p.subscribers, paymentHash)
}
// Copy subscribers to all payments locally while holding the lock in
// order to avoid concurrency issues while reading/writing the map.
subscribersAllPayments := make(map[uint64]*controlTowerSubscriberImpl)
for k, v := range p.subscribersAllPayments {
subscribersAllPayments[k] = v
}
p.subscribersMtx.Unlock()
// Notify all subscribers of the event.
for _, subscriber := range list {
// Notify all subscribers that subscribed to the current payment hash.
for _, subscriber := range subscribersPaymentHash {
select {
case subscriber.queue.ChanIn() <- event:
// If this event is the last, close the incoming channel
@ -305,4 +373,18 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
case <-subscriber.quit:
}
}
// Notify all subscribers that subscribed to all payments.
for key, subscriber := range subscribersAllPayments {
select {
case subscriber.queue.ChanIn() <- event:
// If subscriber disappeared, remove it from the subscribers
// list.
case <-subscriber.quit:
p.subscribersMtx.Lock()
delete(p.subscribersAllPayments, key)
p.subscribersMtx.Unlock()
}
}
}

View file

@ -115,7 +115,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
// We expect all subscribers to now report the final outcome followed by
// no other events.
subscribers := []*ControlTowerSubscriber{
subscribers := []ControlTowerSubscriber{
subscriber1, subscriber2, subscriber3,
}
@ -123,7 +123,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
var result *channeldb.MPPayment
for result == nil || result.Status == channeldb.StatusInFlight {
select {
case item := <-s.Updates:
case item := <-s.Updates():
result = item.(*channeldb.MPPayment)
case <-time.After(testTimeout):
t.Fatal("timeout waiting for payment result")
@ -149,7 +149,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
// After the final event, we expect the channel to be closed.
select {
case _, ok := <-s.Updates:
case _, ok := <-s.Updates():
if ok {
t.Fatal("expected channel to be closed")
}
@ -178,6 +178,236 @@ func TestPaymentControlSubscribeFail(t *testing.T) {
})
}
// TestPaymentControlSubscribeAllSuccess tests that multiple payments are
// properly sent to subscribers of TrackPayments.
func TestPaymentControlSubscribeAllSuccess(t *testing.T) {
t.Parallel()
db, err := initDB(t, true)
require.NoError(t, err, "unable to init db: %v")
pControl := NewControlTower(channeldb.NewPaymentControl(db))
// Initiate a payment.
info1, attempt1, preimg1, err := genInfo()
require.NoError(t, err)
err = pControl.InitPayment(info1.PaymentIdentifier, info1)
require.NoError(t, err)
// Subscription should succeed and immediately report the InFlight
// status.
subscription, err := pControl.SubscribeAllPayments()
require.NoError(t, err, "expected subscribe to succeed, but got: %v")
// Register an attempt.
err = pControl.RegisterAttempt(info1.PaymentIdentifier, attempt1)
require.NoError(t, err)
// Initiate a second payment after the subscription is already active.
info2, attempt2, preimg2, err := genInfo()
require.NoError(t, err)
err = pControl.InitPayment(info2.PaymentIdentifier, info2)
require.NoError(t, err)
// Register an attempt on the second payment.
err = pControl.RegisterAttempt(info2.PaymentIdentifier, attempt2)
require.NoError(t, err)
// Mark the first payment as successful.
settleInfo1 := channeldb.HTLCSettleInfo{
Preimage: preimg1,
}
htlcAttempt1, err := pControl.SettleAttempt(
info1.PaymentIdentifier, attempt1.AttemptID, &settleInfo1,
)
require.NoError(t, err)
require.Equal(
t, settleInfo1, *htlcAttempt1.Settle,
"unexpected settle info returned",
)
// Mark the second payment as successful.
settleInfo2 := channeldb.HTLCSettleInfo{
Preimage: preimg2,
}
htlcAttempt2, err := pControl.SettleAttempt(
info2.PaymentIdentifier, attempt2.AttemptID, &settleInfo2,
)
require.NoError(t, err)
require.Equal(
t, settleInfo2, *htlcAttempt2.Settle,
"unexpected fail info returned",
)
// The two payments will be asserted individually, store the last update
// for each payment.
results := make(map[lntypes.Hash]*channeldb.MPPayment)
// After exactly 5 updates both payments will/should have completed.
for i := 0; i < 5; i++ {
select {
case item := <-subscription.Updates():
id := item.(*channeldb.MPPayment).Info.PaymentIdentifier
results[id] = item.(*channeldb.MPPayment)
case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result")
}
}
result1 := results[info1.PaymentIdentifier]
require.Equal(
t, channeldb.StatusSucceeded, result1.Status,
"unexpected payment state payment 1",
)
settle1, _ := result1.TerminalInfo()
require.Equal(
t, preimg1, settle1.Preimage, "unexpected preimage payment 1",
)
require.Len(
t, result1.HTLCs, 1, "expect 1 htlc for payment 1, got %d",
len(result1.HTLCs),
)
htlc1 := result1.HTLCs[0]
require.Equal(t, attempt1.Route, htlc1.Route, "unexpected htlc route.")
result2 := results[info2.PaymentIdentifier]
require.Equal(
t, channeldb.StatusSucceeded, result2.Status,
"unexpected payment state payment 2",
)
settle2, _ := result2.TerminalInfo()
require.Equal(
t, preimg2, settle2.Preimage, "unexpected preimage payment 2",
)
require.Len(
t, result2.HTLCs, 1, "expect 1 htlc for payment 2, got %d",
len(result2.HTLCs),
)
htlc2 := result2.HTLCs[0]
require.Equal(t, attempt2.Route, htlc2.Route, "unexpected htlc route.")
}
// TestPaymentControlSubscribeAllImmediate tests whether already inflight
// payments are reported at the start of the SubscribeAllPayments subscription.
func TestPaymentControlSubscribeAllImmediate(t *testing.T) {
t.Parallel()
db, err := initDB(t, true)
require.NoError(t, err, "unable to init db: %v")
pControl := NewControlTower(channeldb.NewPaymentControl(db))
// Initiate a payment.
info, attempt, _, err := genInfo()
require.NoError(t, err)
err = pControl.InitPayment(info.PaymentIdentifier, info)
require.NoError(t, err)
// Register a payment update.
err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt)
require.NoError(t, err)
subscription, err := pControl.SubscribeAllPayments()
require.NoError(t, err, "expected subscribe to succeed, but got: %v")
// Assert the new subscription receives the old update.
select {
case update := <-subscription.Updates():
require.NotNil(t, update)
require.Equal(
t, info.PaymentIdentifier,
update.(*channeldb.MPPayment).Info.PaymentIdentifier,
)
require.Len(t, subscription.Updates(), 0)
case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result")
}
}
// TestPaymentControlUnsubscribeSuccess tests that when unsubscribed, there are
// no more notifications to that specific subscription.
func TestPaymentControlUnsubscribeSuccess(t *testing.T) {
t.Parallel()
db, err := initDB(t, true)
require.NoError(t, err, "unable to init db: %v")
pControl := NewControlTower(channeldb.NewPaymentControl(db))
subscription1, err := pControl.SubscribeAllPayments()
require.NoError(t, err, "expected subscribe to succeed, but got: %v")
subscription2, err := pControl.SubscribeAllPayments()
require.NoError(t, err, "expected subscribe to succeed, but got: %v")
// Initiate a payment.
info, attempt, _, err := genInfo()
require.NoError(t, err)
err = pControl.InitPayment(info.PaymentIdentifier, info)
require.NoError(t, err)
// Register a payment update.
err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt)
require.NoError(t, err)
// Assert all subscriptions receive the update.
select {
case update1 := <-subscription1.Updates():
require.NotNil(t, update1)
case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result")
}
select {
case update2 := <-subscription2.Updates():
require.NotNil(t, update2)
case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result")
}
// Close the first subscription.
subscription1.Close()
// Register another update.
failInfo := channeldb.HTLCFailInfo{
Reason: channeldb.HTLCFailInternal,
}
_, err = pControl.FailAttempt(
info.PaymentIdentifier, attempt.AttemptID, &failInfo,
)
require.NoError(t, err, "unable to fail htlc")
// Assert only subscription 2 receives the update.
select {
case update2 := <-subscription2.Updates():
require.NotNil(t, update2)
case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for payment result")
}
require.Len(t, subscription1.Updates(), 0)
// Close the second subscription.
subscription2.Close()
// Register a last update.
err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt)
require.NoError(t, err)
// Assert no subscriptions receive the update.
require.Len(t, subscription1.Updates(), 0)
require.Len(t, subscription2.Updates(), 0)
}
func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
keepFailedPaymentAttempts bool) {
@ -237,7 +467,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
// We expect both subscribers to now report the final outcome followed
// by no other events.
subscribers := []*ControlTowerSubscriber{
subscribers := []ControlTowerSubscriber{
subscriber1, subscriber2,
}
@ -245,7 +475,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
var result *channeldb.MPPayment
for result == nil || result.Status == channeldb.StatusInFlight {
select {
case item := <-s.Updates:
case item := <-s.Updates():
result = item.(*channeldb.MPPayment)
case <-time.After(testTimeout):
t.Fatal("timeout waiting for payment result")
@ -283,7 +513,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
// After the final event, we expect the channel to be closed.
select {
case _, ok := <-s.Updates:
case _, ok := <-s.Updates():
if ok {
t.Fatal("expected channel to be closed")
}

View file

@ -552,7 +552,13 @@ func (m *mockControlTowerOld) FetchInFlightPayments() (
}
func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) {
ControlTowerSubscriber, error) {
return nil, errors.New("not implemented")
}
func (m *mockControlTowerOld) SubscribeAllPayments() (
ControlTowerSubscriber, error) {
return nil, errors.New("not implemented")
}
@ -768,10 +774,17 @@ func (m *mockControlTower) FetchInFlightPayments() (
}
func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) {
ControlTowerSubscriber, error) {
args := m.Called(paymentHash)
return args.Get(0).(*ControlTowerSubscriber), args.Error(1)
return args.Get(0).(ControlTowerSubscriber), args.Error(1)
}
func (m *mockControlTower) SubscribeAllPayments() (
ControlTowerSubscriber, error) {
args := m.Called()
return args.Get(0).(ControlTowerSubscriber), args.Error(1)
}
type mockLink struct {