lnd/invoices/modification_interceptor_test.go

108 lines
2.8 KiB
Go
Raw Permalink Normal View History

package invoices
import (
"fmt"
"testing"
"time"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
)
var (
defaultTimeout = 50 * time.Millisecond
)
// TestHtlcModificationInterceptor tests the basic functionality of the HTLC
// modification interceptor.
func TestHtlcModificationInterceptor(t *testing.T) {
interceptor := NewHtlcModificationInterceptor()
request := HtlcModifyRequest{
WireCustomRecords: lnwire.CustomRecords{
lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3},
},
ExitHtlcCircuitKey: CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(1),
HtlcID: 1,
},
ExitHtlcAmt: 1234,
}
expectedResponse := HtlcModifyResponse{
AmountPaid: 345,
}
interceptCallbackCalled := make(chan HtlcModifyRequest, 1)
successInterceptCallback := func(
req HtlcModifyRequest) (*HtlcModifyResponse, error) {
interceptCallbackCalled <- req
return &expectedResponse, nil
}
errorInterceptCallback := func(
req HtlcModifyRequest) (*HtlcModifyResponse, error) {
interceptCallbackCalled <- req
return nil, fmt.Errorf("something went wrong")
}
responseCallbackCalled := make(chan HtlcModifyResponse, 1)
responseCallback := func(resp HtlcModifyResponse) {
responseCallbackCalled <- resp
}
// Create a session without setting a callback first.
err := interceptor.Intercept(request, responseCallback)
require.NoError(t, err)
// Set the callback and create a new session.
done, _, err := interceptor.RegisterInterceptor(
successInterceptCallback,
)
require.NoError(t, err)
err = interceptor.Intercept(request, responseCallback)
require.NoError(t, err)
// The intercept callback should be called now.
select {
case req := <-interceptCallbackCalled:
require.Equal(t, request, req)
case <-time.After(defaultTimeout):
t.Fatal("intercept callback not called")
}
// And the result should make it back to the response callback.
select {
case resp := <-responseCallbackCalled:
require.Equal(t, expectedResponse, resp)
case <-time.After(defaultTimeout):
t.Fatal("response callback not called")
}
// If we try to set a new callback without first returning the previous
// one, we should get an error.
_, _, err = interceptor.RegisterInterceptor(successInterceptCallback)
require.ErrorIs(t, err, ErrInterceptorClientAlreadyConnected)
// Reset the callback, then try to set a new one.
done()
done2, _, err := interceptor.RegisterInterceptor(errorInterceptCallback)
require.NoError(t, err)
defer done2()
// We should now get an error when intercepting.
err = interceptor.Intercept(request, responseCallback)
require.ErrorContains(t, err, "something went wrong")
// The success callback should not be called.
select {
case resp := <-responseCallbackCalled:
t.Fatalf("unexpected response: %v", resp)
case <-time.After(defaultTimeout):
// Expected.
}
}