lnd/htlcswitch/hop/iterator_test.go
Carla Kirk-Cohen 776c889267
multi: return route role from HopPayload
We need to know what role we're playing to be able to handle errors
correctly, but the information that we need for this is held by our
iterator:
- Whether we had a blinding point in update add (blinding kit)
- Whether we had a blinding point in payload

As we're now going to use the route role return value even when our
err!=nil, we rename the error to signal that we're using less
canonical golang here.

An alternative to this approach is to attach a RouteRole to our
ErrInvalidPayload. The downside of that approach is:
- Propagate context through parsing (whether we had updateAddHtlc)
- Clumsy handling for errors that are not of type ErrInvalidPayload
2024-04-25 09:46:31 -04:00

295 lines
8 KiB
Go

package hop
import (
"bytes"
"encoding/binary"
"errors"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/davecgh/go-spew/spew"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
// TestSphinxHopIteratorForwardingInstructions tests that we're able to
// properly decode an onion payload, no matter the payload type, into the
// original set of forwarding instructions.
func TestSphinxHopIteratorForwardingInstructions(t *testing.T) {
t.Parallel()
// First, we'll make the hop data that the sender would create to send
// an HTLC through our imaginary route.
hopData := sphinx.HopData{
ForwardAmount: 100000,
OutgoingCltv: 4343,
}
copy(hopData.NextAddress[:], bytes.Repeat([]byte("a"), 8))
// Next, we'll make the hop forwarding information that we should
// extract each type, no matter the payload type.
nextAddrInt := binary.BigEndian.Uint64(hopData.NextAddress[:])
expectedFwdInfo := ForwardingInfo{
NextHop: lnwire.NewShortChanIDFromInt(nextAddrInt),
AmountToForward: lnwire.MilliSatoshi(hopData.ForwardAmount),
OutgoingCTLV: hopData.OutgoingCltv,
}
// For our TLV payload, we'll serialize the hop into into a TLV stream
// as we would normally in the routing network.
var b bytes.Buffer
tlvRecords := []tlv.Record{
record.NewAmtToFwdRecord(&hopData.ForwardAmount),
record.NewLockTimeRecord(&hopData.OutgoingCltv),
record.NewNextHopIDRecord(&nextAddrInt),
}
tlvStream, err := tlv.NewStream(tlvRecords...)
require.NoError(t, err, "unable to create stream")
if err := tlvStream.Encode(&b); err != nil {
t.Fatalf("unable to encode stream: %v", err)
}
var testCases = []struct {
sphinxPacket *sphinx.ProcessedPacket
expectedFwdInfo ForwardingInfo
}{
// A regular legacy payload that signals more hops.
{
sphinxPacket: &sphinx.ProcessedPacket{
Payload: sphinx.HopPayload{
Type: sphinx.PayloadLegacy,
},
Action: sphinx.MoreHops,
ForwardingInstructions: &hopData,
},
expectedFwdInfo: expectedFwdInfo,
},
// A TLV payload, which includes the sphinx action as
// cid may be zero for blinded routes (thus we require the
// action to signal whether we are at the final hop).
{
sphinxPacket: &sphinx.ProcessedPacket{
Payload: sphinx.HopPayload{
Type: sphinx.PayloadTLV,
Payload: b.Bytes(),
},
Action: sphinx.MoreHops,
},
expectedFwdInfo: expectedFwdInfo,
},
}
// Finally, we'll test that we get the same set of
// ForwardingInstructions for each payload type.
iterator := sphinxHopIterator{}
for i, testCase := range testCases {
iterator.processedPacket = testCase.sphinxPacket
pld, _, pldErr := iterator.HopPayload()
if pldErr != nil {
t.Fatalf("#%v: unable to extract forwarding "+
"instructions: %v", i, pldErr)
}
fwdInfo := pld.ForwardingInfo()
if fwdInfo != testCase.expectedFwdInfo {
t.Fatalf("#%v: wrong fwding info: expected %v, got %v",
i, spew.Sdump(testCase.expectedFwdInfo),
spew.Sdump(fwdInfo))
}
}
}
// TestForwardingAmountCalc tests calculation of forwarding amounts from the
// hop's forwarding parameters.
func TestForwardingAmountCalc(t *testing.T) {
t.Parallel()
tests := []struct {
name string
incomingAmount lnwire.MilliSatoshi
baseFee uint32
proportional uint32
forwardAmount lnwire.MilliSatoshi
expectErr bool
}{
{
name: "overflow",
incomingAmount: 10,
baseFee: 100,
expectErr: true,
},
{
name: "trivial proportional",
incomingAmount: 100_000,
baseFee: 1000,
proportional: 10,
forwardAmount: 99000,
},
{
name: "both fees charged",
incomingAmount: 10_002_020,
baseFee: 1000,
proportional: 1,
forwardAmount: 10_001_010,
},
}
for _, testCase := range tests {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
actual, err := calculateForwardingAmount(
testCase.incomingAmount, testCase.baseFee,
testCase.proportional,
)
require.Equal(t, testCase.expectErr, err != nil)
require.Equal(t, testCase.forwardAmount.ToSatoshis(),
actual.ToSatoshis())
})
}
}
// mockProcessor is a mocked blinding point processor that just returns the
// data that it is called with when "decrypting".
type mockProcessor struct {
decryptErr error
}
// DecryptBlindedHopData mocks blob decryption, returning the same data that
// it was called with and an optionally configured error.
func (m *mockProcessor) DecryptBlindedHopData(_ *btcec.PublicKey,
data []byte) ([]byte, error) {
return data, m.decryptErr
}
// NextEphemeral mocks getting our next ephemeral key.
func (m *mockProcessor) NextEphemeral(*btcec.PublicKey) (*btcec.PublicKey,
error) {
return nil, nil
}
// TestDecryptAndValidateFwdInfo tests deriving forwarding info using a
// blinding kit. This test does not cover assertions on the calculations of
// forwarding information, because this is covered in a test dedicated to those
// calculations.
func TestDecryptAndValidateFwdInfo(t *testing.T) {
t.Parallel()
// Encode valid blinding data that we'll fake decrypting for our test.
maxCltv := 1000
blindedData := record.NewBlindedRouteData(
lnwire.NewShortChanIDFromInt(1500), nil,
record.PaymentRelayInfo{
CltvExpiryDelta: 10,
BaseFee: 100,
FeeRate: 0,
},
&record.PaymentConstraints{
MaxCltvExpiry: 1000,
HtlcMinimumMsat: lnwire.MilliSatoshi(1),
},
nil,
)
validData, err := record.EncodeBlindedRouteData(blindedData)
require.NoError(t, err)
// Mocked error.
errDecryptFailed := errors.New("could not decrypt")
tests := []struct {
name string
data []byte
incomingCLTV uint32
updateAddBlinding *btcec.PublicKey
payloadBlinding *btcec.PublicKey
processor *mockProcessor
expectedErr error
}{
{
name: "no blinding point",
data: validData,
processor: &mockProcessor{},
expectedErr: ErrNoBlindingPoint,
},
{
name: "decryption failed",
data: validData,
updateAddBlinding: &btcec.PublicKey{},
incomingCLTV: 500,
processor: &mockProcessor{
decryptErr: errDecryptFailed,
},
expectedErr: errDecryptFailed,
},
{
name: "decode fails",
data: []byte{1, 2, 3},
updateAddBlinding: &btcec.PublicKey{},
incomingCLTV: 500,
processor: &mockProcessor{},
expectedErr: ErrDecodeFailed,
},
{
name: "validation fails",
data: validData,
updateAddBlinding: &btcec.PublicKey{},
incomingCLTV: uint32(maxCltv) + 10,
processor: &mockProcessor{},
expectedErr: ErrInvalidPayload{
Type: record.LockTimeOnionType,
Violation: InsufficientViolation,
},
},
{
name: "valid using update add",
updateAddBlinding: &btcec.PublicKey{},
data: validData,
processor: &mockProcessor{},
expectedErr: nil,
},
{
name: "valid using payload",
payloadBlinding: &btcec.PublicKey{},
data: validData,
processor: &mockProcessor{},
expectedErr: nil,
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
// We don't actually use blinding keys due to our
// mocking so they can be nil.
kit := BlindingKit{
Processor: testCase.processor,
IncomingAmount: 10000,
IncomingCltv: testCase.incomingCLTV,
}
if testCase.updateAddBlinding != nil {
kit.UpdateAddBlinding = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](testCase.updateAddBlinding),
)
}
_, err := kit.DecryptAndValidateFwdInfo(
&Payload{
encryptedData: testCase.data,
blindingPoint: testCase.payloadBlinding,
}, false,
make(map[tlv.Type][]byte),
)
require.ErrorIs(t, err, testCase.expectedErr)
})
}
}