mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-22 22:25:24 +01:00
Merge pull request #6335 from JssDWt/feature/subscribe-paymentattempts
routerrpc: TrackPayments
This commit is contained in:
commit
bd69e79f84
15 changed files with 1722 additions and 712 deletions
|
@ -23,6 +23,10 @@ transaction](https://github.com/lightningnetwork/lnd/pull/6730).
|
|||
|
||||
* [Add list addresses RPC](https://github.com/lightningnetwork/lnd/pull/6596).
|
||||
|
||||
* Add [TrackPayments](https://github.com/lightningnetwork/lnd/pull/6335)
|
||||
method to the RPC to allow subscribing to updates from any inflight payment.
|
||||
Similar to TrackPaymentV2, but for any inflight payment.
|
||||
|
||||
## Wallet
|
||||
|
||||
* [Allows Taproot public keys and tap scripts to be imported as watch-only
|
||||
|
@ -88,6 +92,7 @@ minimum version needed to build the project.
|
|||
* Elle Mouton
|
||||
* ErikEk
|
||||
* hieblmi
|
||||
* Jesse de Wit
|
||||
* Olaoluwa Osuntokun
|
||||
* Oliver Gugger
|
||||
* Priyansh Rastogi
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -101,6 +101,34 @@ func request_Router_TrackPaymentV2_0(ctx context.Context, marshaler runtime.Mars
|
|||
|
||||
}
|
||||
|
||||
var (
|
||||
filter_Router_TrackPayments_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)}
|
||||
)
|
||||
|
||||
func request_Router_TrackPayments_0(ctx context.Context, marshaler runtime.Marshaler, client RouterClient, req *http.Request, pathParams map[string]string) (Router_TrackPaymentsClient, runtime.ServerMetadata, error) {
|
||||
var protoReq TrackPaymentsRequest
|
||||
var metadata runtime.ServerMetadata
|
||||
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_Router_TrackPayments_0); err != nil {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
|
||||
stream, err := client.TrackPayments(ctx, &protoReq)
|
||||
if err != nil {
|
||||
return nil, metadata, err
|
||||
}
|
||||
header, err := stream.Header()
|
||||
if err != nil {
|
||||
return nil, metadata, err
|
||||
}
|
||||
metadata.HeaderMD = header
|
||||
return stream, metadata, nil
|
||||
|
||||
}
|
||||
|
||||
func request_Router_EstimateRouteFee_0(ctx context.Context, marshaler runtime.Marshaler, client RouterClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
|
||||
var protoReq RouteFeeRequest
|
||||
var metadata runtime.ServerMetadata
|
||||
|
@ -556,6 +584,13 @@ func RegisterRouterHandlerServer(ctx context.Context, mux *runtime.ServeMux, ser
|
|||
return
|
||||
})
|
||||
|
||||
mux.Handle("GET", pattern_Router_TrackPayments_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
|
||||
err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")
|
||||
_, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
|
||||
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
|
||||
return
|
||||
})
|
||||
|
||||
mux.Handle("POST", pattern_Router_EstimateRouteFee_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
|
@ -881,6 +916,26 @@ func RegisterRouterHandlerClient(ctx context.Context, mux *runtime.ServeMux, cli
|
|||
|
||||
})
|
||||
|
||||
mux.Handle("GET", pattern_Router_TrackPayments_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
|
||||
rctx, err := runtime.AnnotateContext(ctx, mux, req, "/routerrpc.Router/TrackPayments", runtime.WithHTTPPathPattern("/v2/router/payments"))
|
||||
if err != nil {
|
||||
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
|
||||
return
|
||||
}
|
||||
resp, md, err := request_Router_TrackPayments_0(rctx, inboundMarshaler, client, req, pathParams)
|
||||
ctx = runtime.NewServerMetadataContext(ctx, md)
|
||||
if err != nil {
|
||||
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
forward_Router_TrackPayments_0(ctx, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...)
|
||||
|
||||
})
|
||||
|
||||
mux.Handle("POST", pattern_Router_EstimateRouteFee_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
|
@ -1129,6 +1184,8 @@ var (
|
|||
|
||||
pattern_Router_TrackPaymentV2_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v2", "router", "track", "payment_hash"}, ""))
|
||||
|
||||
pattern_Router_TrackPayments_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v2", "router", "payments"}, ""))
|
||||
|
||||
pattern_Router_EstimateRouteFee_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v2", "router", "route", "estimatefee"}, ""))
|
||||
|
||||
pattern_Router_SendToRouteV2_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v2", "router", "route", "send"}, ""))
|
||||
|
@ -1159,6 +1216,8 @@ var (
|
|||
|
||||
forward_Router_TrackPaymentV2_0 = runtime.ForwardResponseStream
|
||||
|
||||
forward_Router_TrackPayments_0 = runtime.ForwardResponseStream
|
||||
|
||||
forward_Router_EstimateRouteFee_0 = runtime.ForwardResponseMessage
|
||||
|
||||
forward_Router_SendToRouteV2_0 = runtime.ForwardResponseMessage
|
||||
|
|
|
@ -107,6 +107,48 @@ func RegisterRouterJSONCallbacks(registry map[string]func(ctx context.Context,
|
|||
}()
|
||||
}
|
||||
|
||||
registry["routerrpc.Router.TrackPayments"] = func(ctx context.Context,
|
||||
conn *grpc.ClientConn, reqJSON string, callback func(string, error)) {
|
||||
|
||||
req := &TrackPaymentsRequest{}
|
||||
err := marshaler.Unmarshal([]byte(reqJSON), req)
|
||||
if err != nil {
|
||||
callback("", err)
|
||||
return
|
||||
}
|
||||
|
||||
client := NewRouterClient(conn)
|
||||
stream, err := client.TrackPayments(ctx, req)
|
||||
if err != nil {
|
||||
callback("", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
callback("", stream.Context().Err())
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
resp, err := stream.Recv()
|
||||
if err != nil {
|
||||
callback("", err)
|
||||
return
|
||||
}
|
||||
|
||||
respBytes, err := marshaler.Marshal(resp)
|
||||
if err != nil {
|
||||
callback("", err)
|
||||
return
|
||||
}
|
||||
callback(string(respBytes), nil)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
registry["routerrpc.Router.EstimateRouteFee"] = func(ctx context.Context,
|
||||
conn *grpc.ClientConn, reqJSON string, callback func(string, error)) {
|
||||
|
||||
|
|
|
@ -22,6 +22,16 @@ service Router {
|
|||
*/
|
||||
rpc TrackPaymentV2 (TrackPaymentRequest) returns (stream lnrpc.Payment);
|
||||
|
||||
/*
|
||||
TrackPayments returns an update stream for every payment that is not in a
|
||||
terminal state. Note that if payments are in-flight while starting a new
|
||||
subscription, the start of the payment stream could produce out-of-order
|
||||
and/or duplicate events. In order to get updates for every in-flight
|
||||
payment attempt make sure to subscribe to this method before initiating any
|
||||
payments.
|
||||
*/
|
||||
rpc TrackPayments (TrackPaymentsRequest) returns (stream lnrpc.Payment);
|
||||
|
||||
/*
|
||||
EstimateRouteFee allows callers to obtain a lower bound w.r.t how much it
|
||||
may cost to send an HTLC to the target end destination.
|
||||
|
@ -303,6 +313,14 @@ message TrackPaymentRequest {
|
|||
bool no_inflight_updates = 2;
|
||||
}
|
||||
|
||||
message TrackPaymentsRequest {
|
||||
/*
|
||||
If set, only the final payment updates are streamed back. Intermediate
|
||||
updates that show which htlcs are still in flight are suppressed.
|
||||
*/
|
||||
bool no_inflight_updates = 1;
|
||||
}
|
||||
|
||||
message RouteFeeRequest {
|
||||
/*
|
||||
The destination once wishes to obtain a routing fee quote to.
|
||||
|
|
|
@ -250,6 +250,47 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"/v2/router/payments": {
|
||||
"get": {
|
||||
"summary": "TrackPayments returns an update stream for every payment that is not in a\nterminal state. Note that if payments are in-flight while starting a new\nsubscription, the start of the payment stream could produce out-of-order\nand/or duplicate events. In order to get updates for every in-flight\npayment attempt make sure to subscribe to this method before initiating any\npayments.",
|
||||
"operationId": "Router_TrackPayments",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A successful response.(streaming responses)",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"$ref": "#/definitions/lnrpcPayment"
|
||||
},
|
||||
"error": {
|
||||
"$ref": "#/definitions/rpcStatus"
|
||||
}
|
||||
},
|
||||
"title": "Stream result of lnrpcPayment"
|
||||
}
|
||||
},
|
||||
"default": {
|
||||
"description": "An unexpected error response.",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/rpcStatus"
|
||||
}
|
||||
}
|
||||
},
|
||||
"parameters": [
|
||||
{
|
||||
"name": "no_inflight_updates",
|
||||
"description": "If set, only the final payment updates are streamed back. Intermediate\nupdates that show which htlcs are still in flight are suppressed.",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"type": "boolean"
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Router"
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v2/router/route": {
|
||||
"post": {
|
||||
"summary": "BuildRoute builds a fully specified route based on a list of hop public\nkeys. It retrieves the relevant channel policies from the graph in order to\ncalculate the correct fees and time locks.",
|
||||
|
|
|
@ -8,6 +8,8 @@ http:
|
|||
body: "*"
|
||||
- selector: routerrpc.Router.TrackPaymentV2
|
||||
get: "/v2/router/track/{payment_hash}"
|
||||
- selector: routerrpc.Router.TrackPayments
|
||||
get: "/v2/router/payments"
|
||||
- selector: routerrpc.Router.EstimateRouteFee
|
||||
post: "/v2/router/route/estimatefee"
|
||||
body: "*"
|
||||
|
|
|
@ -26,6 +26,13 @@ type RouterClient interface {
|
|||
// TrackPaymentV2 returns an update stream for the payment identified by the
|
||||
// payment hash.
|
||||
TrackPaymentV2(ctx context.Context, in *TrackPaymentRequest, opts ...grpc.CallOption) (Router_TrackPaymentV2Client, error)
|
||||
// TrackPayments returns an update stream for every payment that is not in a
|
||||
// terminal state. Note that if payments are in-flight while starting a new
|
||||
// subscription, the start of the payment stream could produce out-of-order
|
||||
// and/or duplicate events. In order to get updates for every in-flight
|
||||
// payment attempt make sure to subscribe to this method before initiating any
|
||||
// payments.
|
||||
TrackPayments(ctx context.Context, in *TrackPaymentsRequest, opts ...grpc.CallOption) (Router_TrackPaymentsClient, error)
|
||||
// EstimateRouteFee allows callers to obtain a lower bound w.r.t how much it
|
||||
// may cost to send an HTLC to the target end destination.
|
||||
EstimateRouteFee(ctx context.Context, in *RouteFeeRequest, opts ...grpc.CallOption) (*RouteFeeResponse, error)
|
||||
|
@ -165,6 +172,38 @@ func (x *routerTrackPaymentV2Client) Recv() (*lnrpc.Payment, error) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
func (c *routerClient) TrackPayments(ctx context.Context, in *TrackPaymentsRequest, opts ...grpc.CallOption) (Router_TrackPaymentsClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[2], "/routerrpc.Router/TrackPayments", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &routerTrackPaymentsClient{stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type Router_TrackPaymentsClient interface {
|
||||
Recv() (*lnrpc.Payment, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type routerTrackPaymentsClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *routerTrackPaymentsClient) Recv() (*lnrpc.Payment, error) {
|
||||
m := new(lnrpc.Payment)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (c *routerClient) EstimateRouteFee(ctx context.Context, in *RouteFeeRequest, opts ...grpc.CallOption) (*RouteFeeResponse, error) {
|
||||
out := new(RouteFeeResponse)
|
||||
err := c.cc.Invoke(ctx, "/routerrpc.Router/EstimateRouteFee", in, out, opts...)
|
||||
|
@ -257,7 +296,7 @@ func (c *routerClient) BuildRoute(ctx context.Context, in *BuildRouteRequest, op
|
|||
}
|
||||
|
||||
func (c *routerClient) SubscribeHtlcEvents(ctx context.Context, in *SubscribeHtlcEventsRequest, opts ...grpc.CallOption) (Router_SubscribeHtlcEventsClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[2], "/routerrpc.Router/SubscribeHtlcEvents", opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[3], "/routerrpc.Router/SubscribeHtlcEvents", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -290,7 +329,7 @@ func (x *routerSubscribeHtlcEventsClient) Recv() (*HtlcEvent, error) {
|
|||
|
||||
// Deprecated: Do not use.
|
||||
func (c *routerClient) SendPayment(ctx context.Context, in *SendPaymentRequest, opts ...grpc.CallOption) (Router_SendPaymentClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[3], "/routerrpc.Router/SendPayment", opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[4], "/routerrpc.Router/SendPayment", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -323,7 +362,7 @@ func (x *routerSendPaymentClient) Recv() (*PaymentStatus, error) {
|
|||
|
||||
// Deprecated: Do not use.
|
||||
func (c *routerClient) TrackPayment(ctx context.Context, in *TrackPaymentRequest, opts ...grpc.CallOption) (Router_TrackPaymentClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[4], "/routerrpc.Router/TrackPayment", opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[5], "/routerrpc.Router/TrackPayment", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -355,7 +394,7 @@ func (x *routerTrackPaymentClient) Recv() (*PaymentStatus, error) {
|
|||
}
|
||||
|
||||
func (c *routerClient) HtlcInterceptor(ctx context.Context, opts ...grpc.CallOption) (Router_HtlcInterceptorClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[5], "/routerrpc.Router/HtlcInterceptor", opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &Router_ServiceDesc.Streams[6], "/routerrpc.Router/HtlcInterceptor", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -405,6 +444,13 @@ type RouterServer interface {
|
|||
// TrackPaymentV2 returns an update stream for the payment identified by the
|
||||
// payment hash.
|
||||
TrackPaymentV2(*TrackPaymentRequest, Router_TrackPaymentV2Server) error
|
||||
// TrackPayments returns an update stream for every payment that is not in a
|
||||
// terminal state. Note that if payments are in-flight while starting a new
|
||||
// subscription, the start of the payment stream could produce out-of-order
|
||||
// and/or duplicate events. In order to get updates for every in-flight
|
||||
// payment attempt make sure to subscribe to this method before initiating any
|
||||
// payments.
|
||||
TrackPayments(*TrackPaymentsRequest, Router_TrackPaymentsServer) error
|
||||
// EstimateRouteFee allows callers to obtain a lower bound w.r.t how much it
|
||||
// may cost to send an HTLC to the target end destination.
|
||||
EstimateRouteFee(context.Context, *RouteFeeRequest) (*RouteFeeResponse, error)
|
||||
|
@ -483,6 +529,9 @@ func (UnimplementedRouterServer) SendPaymentV2(*SendPaymentRequest, Router_SendP
|
|||
func (UnimplementedRouterServer) TrackPaymentV2(*TrackPaymentRequest, Router_TrackPaymentV2Server) error {
|
||||
return status.Errorf(codes.Unimplemented, "method TrackPaymentV2 not implemented")
|
||||
}
|
||||
func (UnimplementedRouterServer) TrackPayments(*TrackPaymentsRequest, Router_TrackPaymentsServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method TrackPayments not implemented")
|
||||
}
|
||||
func (UnimplementedRouterServer) EstimateRouteFee(context.Context, *RouteFeeRequest) (*RouteFeeResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method EstimateRouteFee not implemented")
|
||||
}
|
||||
|
@ -583,6 +632,27 @@ func (x *routerTrackPaymentV2Server) Send(m *lnrpc.Payment) error {
|
|||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func _Router_TrackPayments_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(TrackPaymentsRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(RouterServer).TrackPayments(m, &routerTrackPaymentsServer{stream})
|
||||
}
|
||||
|
||||
type Router_TrackPaymentsServer interface {
|
||||
Send(*lnrpc.Payment) error
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type routerTrackPaymentsServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *routerTrackPaymentsServer) Send(m *lnrpc.Payment) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func _Router_EstimateRouteFee_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RouteFeeRequest)
|
||||
if err := dec(in); err != nil {
|
||||
|
@ -933,6 +1003,11 @@ var Router_ServiceDesc = grpc.ServiceDesc{
|
|||
Handler: _Router_TrackPaymentV2_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
{
|
||||
StreamName: "TrackPayments",
|
||||
Handler: _Router_TrackPayments_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
{
|
||||
StreamName: "SubscribeHtlcEvents",
|
||||
Handler: _Router_SubscribeHtlcEvents_Handler,
|
||||
|
|
|
@ -72,6 +72,10 @@ var (
|
|||
Entity: "offchain",
|
||||
Action: "read",
|
||||
}},
|
||||
"/routerrpc.Router/TrackPayments": {{
|
||||
Entity: "offchain",
|
||||
Action: "read",
|
||||
}},
|
||||
"/routerrpc.Router/EstimateRouteFee": {{
|
||||
Entity: "offchain",
|
||||
Action: "read",
|
||||
|
@ -737,22 +741,65 @@ func (s *Server) trackPayment(identifier lntypes.Hash,
|
|||
router := s.cfg.RouterBackend
|
||||
|
||||
// Subscribe to the outcome of this payment.
|
||||
subscription, err := router.Tower.SubscribePayment(
|
||||
identifier,
|
||||
)
|
||||
subscription, err := router.Tower.SubscribePayment(identifier)
|
||||
|
||||
switch {
|
||||
case err == channeldb.ErrPaymentNotInitiated:
|
||||
return status.Error(codes.NotFound, err.Error())
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
// Stream updates to the client.
|
||||
err = s.trackPaymentStream(
|
||||
stream.Context(), subscription, noInflightUpdates, stream.Send,
|
||||
)
|
||||
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("Payment stream %v canceled", identifier)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// TrackPayments returns a stream of payment state updates.
|
||||
func (s *Server) TrackPayments(request *TrackPaymentsRequest,
|
||||
stream Router_TrackPaymentsServer) error {
|
||||
|
||||
log.Debug("TrackPayments called")
|
||||
|
||||
router := s.cfg.RouterBackend
|
||||
|
||||
// Subscribe to payments.
|
||||
subscription, err := router.Tower.SubscribeAllPayments()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Stream updates to the client.
|
||||
err = s.trackPaymentStream(
|
||||
stream.Context(), subscription, request.NoInflightUpdates,
|
||||
stream.Send,
|
||||
)
|
||||
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("TrackPayments payment stream canceled.")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// trackPaymentStream streams payment updates to the client.
|
||||
func (s *Server) trackPaymentStream(context context.Context,
|
||||
subscription routing.ControlTowerSubscriber, noInflightUpdates bool,
|
||||
send func(*lnrpc.Payment) error) error {
|
||||
|
||||
defer subscription.Close()
|
||||
|
||||
// Stream updates back to the client. The first update is always the
|
||||
// current state of the payment.
|
||||
// Stream updates back to the client.
|
||||
for {
|
||||
select {
|
||||
case item, ok := <-subscription.Updates:
|
||||
case item, ok := <-subscription.Updates():
|
||||
if !ok {
|
||||
// No more payment updates.
|
||||
return nil
|
||||
|
@ -766,13 +813,15 @@ func (s *Server) trackPayment(identifier lntypes.Hash,
|
|||
continue
|
||||
}
|
||||
|
||||
rpcPayment, err := router.MarshallPayment(result)
|
||||
rpcPayment, err := s.cfg.RouterBackend.MarshallPayment(
|
||||
result,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send event to the client.
|
||||
err = stream.Send(rpcPayment)
|
||||
err = send(rpcPayment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -780,9 +829,8 @@ func (s *Server) trackPayment(identifier lntypes.Hash,
|
|||
case <-s.quit:
|
||||
return errServerShuttingDown
|
||||
|
||||
case <-stream.Context().Done():
|
||||
log.Debugf("Payment status stream %v canceled", identifier)
|
||||
return stream.Context().Err()
|
||||
case <-context.Done():
|
||||
return context.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
216
lnrpc/routerrpc/router_server_test.go
Normal file
216
lnrpc/routerrpc/router_server_test.go
Normal file
|
@ -0,0 +1,216 @@
|
|||
package routerrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/queue"
|
||||
"github.com/lightningnetwork/lnd/routing"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type streamMock struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
sentFromServer chan *lnrpc.Payment
|
||||
}
|
||||
|
||||
func makeStreamMock(ctx context.Context) *streamMock {
|
||||
return &streamMock{
|
||||
ctx: ctx,
|
||||
sentFromServer: make(chan *lnrpc.Payment, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *streamMock) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *streamMock) Send(p *lnrpc.Payment) error {
|
||||
m.sentFromServer <- p
|
||||
return nil
|
||||
}
|
||||
|
||||
type controlTowerSubscriberMock struct {
|
||||
updates <-chan interface{}
|
||||
}
|
||||
|
||||
func (s controlTowerSubscriberMock) Updates() <-chan interface{} {
|
||||
return s.updates
|
||||
}
|
||||
|
||||
func (s controlTowerSubscriberMock) Close() {
|
||||
}
|
||||
|
||||
type controlTowerMock struct {
|
||||
queue *queue.ConcurrentQueue
|
||||
routing.ControlTower
|
||||
}
|
||||
|
||||
func makeControlTowerMock() *controlTowerMock {
|
||||
towerMock := &controlTowerMock{
|
||||
queue: queue.NewConcurrentQueue(20),
|
||||
}
|
||||
towerMock.queue.Start()
|
||||
|
||||
return towerMock
|
||||
}
|
||||
|
||||
func (t *controlTowerMock) SubscribeAllPayments() (
|
||||
routing.ControlTowerSubscriber, error) {
|
||||
|
||||
return &controlTowerSubscriberMock{
|
||||
updates: t.queue.ChanOut(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestTrackPaymentsReturnsOnCancelContext tests whether TrackPayments returns
|
||||
// when the stream context is cancelled.
|
||||
func TestTrackPaymentsReturnsOnCancelContext(t *testing.T) {
|
||||
// Setup mocks and request.
|
||||
request := &TrackPaymentsRequest{
|
||||
NoInflightUpdates: false,
|
||||
}
|
||||
towerMock := makeControlTowerMock()
|
||||
|
||||
streamCtx, cancelStream := context.WithCancel(context.Background())
|
||||
stream := makeStreamMock(streamCtx)
|
||||
|
||||
server := &Server{
|
||||
cfg: &Config{
|
||||
RouterBackend: &RouterBackend{
|
||||
Tower: towerMock,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Cancel stream immediately
|
||||
cancelStream()
|
||||
|
||||
// Make sure the call returns.
|
||||
err := server.TrackPayments(request, stream)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
}
|
||||
|
||||
// TestTrackPaymentsInflightUpdate tests whether all updates from the control
|
||||
// tower are propagated to the client.
|
||||
func TestTrackPaymentsInflightUpdates(t *testing.T) {
|
||||
// Setup mocks and request.
|
||||
request := &TrackPaymentsRequest{
|
||||
NoInflightUpdates: false,
|
||||
}
|
||||
towerMock := makeControlTowerMock()
|
||||
|
||||
streamCtx, cancelStream := context.WithCancel(context.Background())
|
||||
stream := makeStreamMock(streamCtx)
|
||||
defer cancelStream()
|
||||
|
||||
server := &Server{
|
||||
cfg: &Config{
|
||||
RouterBackend: &RouterBackend{
|
||||
Tower: towerMock,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Listen to payment updates in a goroutine.
|
||||
go func() {
|
||||
err := server.TrackPayments(request, stream)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
}()
|
||||
|
||||
// Enqueue some payment updates on the mock.
|
||||
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{},
|
||||
Status: channeldb.StatusInFlight,
|
||||
}
|
||||
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{},
|
||||
Status: channeldb.StatusSucceeded,
|
||||
}
|
||||
|
||||
// Wait until there's 2 updates or the deadline is exceeded.
|
||||
deadline := time.Now().Add(1 * time.Second)
|
||||
for {
|
||||
if len(stream.sentFromServer) == 2 {
|
||||
break
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
require.FailNow(t, "deadline exceeded.")
|
||||
}
|
||||
}
|
||||
|
||||
// Both updates should be sent to the client.
|
||||
require.Len(t, stream.sentFromServer, 2)
|
||||
|
||||
// The updates should be in the right order.
|
||||
payment := <-stream.sentFromServer
|
||||
require.Equal(t, lnrpc.Payment_IN_FLIGHT, payment.Status)
|
||||
payment = <-stream.sentFromServer
|
||||
require.Equal(t, lnrpc.Payment_SUCCEEDED, payment.Status)
|
||||
}
|
||||
|
||||
// TestTrackPaymentsInflightUpdate tests whether only final updates from the
|
||||
// control tower are propagated to the client when noInflightUpdates = true.
|
||||
func TestTrackPaymentsNoInflightUpdates(t *testing.T) {
|
||||
// Setup mocks and request.
|
||||
request := &TrackPaymentsRequest{
|
||||
NoInflightUpdates: true,
|
||||
}
|
||||
towerMock := &controlTowerMock{
|
||||
queue: queue.NewConcurrentQueue(20),
|
||||
}
|
||||
towerMock.queue.Start()
|
||||
|
||||
streamCtx, cancelStream := context.WithCancel(context.Background())
|
||||
stream := makeStreamMock(streamCtx)
|
||||
defer cancelStream()
|
||||
|
||||
server := &Server{
|
||||
cfg: &Config{
|
||||
RouterBackend: &RouterBackend{
|
||||
Tower: towerMock,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Listen to payment updates in a goroutine.
|
||||
go func() {
|
||||
err := server.TrackPayments(request, stream)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
}()
|
||||
|
||||
// Enqueue some payment updates on the mock.
|
||||
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{},
|
||||
Status: channeldb.StatusInFlight,
|
||||
}
|
||||
towerMock.queue.ChanIn() <- &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{},
|
||||
Status: channeldb.StatusSucceeded,
|
||||
}
|
||||
|
||||
// Wait until there's 1 update or the deadline is exceeded.
|
||||
deadline := time.Now().Add(1 * time.Second)
|
||||
for {
|
||||
if len(stream.sentFromServer) == 1 {
|
||||
break
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
require.FailNow(t, "deadline exceeded.")
|
||||
}
|
||||
}
|
||||
|
||||
// Only 1 update should be sent to the client.
|
||||
require.Len(t, stream.sentFromServer, 1)
|
||||
|
||||
// Only the final states should be sent to the client.
|
||||
payment := <-stream.sentFromServer
|
||||
require.Equal(t, lnrpc.Payment_SUCCEEDED, payment.Status)
|
||||
}
|
|
@ -439,4 +439,8 @@ var allTestCases = []*testCase{
|
|||
name: "taproot coop close",
|
||||
test: testTaprootCoopClose,
|
||||
},
|
||||
{
|
||||
name: "trackpayments",
|
||||
test: testTrackPayments,
|
||||
},
|
||||
}
|
||||
|
|
102
lntest/itest/lnd_trackpayments_test.go
Normal file
102
lntest/itest/lnd_trackpayments_test.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package itest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/btcsuite/btcd/btcutil"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
|
||||
"github.com/lightningnetwork/lnd/lntest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testTrackPayments tests whether a client that calls the TrackPayments api
|
||||
// receives payment updates.
|
||||
func testTrackPayments(net *lntest.NetworkHarness, t *harnessTest) {
|
||||
// Open a channel between alice and bob.
|
||||
net.EnsureConnected(t.t, net.Alice, net.Bob)
|
||||
channel := openChannelAndAssert(
|
||||
t, net, net.Alice, net.Bob,
|
||||
lntest.OpenChannelParams{
|
||||
Amt: btcutil.Amount(300000),
|
||||
},
|
||||
)
|
||||
defer closeChannelAndAssert(t, net, net.Alice, channel, true)
|
||||
|
||||
err := net.Alice.WaitForNetworkChannelOpen(channel)
|
||||
require.NoError(t.t, err, "unable to wait for channel to open")
|
||||
|
||||
ctxb := context.Background()
|
||||
ctxt, cancelTracker := context.WithCancel(ctxb)
|
||||
defer cancelTracker()
|
||||
|
||||
// Call the TrackPayments api to listen for payment updates.
|
||||
tracker, err := net.Alice.RouterClient.TrackPayments(
|
||||
ctxt,
|
||||
&routerrpc.TrackPaymentsRequest{
|
||||
NoInflightUpdates: false,
|
||||
},
|
||||
)
|
||||
require.NoError(t.t, err, "failed to call TrackPayments successfully.")
|
||||
|
||||
// Create an invoice from bob.
|
||||
var amountMsat int64 = 1000
|
||||
invoiceResp, err := net.Bob.AddInvoice(
|
||||
ctxb,
|
||||
&lnrpc.Invoice{
|
||||
ValueMsat: amountMsat,
|
||||
},
|
||||
)
|
||||
require.NoError(t.t, err, "unable to add invoice.")
|
||||
|
||||
invoice, err := net.Bob.LookupInvoice(
|
||||
ctxb,
|
||||
&lnrpc.PaymentHash{
|
||||
RHashStr: hex.EncodeToString(invoiceResp.RHash),
|
||||
},
|
||||
)
|
||||
require.NoError(t.t, err, "unable to find invoice.")
|
||||
|
||||
// Send payment from alice to bob.
|
||||
paymentClient, err := net.Alice.RouterClient.SendPaymentV2(
|
||||
ctxb,
|
||||
&routerrpc.SendPaymentRequest{
|
||||
PaymentRequest: invoice.PaymentRequest,
|
||||
TimeoutSeconds: 60,
|
||||
},
|
||||
)
|
||||
require.NoError(t.t, err, "unable to send payment.")
|
||||
|
||||
// Make sure the payment doesn't error due to invalid parameters or so.
|
||||
_, err = paymentClient.Recv()
|
||||
require.NoError(t.t, err, "unable to get payment update.")
|
||||
|
||||
// Assert the first payment update is an inflight update.
|
||||
update1, err := tracker.Recv()
|
||||
require.NoError(t.t, err, "unable to receive payment update 1.")
|
||||
|
||||
require.Equal(
|
||||
t.t, lnrpc.PaymentFailureReason_FAILURE_REASON_NONE,
|
||||
update1.FailureReason,
|
||||
)
|
||||
require.Equal(t.t, lnrpc.Payment_IN_FLIGHT, update1.Status)
|
||||
require.Equal(t.t, invoice.PaymentRequest, update1.PaymentRequest)
|
||||
require.Equal(t.t, amountMsat, update1.ValueMsat)
|
||||
|
||||
// Assert the second payment update is a payment success update.
|
||||
update2, err := tracker.Recv()
|
||||
require.NoError(t.t, err, "unable to receive payment update 2.")
|
||||
|
||||
require.Equal(
|
||||
t.t, lnrpc.PaymentFailureReason_FAILURE_REASON_NONE,
|
||||
update2.FailureReason,
|
||||
)
|
||||
require.Equal(t.t, lnrpc.Payment_SUCCEEDED, update2.Status)
|
||||
require.Equal(t.t, invoice.PaymentRequest, update2.PaymentRequest)
|
||||
require.Equal(t.t, amountMsat, update2.ValueMsat)
|
||||
require.Equal(
|
||||
t.t, hex.EncodeToString(invoice.RPreimage),
|
||||
update2.PaymentPreimage,
|
||||
)
|
||||
}
|
|
@ -60,35 +60,48 @@ type ControlTower interface {
|
|||
// SubscribePayment subscribes to updates for the payment with the given
|
||||
// hash. A first update with the current state of the payment is always
|
||||
// sent out immediately.
|
||||
SubscribePayment(paymentHash lntypes.Hash) (*ControlTowerSubscriber,
|
||||
SubscribePayment(paymentHash lntypes.Hash) (ControlTowerSubscriber,
|
||||
error)
|
||||
|
||||
// SubscribeAllPayments subscribes to updates for all payments. A first
|
||||
// update with the current state of every inflight payment is always
|
||||
// sent out immediately.
|
||||
SubscribeAllPayments() (ControlTowerSubscriber, error)
|
||||
}
|
||||
|
||||
// ControlTowerSubscriber contains the state for a payment update subscriber.
|
||||
type ControlTowerSubscriber struct {
|
||||
type ControlTowerSubscriber interface {
|
||||
// Updates is the channel over which *channeldb.MPPayment updates can be
|
||||
// received.
|
||||
Updates <-chan interface{}
|
||||
Updates() <-chan interface{}
|
||||
|
||||
queue *queue.ConcurrentQueue
|
||||
quit chan struct{}
|
||||
// Close signals that the subscriber is no longer interested in updates.
|
||||
Close()
|
||||
}
|
||||
|
||||
// ControlTowerSubscriberImpl contains the state for a payment update
|
||||
// subscriber.
|
||||
type controlTowerSubscriberImpl struct {
|
||||
updates <-chan interface{}
|
||||
queue *queue.ConcurrentQueue
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// newControlTowerSubscriber instantiates a new subscriber state object.
|
||||
func newControlTowerSubscriber() *ControlTowerSubscriber {
|
||||
func newControlTowerSubscriber() *controlTowerSubscriberImpl {
|
||||
// Create a queue for payment updates.
|
||||
queue := queue.NewConcurrentQueue(20)
|
||||
queue.Start()
|
||||
|
||||
return &ControlTowerSubscriber{
|
||||
Updates: queue.ChanOut(),
|
||||
return &controlTowerSubscriberImpl{
|
||||
updates: queue.ChanOut(),
|
||||
queue: queue,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Close signals that the subscriber is no longer interested in updates.
|
||||
func (s *ControlTowerSubscriber) Close() {
|
||||
func (s *controlTowerSubscriberImpl) Close() {
|
||||
// Close quit channel so that any pending writes to the queue are
|
||||
// cancelled.
|
||||
close(s.quit)
|
||||
|
@ -97,13 +110,24 @@ func (s *ControlTowerSubscriber) Close() {
|
|||
s.queue.Stop()
|
||||
}
|
||||
|
||||
// Updates is the channel over which *channeldb.MPPayment updates can be
|
||||
// received.
|
||||
func (s *controlTowerSubscriberImpl) Updates() <-chan interface{} {
|
||||
return s.updates
|
||||
}
|
||||
|
||||
// controlTower is persistent implementation of ControlTower to restrict
|
||||
// double payment sending.
|
||||
type controlTower struct {
|
||||
db *channeldb.PaymentControl
|
||||
|
||||
subscribers map[lntypes.Hash][]*ControlTowerSubscriber
|
||||
subscribersMtx sync.Mutex
|
||||
// subscriberIndex is used to provide a unique id for each subscriber
|
||||
// to all payments. This is used to easily remove the subscriber when
|
||||
// necessary.
|
||||
subscriberIndex uint64
|
||||
subscribersAllPayments map[uint64]*controlTowerSubscriberImpl
|
||||
subscribers map[lntypes.Hash][]*controlTowerSubscriberImpl
|
||||
subscribersMtx sync.Mutex
|
||||
|
||||
// paymentsMtx provides synchronization on the payment level to ensure
|
||||
// that no race conditions occur in between updating the database and
|
||||
|
@ -114,8 +138,11 @@ type controlTower struct {
|
|||
// NewControlTower creates a new instance of the controlTower.
|
||||
func NewControlTower(db *channeldb.PaymentControl) ControlTower {
|
||||
return &controlTower{
|
||||
db: db,
|
||||
subscribers: make(map[lntypes.Hash][]*ControlTowerSubscriber),
|
||||
db: db,
|
||||
subscribersAllPayments: make(
|
||||
map[uint64]*controlTowerSubscriberImpl,
|
||||
),
|
||||
subscribers: make(map[lntypes.Hash][]*controlTowerSubscriberImpl),
|
||||
paymentsMtx: multimutex.NewHashMutex(),
|
||||
}
|
||||
}
|
||||
|
@ -232,7 +259,7 @@ func (p *controlTower) FetchInFlightPayments() ([]*channeldb.MPPayment, error) {
|
|||
// first update with the current state of the payment is always sent out
|
||||
// immediately.
|
||||
func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
|
||||
*ControlTowerSubscriber, error) {
|
||||
ControlTowerSubscriber, error) {
|
||||
|
||||
// Take lock before querying the db to prevent missing or duplicating an
|
||||
// update.
|
||||
|
@ -266,6 +293,39 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
|
|||
return subscriber, nil
|
||||
}
|
||||
|
||||
// SubscribeAllPayments subscribes to updates for all inflight payments. A first
|
||||
// update with the current state of every inflight payment is always sent out
|
||||
// immediately.
|
||||
// Note: If payments are in-flight while starting a new subscription, the start
|
||||
// of the payment stream could produce out-of-order and/or duplicate events. In
|
||||
// order to get updates for every in-flight payment attempt make sure to
|
||||
// subscribe to this method before initiating any payments.
|
||||
func (p *controlTower) SubscribeAllPayments() (ControlTowerSubscriber, error) {
|
||||
subscriber := newControlTowerSubscriber()
|
||||
|
||||
// Add the subscriber to the list before fetching in-flight payments, so
|
||||
// no events are missed. If a payment attempt update occurs after
|
||||
// appending and before fetching in-flight payments, an out-of-order
|
||||
// duplicate may be produced, because it is then fetched in below call
|
||||
// and notified through the subscription.
|
||||
p.subscribersMtx.Lock()
|
||||
p.subscribersAllPayments[p.subscriberIndex] = subscriber
|
||||
p.subscriberIndex++
|
||||
p.subscribersMtx.Unlock()
|
||||
|
||||
inflightPayments, err := p.db.FetchInFlightPayments()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for index := range inflightPayments {
|
||||
// Always write current payment state to the channel.
|
||||
subscriber.queue.ChanIn() <- inflightPayments[index]
|
||||
}
|
||||
|
||||
return subscriber, nil
|
||||
}
|
||||
|
||||
// notifySubscribers sends a final payment event to all subscribers of this
|
||||
// payment. The channel will be closed after this. Note that this function must
|
||||
// be executed atomically (by means of a lock) with the database update to
|
||||
|
@ -275,8 +335,9 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
|
|||
|
||||
// Get all subscribers for this payment.
|
||||
p.subscribersMtx.Lock()
|
||||
list, ok := p.subscribers[paymentHash]
|
||||
if !ok {
|
||||
|
||||
subscribersPaymentHash, ok := p.subscribers[paymentHash]
|
||||
if !ok && len(p.subscribersAllPayments) == 0 {
|
||||
p.subscribersMtx.Unlock()
|
||||
return
|
||||
}
|
||||
|
@ -287,10 +348,17 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
|
|||
if terminal {
|
||||
delete(p.subscribers, paymentHash)
|
||||
}
|
||||
|
||||
// Copy subscribers to all payments locally while holding the lock in
|
||||
// order to avoid concurrency issues while reading/writing the map.
|
||||
subscribersAllPayments := make(map[uint64]*controlTowerSubscriberImpl)
|
||||
for k, v := range p.subscribersAllPayments {
|
||||
subscribersAllPayments[k] = v
|
||||
}
|
||||
p.subscribersMtx.Unlock()
|
||||
|
||||
// Notify all subscribers of the event.
|
||||
for _, subscriber := range list {
|
||||
// Notify all subscribers that subscribed to the current payment hash.
|
||||
for _, subscriber := range subscribersPaymentHash {
|
||||
select {
|
||||
case subscriber.queue.ChanIn() <- event:
|
||||
// If this event is the last, close the incoming channel
|
||||
|
@ -305,4 +373,18 @@ func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
|
|||
case <-subscriber.quit:
|
||||
}
|
||||
}
|
||||
|
||||
// Notify all subscribers that subscribed to all payments.
|
||||
for key, subscriber := range subscribersAllPayments {
|
||||
select {
|
||||
case subscriber.queue.ChanIn() <- event:
|
||||
|
||||
// If subscriber disappeared, remove it from the subscribers
|
||||
// list.
|
||||
case <-subscriber.quit:
|
||||
p.subscribersMtx.Lock()
|
||||
delete(p.subscribersAllPayments, key)
|
||||
p.subscribersMtx.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -115,7 +115,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
|
|||
|
||||
// We expect all subscribers to now report the final outcome followed by
|
||||
// no other events.
|
||||
subscribers := []*ControlTowerSubscriber{
|
||||
subscribers := []ControlTowerSubscriber{
|
||||
subscriber1, subscriber2, subscriber3,
|
||||
}
|
||||
|
||||
|
@ -123,7 +123,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
|
|||
var result *channeldb.MPPayment
|
||||
for result == nil || result.Status == channeldb.StatusInFlight {
|
||||
select {
|
||||
case item := <-s.Updates:
|
||||
case item := <-s.Updates():
|
||||
result = item.(*channeldb.MPPayment)
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("timeout waiting for payment result")
|
||||
|
@ -149,7 +149,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
|
|||
|
||||
// After the final event, we expect the channel to be closed.
|
||||
select {
|
||||
case _, ok := <-s.Updates:
|
||||
case _, ok := <-s.Updates():
|
||||
if ok {
|
||||
t.Fatal("expected channel to be closed")
|
||||
}
|
||||
|
@ -178,6 +178,236 @@ func TestPaymentControlSubscribeFail(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// TestPaymentControlSubscribeAllSuccess tests that multiple payments are
|
||||
// properly sent to subscribers of TrackPayments.
|
||||
func TestPaymentControlSubscribeAllSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := initDB(t, true)
|
||||
require.NoError(t, err, "unable to init db: %v")
|
||||
|
||||
pControl := NewControlTower(channeldb.NewPaymentControl(db))
|
||||
|
||||
// Initiate a payment.
|
||||
info1, attempt1, preimg1, err := genInfo()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pControl.InitPayment(info1.PaymentIdentifier, info1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Subscription should succeed and immediately report the InFlight
|
||||
// status.
|
||||
subscription, err := pControl.SubscribeAllPayments()
|
||||
require.NoError(t, err, "expected subscribe to succeed, but got: %v")
|
||||
|
||||
// Register an attempt.
|
||||
err = pControl.RegisterAttempt(info1.PaymentIdentifier, attempt1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initiate a second payment after the subscription is already active.
|
||||
info2, attempt2, preimg2, err := genInfo()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pControl.InitPayment(info2.PaymentIdentifier, info2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Register an attempt on the second payment.
|
||||
err = pControl.RegisterAttempt(info2.PaymentIdentifier, attempt2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark the first payment as successful.
|
||||
settleInfo1 := channeldb.HTLCSettleInfo{
|
||||
Preimage: preimg1,
|
||||
}
|
||||
htlcAttempt1, err := pControl.SettleAttempt(
|
||||
info1.PaymentIdentifier, attempt1.AttemptID, &settleInfo1,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(
|
||||
t, settleInfo1, *htlcAttempt1.Settle,
|
||||
"unexpected settle info returned",
|
||||
)
|
||||
|
||||
// Mark the second payment as successful.
|
||||
settleInfo2 := channeldb.HTLCSettleInfo{
|
||||
Preimage: preimg2,
|
||||
}
|
||||
htlcAttempt2, err := pControl.SettleAttempt(
|
||||
info2.PaymentIdentifier, attempt2.AttemptID, &settleInfo2,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(
|
||||
t, settleInfo2, *htlcAttempt2.Settle,
|
||||
"unexpected fail info returned",
|
||||
)
|
||||
|
||||
// The two payments will be asserted individually, store the last update
|
||||
// for each payment.
|
||||
results := make(map[lntypes.Hash]*channeldb.MPPayment)
|
||||
|
||||
// After exactly 5 updates both payments will/should have completed.
|
||||
for i := 0; i < 5; i++ {
|
||||
select {
|
||||
case item := <-subscription.Updates():
|
||||
id := item.(*channeldb.MPPayment).Info.PaymentIdentifier
|
||||
results[id] = item.(*channeldb.MPPayment)
|
||||
case <-time.After(testTimeout):
|
||||
require.Fail(t, "timeout waiting for payment result")
|
||||
}
|
||||
}
|
||||
|
||||
result1 := results[info1.PaymentIdentifier]
|
||||
require.Equal(
|
||||
t, channeldb.StatusSucceeded, result1.Status,
|
||||
"unexpected payment state payment 1",
|
||||
)
|
||||
|
||||
settle1, _ := result1.TerminalInfo()
|
||||
require.Equal(
|
||||
t, preimg1, settle1.Preimage, "unexpected preimage payment 1",
|
||||
)
|
||||
|
||||
require.Len(
|
||||
t, result1.HTLCs, 1, "expect 1 htlc for payment 1, got %d",
|
||||
len(result1.HTLCs),
|
||||
)
|
||||
|
||||
htlc1 := result1.HTLCs[0]
|
||||
require.Equal(t, attempt1.Route, htlc1.Route, "unexpected htlc route.")
|
||||
|
||||
result2 := results[info2.PaymentIdentifier]
|
||||
require.Equal(
|
||||
t, channeldb.StatusSucceeded, result2.Status,
|
||||
"unexpected payment state payment 2",
|
||||
)
|
||||
|
||||
settle2, _ := result2.TerminalInfo()
|
||||
require.Equal(
|
||||
t, preimg2, settle2.Preimage, "unexpected preimage payment 2",
|
||||
)
|
||||
require.Len(
|
||||
t, result2.HTLCs, 1, "expect 1 htlc for payment 2, got %d",
|
||||
len(result2.HTLCs),
|
||||
)
|
||||
|
||||
htlc2 := result2.HTLCs[0]
|
||||
require.Equal(t, attempt2.Route, htlc2.Route, "unexpected htlc route.")
|
||||
}
|
||||
|
||||
// TestPaymentControlSubscribeAllImmediate tests whether already inflight
|
||||
// payments are reported at the start of the SubscribeAllPayments subscription.
|
||||
func TestPaymentControlSubscribeAllImmediate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := initDB(t, true)
|
||||
require.NoError(t, err, "unable to init db: %v")
|
||||
|
||||
pControl := NewControlTower(channeldb.NewPaymentControl(db))
|
||||
|
||||
// Initiate a payment.
|
||||
info, attempt, _, err := genInfo()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pControl.InitPayment(info.PaymentIdentifier, info)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Register a payment update.
|
||||
err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscription, err := pControl.SubscribeAllPayments()
|
||||
require.NoError(t, err, "expected subscribe to succeed, but got: %v")
|
||||
|
||||
// Assert the new subscription receives the old update.
|
||||
select {
|
||||
case update := <-subscription.Updates():
|
||||
require.NotNil(t, update)
|
||||
require.Equal(
|
||||
t, info.PaymentIdentifier,
|
||||
update.(*channeldb.MPPayment).Info.PaymentIdentifier,
|
||||
)
|
||||
require.Len(t, subscription.Updates(), 0)
|
||||
case <-time.After(testTimeout):
|
||||
require.Fail(t, "timeout waiting for payment result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaymentControlUnsubscribeSuccess tests that when unsubscribed, there are
|
||||
// no more notifications to that specific subscription.
|
||||
func TestPaymentControlUnsubscribeSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := initDB(t, true)
|
||||
require.NoError(t, err, "unable to init db: %v")
|
||||
|
||||
pControl := NewControlTower(channeldb.NewPaymentControl(db))
|
||||
|
||||
subscription1, err := pControl.SubscribeAllPayments()
|
||||
require.NoError(t, err, "expected subscribe to succeed, but got: %v")
|
||||
|
||||
subscription2, err := pControl.SubscribeAllPayments()
|
||||
require.NoError(t, err, "expected subscribe to succeed, but got: %v")
|
||||
|
||||
// Initiate a payment.
|
||||
info, attempt, _, err := genInfo()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pControl.InitPayment(info.PaymentIdentifier, info)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Register a payment update.
|
||||
err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert all subscriptions receive the update.
|
||||
select {
|
||||
case update1 := <-subscription1.Updates():
|
||||
require.NotNil(t, update1)
|
||||
case <-time.After(testTimeout):
|
||||
require.Fail(t, "timeout waiting for payment result")
|
||||
}
|
||||
|
||||
select {
|
||||
case update2 := <-subscription2.Updates():
|
||||
require.NotNil(t, update2)
|
||||
case <-time.After(testTimeout):
|
||||
require.Fail(t, "timeout waiting for payment result")
|
||||
}
|
||||
|
||||
// Close the first subscription.
|
||||
subscription1.Close()
|
||||
|
||||
// Register another update.
|
||||
failInfo := channeldb.HTLCFailInfo{
|
||||
Reason: channeldb.HTLCFailInternal,
|
||||
}
|
||||
_, err = pControl.FailAttempt(
|
||||
info.PaymentIdentifier, attempt.AttemptID, &failInfo,
|
||||
)
|
||||
require.NoError(t, err, "unable to fail htlc")
|
||||
|
||||
// Assert only subscription 2 receives the update.
|
||||
select {
|
||||
case update2 := <-subscription2.Updates():
|
||||
require.NotNil(t, update2)
|
||||
case <-time.After(testTimeout):
|
||||
require.Fail(t, "timeout waiting for payment result")
|
||||
}
|
||||
|
||||
require.Len(t, subscription1.Updates(), 0)
|
||||
|
||||
// Close the second subscription.
|
||||
subscription2.Close()
|
||||
|
||||
// Register a last update.
|
||||
err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert no subscriptions receive the update.
|
||||
require.Len(t, subscription1.Updates(), 0)
|
||||
require.Len(t, subscription2.Updates(), 0)
|
||||
}
|
||||
|
||||
func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
|
||||
keepFailedPaymentAttempts bool) {
|
||||
|
||||
|
@ -237,7 +467,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
|
|||
|
||||
// We expect both subscribers to now report the final outcome followed
|
||||
// by no other events.
|
||||
subscribers := []*ControlTowerSubscriber{
|
||||
subscribers := []ControlTowerSubscriber{
|
||||
subscriber1, subscriber2,
|
||||
}
|
||||
|
||||
|
@ -245,7 +475,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
|
|||
var result *channeldb.MPPayment
|
||||
for result == nil || result.Status == channeldb.StatusInFlight {
|
||||
select {
|
||||
case item := <-s.Updates:
|
||||
case item := <-s.Updates():
|
||||
result = item.(*channeldb.MPPayment)
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("timeout waiting for payment result")
|
||||
|
@ -283,7 +513,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
|
|||
|
||||
// After the final event, we expect the channel to be closed.
|
||||
select {
|
||||
case _, ok := <-s.Updates:
|
||||
case _, ok := <-s.Updates():
|
||||
if ok {
|
||||
t.Fatal("expected channel to be closed")
|
||||
}
|
||||
|
|
|
@ -552,7 +552,13 @@ func (m *mockControlTowerOld) FetchInFlightPayments() (
|
|||
}
|
||||
|
||||
func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (
|
||||
*ControlTowerSubscriber, error) {
|
||||
ControlTowerSubscriber, error) {
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockControlTowerOld) SubscribeAllPayments() (
|
||||
ControlTowerSubscriber, error) {
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
@ -768,10 +774,17 @@ func (m *mockControlTower) FetchInFlightPayments() (
|
|||
}
|
||||
|
||||
func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
|
||||
*ControlTowerSubscriber, error) {
|
||||
ControlTowerSubscriber, error) {
|
||||
|
||||
args := m.Called(paymentHash)
|
||||
return args.Get(0).(*ControlTowerSubscriber), args.Error(1)
|
||||
return args.Get(0).(ControlTowerSubscriber), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) SubscribeAllPayments() (
|
||||
ControlTowerSubscriber, error) {
|
||||
|
||||
args := m.Called()
|
||||
return args.Get(0).(ControlTowerSubscriber), args.Error(1)
|
||||
}
|
||||
|
||||
type mockLink struct {
|
||||
|
|
Loading…
Add table
Reference in a new issue