lnd/htlcswitch/hop/iterator_test.go
Elle Mouton ad0905f10e
record+htlcswitch: convert BlindedRouteData fields to optional
For the final hop in a blinded route, the SCID and RelayInfo fields will
_not_ be set. So these fields need to be converted to optional records.

The existing BlindedRouteData constructor is also renamed to
`NewNonFinalBlindedRouteData` in preparation for a
`NewFinalBlindedRouteData` constructor which will be used to construct
the blinded data for the final hop which will contain a much smaller set
of data. The SCID and RelayInfo parameters of the constructor are left
as non-pointers in order to force the caller to set them in the case
that the constructor is called for non-final nodes. The other option
would be to create a single constructor where all parameters are
optional but I think this makes it easier for the caller to make a
mistake.
2024-07-10 09:12:39 +02:00

295 lines
8 KiB
Go

package hop
import (
"bytes"
"encoding/binary"
"errors"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/davecgh/go-spew/spew"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
// TestSphinxHopIteratorForwardingInstructions tests that we're able to
// properly decode an onion payload, no matter the payload type, into the
// original set of forwarding instructions.
func TestSphinxHopIteratorForwardingInstructions(t *testing.T) {
t.Parallel()
// First, we'll make the hop data that the sender would create to send
// an HTLC through our imaginary route.
hopData := sphinx.HopData{
ForwardAmount: 100000,
OutgoingCltv: 4343,
}
copy(hopData.NextAddress[:], bytes.Repeat([]byte("a"), 8))
// Next, we'll make the hop forwarding information that we should
// extract each type, no matter the payload type.
nextAddrInt := binary.BigEndian.Uint64(hopData.NextAddress[:])
expectedFwdInfo := ForwardingInfo{
NextHop: lnwire.NewShortChanIDFromInt(nextAddrInt),
AmountToForward: lnwire.MilliSatoshi(hopData.ForwardAmount),
OutgoingCTLV: hopData.OutgoingCltv,
}
// For our TLV payload, we'll serialize the hop into into a TLV stream
// as we would normally in the routing network.
var b bytes.Buffer
tlvRecords := []tlv.Record{
record.NewAmtToFwdRecord(&hopData.ForwardAmount),
record.NewLockTimeRecord(&hopData.OutgoingCltv),
record.NewNextHopIDRecord(&nextAddrInt),
}
tlvStream, err := tlv.NewStream(tlvRecords...)
require.NoError(t, err, "unable to create stream")
if err := tlvStream.Encode(&b); err != nil {
t.Fatalf("unable to encode stream: %v", err)
}
var testCases = []struct {
sphinxPacket *sphinx.ProcessedPacket
expectedFwdInfo ForwardingInfo
}{
// A regular legacy payload that signals more hops.
{
sphinxPacket: &sphinx.ProcessedPacket{
Payload: sphinx.HopPayload{
Type: sphinx.PayloadLegacy,
},
Action: sphinx.MoreHops,
ForwardingInstructions: &hopData,
},
expectedFwdInfo: expectedFwdInfo,
},
// A TLV payload, which includes the sphinx action as
// cid may be zero for blinded routes (thus we require the
// action to signal whether we are at the final hop).
{
sphinxPacket: &sphinx.ProcessedPacket{
Payload: sphinx.HopPayload{
Type: sphinx.PayloadTLV,
Payload: b.Bytes(),
},
Action: sphinx.MoreHops,
},
expectedFwdInfo: expectedFwdInfo,
},
}
// Finally, we'll test that we get the same set of
// ForwardingInstructions for each payload type.
iterator := sphinxHopIterator{}
for i, testCase := range testCases {
iterator.processedPacket = testCase.sphinxPacket
pld, _, pldErr := iterator.HopPayload()
if pldErr != nil {
t.Fatalf("#%v: unable to extract forwarding "+
"instructions: %v", i, pldErr)
}
fwdInfo := pld.ForwardingInfo()
if fwdInfo != testCase.expectedFwdInfo {
t.Fatalf("#%v: wrong fwding info: expected %v, got %v",
i, spew.Sdump(testCase.expectedFwdInfo),
spew.Sdump(fwdInfo))
}
}
}
// TestForwardingAmountCalc tests calculation of forwarding amounts from the
// hop's forwarding parameters.
func TestForwardingAmountCalc(t *testing.T) {
t.Parallel()
tests := []struct {
name string
incomingAmount lnwire.MilliSatoshi
baseFee uint32
proportional uint32
forwardAmount lnwire.MilliSatoshi
expectErr bool
}{
{
name: "overflow",
incomingAmount: 10,
baseFee: 100,
expectErr: true,
},
{
name: "trivial proportional",
incomingAmount: 100_000,
baseFee: 1000,
proportional: 10,
forwardAmount: 99000,
},
{
name: "both fees charged",
incomingAmount: 10_002_020,
baseFee: 1000,
proportional: 1,
forwardAmount: 10_001_010,
},
}
for _, testCase := range tests {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
actual, err := calculateForwardingAmount(
testCase.incomingAmount, testCase.baseFee,
testCase.proportional,
)
require.Equal(t, testCase.expectErr, err != nil)
require.Equal(t, testCase.forwardAmount.ToSatoshis(),
actual.ToSatoshis())
})
}
}
// mockProcessor is a mocked blinding point processor that just returns the
// data that it is called with when "decrypting".
type mockProcessor struct {
decryptErr error
}
// DecryptBlindedHopData mocks blob decryption, returning the same data that
// it was called with and an optionally configured error.
func (m *mockProcessor) DecryptBlindedHopData(_ *btcec.PublicKey,
data []byte) ([]byte, error) {
return data, m.decryptErr
}
// NextEphemeral mocks getting our next ephemeral key.
func (m *mockProcessor) NextEphemeral(*btcec.PublicKey) (*btcec.PublicKey,
error) {
return nil, nil
}
// TestDecryptAndValidateFwdInfo tests deriving forwarding info using a
// blinding kit. This test does not cover assertions on the calculations of
// forwarding information, because this is covered in a test dedicated to those
// calculations.
func TestDecryptAndValidateFwdInfo(t *testing.T) {
t.Parallel()
// Encode valid blinding data that we'll fake decrypting for our test.
maxCltv := 1000
blindedData := record.NewNonFinalBlindedRouteData(
lnwire.NewShortChanIDFromInt(1500), nil,
record.PaymentRelayInfo{
CltvExpiryDelta: 10,
BaseFee: 100,
FeeRate: 0,
},
&record.PaymentConstraints{
MaxCltvExpiry: 1000,
HtlcMinimumMsat: lnwire.MilliSatoshi(1),
},
nil,
)
validData, err := record.EncodeBlindedRouteData(blindedData)
require.NoError(t, err)
// Mocked error.
errDecryptFailed := errors.New("could not decrypt")
tests := []struct {
name string
data []byte
incomingCLTV uint32
updateAddBlinding *btcec.PublicKey
payloadBlinding *btcec.PublicKey
processor *mockProcessor
expectedErr error
}{
{
name: "no blinding point",
data: validData,
processor: &mockProcessor{},
expectedErr: ErrNoBlindingPoint,
},
{
name: "decryption failed",
data: validData,
updateAddBlinding: &btcec.PublicKey{},
incomingCLTV: 500,
processor: &mockProcessor{
decryptErr: errDecryptFailed,
},
expectedErr: errDecryptFailed,
},
{
name: "decode fails",
data: []byte{1, 2, 3},
updateAddBlinding: &btcec.PublicKey{},
incomingCLTV: 500,
processor: &mockProcessor{},
expectedErr: ErrDecodeFailed,
},
{
name: "validation fails",
data: validData,
updateAddBlinding: &btcec.PublicKey{},
incomingCLTV: uint32(maxCltv) + 10,
processor: &mockProcessor{},
expectedErr: ErrInvalidPayload{
Type: record.LockTimeOnionType,
Violation: InsufficientViolation,
},
},
{
name: "valid using update add",
updateAddBlinding: &btcec.PublicKey{},
data: validData,
processor: &mockProcessor{},
expectedErr: nil,
},
{
name: "valid using payload",
payloadBlinding: &btcec.PublicKey{},
data: validData,
processor: &mockProcessor{},
expectedErr: nil,
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
// We don't actually use blinding keys due to our
// mocking so they can be nil.
kit := BlindingKit{
Processor: testCase.processor,
IncomingAmount: 10000,
IncomingCltv: testCase.incomingCLTV,
}
if testCase.updateAddBlinding != nil {
kit.UpdateAddBlinding = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](testCase.updateAddBlinding),
)
}
_, err := kit.DecryptAndValidateFwdInfo(
&Payload{
encryptedData: testCase.data,
blindingPoint: testCase.payloadBlinding,
}, false,
make(map[tlv.Type][]byte),
)
require.ErrorIs(t, err, testCase.expectedErr)
})
}
}