diff --git a/invoices/interface.go b/invoices/interface.go index f48aa37b6..4f5e08e32 100644 --- a/invoices/interface.go +++ b/invoices/interface.go @@ -207,3 +207,71 @@ type InvoiceUpdater interface { // Finalize finalizes the update before it is written to the database. Finalize(updateType UpdateType) error } + +// HtlcModifyRequest is the request that is passed to the client via callback +// during a HTLC interceptor session. The request contains the invoice that the +// given HTLC is attempting to settle. +type HtlcModifyRequest struct { + // WireCustomRecords are the custom records that were parsed from the + // HTLC wire message. These are the records of the current HTLC to be + // accepted/settled. All previously accepted/settled HTLCs for the same + // invoice are present in the Invoice field below. + WireCustomRecords lnwire.CustomRecords + + // ExitHtlcCircuitKey is the circuit key that identifies the HTLC which + // is involved in the invoice settlement. + ExitHtlcCircuitKey CircuitKey + + // ExitHtlcAmt is the amount of the HTLC which is involved in the + // invoice settlement. + ExitHtlcAmt lnwire.MilliSatoshi + + // ExitHtlcExpiry is the absolute expiry height of the HTLC which is + // involved in the invoice settlement. + ExitHtlcExpiry uint32 + + // CurrentHeight is the current block height. + CurrentHeight uint32 + + // Invoice is the invoice that is being intercepted. The HTLCs within + // the invoice are only those previously accepted/settled for the same + // invoice. + Invoice Invoice +} + +// HtlcModifyResponse is the response that the client should send back to the +// interceptor after processing the HTLC modify request. +type HtlcModifyResponse struct { + // AmountPaid is the amount that the client has decided the HTLC is + // actually worth. This might be different from the amount that the + // HTLC was originally sent with, in case additional value is carried + // along with it (which might be the case in custom channels). + AmountPaid lnwire.MilliSatoshi +} + +// HtlcModifyCallback is a function that is called when an invoice is +// intercepted by the invoice interceptor. +type HtlcModifyCallback func(HtlcModifyRequest) (*HtlcModifyResponse, error) + +// HtlcModifier is an interface that allows an intercept client to register +// itself as a modifier of HTLCs that are settling an invoice. The client can +// then modify the HTLCs based on the invoice and the HTLC that is settling it. +type HtlcModifier interface { + // RegisterInterceptor sets the client callback function that will be + // called when an invoice is intercepted. If a callback is already set, + // an error is returned. The returned function must be used to reset the + // callback to nil once the client is done or disconnects. The read-only + // channel closes when the server stops. + RegisterInterceptor(HtlcModifyCallback) (func(), <-chan struct{}, error) +} + +// HtlcInterceptor is an interface that allows the invoice registry to let +// clients intercept invoices before they are settled. +type HtlcInterceptor interface { + // Intercept generates a new intercept session for the given invoice. + // The call blocks until the client has responded to the request or an + // error occurs. The response callback is only called if a session was + // created in the first place, which is only the case if a client is + // registered. + Intercept(HtlcModifyRequest, func(HtlcModifyResponse)) error +} diff --git a/invoices/mock.go b/invoices/mock.go index 25c81a35d..5d929c227 100644 --- a/invoices/mock.go +++ b/invoices/mock.go @@ -83,3 +83,34 @@ func (m *MockInvoiceDB) DeleteCanceledInvoices(ctx context.Context) error { return args.Error(0) } + +// MockHtlcModifier is a mock implementation of the HtlcModifier interface. +type MockHtlcModifier struct { +} + +// Intercept generates a new intercept session for the given invoice. +// The call blocks until the client has responded to the request or an +// error occurs. The response callback is only called if a session was +// created in the first place, which is only the case if a client is +// registered. +func (m *MockHtlcModifier) Intercept( + _ HtlcModifyRequest, _ func(HtlcModifyResponse)) error { + + return nil +} + +// RegisterInterceptor sets the client callback function that will be +// called when an invoice is intercepted. If a callback is already set, +// an error is returned. The returned function must be used to reset the +// callback to nil once the client is done or disconnects. The read-only channel +// closes when the server stops. +func (m *MockHtlcModifier) RegisterInterceptor(HtlcModifyCallback) (func(), + <-chan struct{}, error) { + + return func() {}, make(chan struct{}), nil +} + +// Ensure that MockHtlcModifier implements the HtlcInterceptor and HtlcModifier +// interfaces. +var _ HtlcInterceptor = (*MockHtlcModifier)(nil) +var _ HtlcModifier = (*MockHtlcModifier)(nil) diff --git a/invoices/modification_interceptor.go b/invoices/modification_interceptor.go new file mode 100644 index 000000000..e2d9a9051 --- /dev/null +++ b/invoices/modification_interceptor.go @@ -0,0 +1,179 @@ +package invoices + +import ( + "errors" + "sync/atomic" + + "github.com/lightningnetwork/lnd/fn" +) + +var ( + // ErrInterceptorClientAlreadyConnected is an error that is returned + // when a client tries to connect to the interceptor service while + // another client is already connected. + ErrInterceptorClientAlreadyConnected = errors.New( + "interceptor client already connected", + ) + + // ErrInterceptorClientDisconnected is an error that is returned when + // the client disconnects during an interceptor session. + ErrInterceptorClientDisconnected = errors.New( + "interceptor client disconnected", + ) +) + +// safeCallback is a wrapper around a callback function that is safe for +// concurrent access. +type safeCallback struct { + // callback is the actual callback function that is called when an + // invoice is intercepted. This might be nil if no client is currently + // connected. + callback atomic.Pointer[HtlcModifyCallback] +} + +// Set atomically sets the callback function. If a callback is already set, an +// error is returned. The returned function can be used to reset the callback to +// nil once the client is done. +func (s *safeCallback) Set(callback HtlcModifyCallback) (func(), error) { + if !s.callback.CompareAndSwap(nil, &callback) { + return nil, ErrInterceptorClientAlreadyConnected + } + + return func() { + s.callback.Store(nil) + }, nil +} + +// IsConnected returns true if a client is currently connected. +func (s *safeCallback) IsConnected() bool { + return s.callback.Load() != nil +} + +// Exec executes the callback function if it is set. If the callback is not set, +// an error is returned. +func (s *safeCallback) Exec(req HtlcModifyRequest) (*HtlcModifyResponse, + error) { + + callback := s.callback.Load() + if callback == nil { + return nil, ErrInterceptorClientDisconnected + } + + return (*callback)(req) +} + +// HtlcModificationInterceptor is a service that intercepts HTLCs that aim to +// settle an invoice, enabling a subscribed client to modify certain aspects of +// those HTLCs. +type HtlcModificationInterceptor struct { + // callback is the wrapped client callback function that is called when + // an invoice is intercepted. This function gives the client the ability + // to determine how the invoice should be settled. + callback *safeCallback + + // quit is a channel that is closed when the interceptor is stopped. + quit chan struct{} +} + +// NewHtlcModificationInterceptor creates a new HtlcModificationInterceptor. +func NewHtlcModificationInterceptor() *HtlcModificationInterceptor { + return &HtlcModificationInterceptor{ + callback: &safeCallback{}, + } +} + +// Intercept generates a new intercept session for the given invoice. The call +// blocks until the client has responded to the request or an error occurs. The +// response callback is only called if a session was created in the first place, +// which is only the case if a client is registered. +func (s *HtlcModificationInterceptor) Intercept(clientRequest HtlcModifyRequest, + responseCallback func(HtlcModifyResponse)) error { + + // If there is no client callback set we will not handle the invoice + // further. + if !s.callback.IsConnected() { + log.Debugf("Not intercepting invoice with circuit key %v, no "+ + "intercept client connected", + clientRequest.ExitHtlcCircuitKey) + + return nil + } + + // We'll block until the client has responded to the request or an error + // occurs. + var ( + responseChan = make(chan *HtlcModifyResponse, 1) + errChan = make(chan error, 1) + ) + + // The callback function will block at the client's discretion. We will + // therefore execute it in a separate goroutine. We don't need a wait + // group because we wait for the response directly below. The caller + // needs to make sure they don't block indefinitely, by selecting on the + // quit channel they receive when registering the callback. + go func() { + log.Debugf("Waiting for client response from invoice HTLC "+ + "interceptor session with circuit key %v", + clientRequest.ExitHtlcCircuitKey) + + // By this point, we've already checked that the client callback + // is set. However, if the client disconnected since that check + // then Exec will return an error. + result, err := s.callback.Exec(clientRequest) + if err != nil { + _ = fn.SendOrQuit(errChan, err, s.quit) + + return + } + + _ = fn.SendOrQuit(responseChan, result, s.quit) + }() + + // Wait for the client to respond or an error to occur. + select { + case response := <-responseChan: + log.Debugf("Received invoice HTLC interceptor response: %v", + response) + + responseCallback(*response) + + return nil + + case err := <-errChan: + log.Errorf("Error from invoice HTLC interceptor session: %v", + err) + + return err + + case <-s.quit: + return ErrInterceptorClientDisconnected + } +} + +// RegisterInterceptor sets the client callback function that will be called +// when an invoice is intercepted. If a callback is already set, an error is +// returned. The returned function must be used to reset the callback to nil +// once the client is done or disconnects. +func (s *HtlcModificationInterceptor) RegisterInterceptor( + callback HtlcModifyCallback) (func(), <-chan struct{}, error) { + + done, err := s.callback.Set(callback) + return done, s.quit, err +} + +// Start starts the service. +func (s *HtlcModificationInterceptor) Start() error { + return nil +} + +// Stop stops the service. +func (s *HtlcModificationInterceptor) Stop() error { + close(s.quit) + + return nil +} + +// Ensure that HtlcModificationInterceptor implements the HtlcInterceptor and +// HtlcModifier interfaces. +var _ HtlcInterceptor = (*HtlcModificationInterceptor)(nil) +var _ HtlcModifier = (*HtlcModificationInterceptor)(nil) diff --git a/invoices/modification_interceptor_test.go b/invoices/modification_interceptor_test.go new file mode 100644 index 000000000..286a390ff --- /dev/null +++ b/invoices/modification_interceptor_test.go @@ -0,0 +1,107 @@ +package invoices + +import ( + "fmt" + "testing" + "time" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +var ( + defaultTimeout = 50 * time.Millisecond +) + +// TestHtlcModificationInterceptor tests the basic functionality of the HTLC +// modification interceptor. +func TestHtlcModificationInterceptor(t *testing.T) { + interceptor := NewHtlcModificationInterceptor() + request := HtlcModifyRequest{ + WireCustomRecords: lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3}, + }, + ExitHtlcCircuitKey: CircuitKey{ + ChanID: lnwire.NewShortChanIDFromInt(1), + HtlcID: 1, + }, + ExitHtlcAmt: 1234, + } + expectedResponse := HtlcModifyResponse{ + AmountPaid: 345, + } + interceptCallbackCalled := make(chan HtlcModifyRequest, 1) + successInterceptCallback := func( + req HtlcModifyRequest) (*HtlcModifyResponse, error) { + + interceptCallbackCalled <- req + + return &expectedResponse, nil + } + errorInterceptCallback := func( + req HtlcModifyRequest) (*HtlcModifyResponse, error) { + + interceptCallbackCalled <- req + + return nil, fmt.Errorf("something went wrong") + } + responseCallbackCalled := make(chan HtlcModifyResponse, 1) + responseCallback := func(resp HtlcModifyResponse) { + responseCallbackCalled <- resp + } + + // Create a session without setting a callback first. + err := interceptor.Intercept(request, responseCallback) + require.NoError(t, err) + + // Set the callback and create a new session. + done, _, err := interceptor.RegisterInterceptor( + successInterceptCallback, + ) + require.NoError(t, err) + + err = interceptor.Intercept(request, responseCallback) + require.NoError(t, err) + + // The intercept callback should be called now. + select { + case req := <-interceptCallbackCalled: + require.Equal(t, request, req) + + case <-time.After(defaultTimeout): + t.Fatal("intercept callback not called") + } + + // And the result should make it back to the response callback. + select { + case resp := <-responseCallbackCalled: + require.Equal(t, expectedResponse, resp) + + case <-time.After(defaultTimeout): + t.Fatal("response callback not called") + } + + // If we try to set a new callback without first returning the previous + // one, we should get an error. + _, _, err = interceptor.RegisterInterceptor(successInterceptCallback) + require.ErrorIs(t, err, ErrInterceptorClientAlreadyConnected) + + // Reset the callback, then try to set a new one. + done() + done2, _, err := interceptor.RegisterInterceptor(errorInterceptCallback) + require.NoError(t, err) + defer done2() + + // We should now get an error when intercepting. + err = interceptor.Intercept(request, responseCallback) + require.ErrorContains(t, err, "something went wrong") + + // The success callback should not be called. + select { + case resp := <-responseCallbackCalled: + t.Fatalf("unexpected response: %v", resp) + + case <-time.After(defaultTimeout): + // Expected. + } +}