mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
108 lines
2.8 KiB
Go
108 lines
2.8 KiB
Go
|
package invoices
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
defaultTimeout = 50 * time.Millisecond
|
||
|
)
|
||
|
|
||
|
// TestHtlcModificationInterceptor tests the basic functionality of the HTLC
|
||
|
// modification interceptor.
|
||
|
func TestHtlcModificationInterceptor(t *testing.T) {
|
||
|
interceptor := NewHtlcModificationInterceptor()
|
||
|
request := HtlcModifyRequest{
|
||
|
WireCustomRecords: lnwire.CustomRecords{
|
||
|
lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3},
|
||
|
},
|
||
|
ExitHtlcCircuitKey: CircuitKey{
|
||
|
ChanID: lnwire.NewShortChanIDFromInt(1),
|
||
|
HtlcID: 1,
|
||
|
},
|
||
|
ExitHtlcAmt: 1234,
|
||
|
}
|
||
|
expectedResponse := HtlcModifyResponse{
|
||
|
AmountPaid: 345,
|
||
|
}
|
||
|
interceptCallbackCalled := make(chan HtlcModifyRequest, 1)
|
||
|
successInterceptCallback := func(
|
||
|
req HtlcModifyRequest) (*HtlcModifyResponse, error) {
|
||
|
|
||
|
interceptCallbackCalled <- req
|
||
|
|
||
|
return &expectedResponse, nil
|
||
|
}
|
||
|
errorInterceptCallback := func(
|
||
|
req HtlcModifyRequest) (*HtlcModifyResponse, error) {
|
||
|
|
||
|
interceptCallbackCalled <- req
|
||
|
|
||
|
return nil, fmt.Errorf("something went wrong")
|
||
|
}
|
||
|
responseCallbackCalled := make(chan HtlcModifyResponse, 1)
|
||
|
responseCallback := func(resp HtlcModifyResponse) {
|
||
|
responseCallbackCalled <- resp
|
||
|
}
|
||
|
|
||
|
// Create a session without setting a callback first.
|
||
|
err := interceptor.Intercept(request, responseCallback)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// Set the callback and create a new session.
|
||
|
done, _, err := interceptor.RegisterInterceptor(
|
||
|
successInterceptCallback,
|
||
|
)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
err = interceptor.Intercept(request, responseCallback)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// The intercept callback should be called now.
|
||
|
select {
|
||
|
case req := <-interceptCallbackCalled:
|
||
|
require.Equal(t, request, req)
|
||
|
|
||
|
case <-time.After(defaultTimeout):
|
||
|
t.Fatal("intercept callback not called")
|
||
|
}
|
||
|
|
||
|
// And the result should make it back to the response callback.
|
||
|
select {
|
||
|
case resp := <-responseCallbackCalled:
|
||
|
require.Equal(t, expectedResponse, resp)
|
||
|
|
||
|
case <-time.After(defaultTimeout):
|
||
|
t.Fatal("response callback not called")
|
||
|
}
|
||
|
|
||
|
// If we try to set a new callback without first returning the previous
|
||
|
// one, we should get an error.
|
||
|
_, _, err = interceptor.RegisterInterceptor(successInterceptCallback)
|
||
|
require.ErrorIs(t, err, ErrInterceptorClientAlreadyConnected)
|
||
|
|
||
|
// Reset the callback, then try to set a new one.
|
||
|
done()
|
||
|
done2, _, err := interceptor.RegisterInterceptor(errorInterceptCallback)
|
||
|
require.NoError(t, err)
|
||
|
defer done2()
|
||
|
|
||
|
// We should now get an error when intercepting.
|
||
|
err = interceptor.Intercept(request, responseCallback)
|
||
|
require.ErrorContains(t, err, "something went wrong")
|
||
|
|
||
|
// The success callback should not be called.
|
||
|
select {
|
||
|
case resp := <-responseCallbackCalled:
|
||
|
t.Fatalf("unexpected response: %v", resp)
|
||
|
|
||
|
case <-time.After(defaultTimeout):
|
||
|
// Expected.
|
||
|
}
|
||
|
}
|