lnd/htlcswitch/hop/iterator_test.go
Carla Kirk-Cohen 4d051b4170
multi: handle all blinding point validation in ValidateParsedPayloadTypes
This commit moves all our validation related to the presence of fields
into ValidateParsedPayloadTypes so that we can handle them in a single
place. We draw the distinction between:
- Validation of the payload (and the context within it's being parsed,
  final hop / blinded hop etc)
- Processing and validation of encrypted data, where we perform
  additional cryptographic operations and validate that the fields
  contained in the blob are valid.

This helps draw the line more clearly between the two validation types,
rather than splitting some payload-releated blinded hop processing
into the encrypted data processing part. The downside of this approach
(vs doing the blinded path payload check _after_ payload validation)
is that we have to pass additional context into payload validation
(ie, whether we got a blinding point in our UpdateAddHtlc - as we
already do for isFinalHop).
2024-04-25 09:15:57 -04:00

295 lines
8 KiB
Go

package hop
import (
"bytes"
"encoding/binary"
"errors"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/davecgh/go-spew/spew"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
// TestSphinxHopIteratorForwardingInstructions tests that we're able to
// properly decode an onion payload, no matter the payload type, into the
// original set of forwarding instructions.
func TestSphinxHopIteratorForwardingInstructions(t *testing.T) {
t.Parallel()
// First, we'll make the hop data that the sender would create to send
// an HTLC through our imaginary route.
hopData := sphinx.HopData{
ForwardAmount: 100000,
OutgoingCltv: 4343,
}
copy(hopData.NextAddress[:], bytes.Repeat([]byte("a"), 8))
// Next, we'll make the hop forwarding information that we should
// extract each type, no matter the payload type.
nextAddrInt := binary.BigEndian.Uint64(hopData.NextAddress[:])
expectedFwdInfo := ForwardingInfo{
NextHop: lnwire.NewShortChanIDFromInt(nextAddrInt),
AmountToForward: lnwire.MilliSatoshi(hopData.ForwardAmount),
OutgoingCTLV: hopData.OutgoingCltv,
}
// For our TLV payload, we'll serialize the hop into into a TLV stream
// as we would normally in the routing network.
var b bytes.Buffer
tlvRecords := []tlv.Record{
record.NewAmtToFwdRecord(&hopData.ForwardAmount),
record.NewLockTimeRecord(&hopData.OutgoingCltv),
record.NewNextHopIDRecord(&nextAddrInt),
}
tlvStream, err := tlv.NewStream(tlvRecords...)
require.NoError(t, err, "unable to create stream")
if err := tlvStream.Encode(&b); err != nil {
t.Fatalf("unable to encode stream: %v", err)
}
var testCases = []struct {
sphinxPacket *sphinx.ProcessedPacket
expectedFwdInfo ForwardingInfo
}{
// A regular legacy payload that signals more hops.
{
sphinxPacket: &sphinx.ProcessedPacket{
Payload: sphinx.HopPayload{
Type: sphinx.PayloadLegacy,
},
Action: sphinx.MoreHops,
ForwardingInstructions: &hopData,
},
expectedFwdInfo: expectedFwdInfo,
},
// A TLV payload, which includes the sphinx action as
// cid may be zero for blinded routes (thus we require the
// action to signal whether we are at the final hop).
{
sphinxPacket: &sphinx.ProcessedPacket{
Payload: sphinx.HopPayload{
Type: sphinx.PayloadTLV,
Payload: b.Bytes(),
},
Action: sphinx.MoreHops,
},
expectedFwdInfo: expectedFwdInfo,
},
}
// Finally, we'll test that we get the same set of
// ForwardingInstructions for each payload type.
iterator := sphinxHopIterator{}
for i, testCase := range testCases {
iterator.processedPacket = testCase.sphinxPacket
pld, err := iterator.HopPayload()
if err != nil {
t.Fatalf("#%v: unable to extract forwarding "+
"instructions: %v", i, err)
}
fwdInfo := pld.ForwardingInfo()
if fwdInfo != testCase.expectedFwdInfo {
t.Fatalf("#%v: wrong fwding info: expected %v, got %v",
i, spew.Sdump(testCase.expectedFwdInfo),
spew.Sdump(fwdInfo))
}
}
}
// TestForwardingAmountCalc tests calculation of forwarding amounts from the
// hop's forwarding parameters.
func TestForwardingAmountCalc(t *testing.T) {
t.Parallel()
tests := []struct {
name string
incomingAmount lnwire.MilliSatoshi
baseFee uint32
proportional uint32
forwardAmount lnwire.MilliSatoshi
expectErr bool
}{
{
name: "overflow",
incomingAmount: 10,
baseFee: 100,
expectErr: true,
},
{
name: "trivial proportional",
incomingAmount: 100_000,
baseFee: 1000,
proportional: 10,
forwardAmount: 99000,
},
{
name: "both fees charged",
incomingAmount: 10_002_020,
baseFee: 1000,
proportional: 1,
forwardAmount: 10_001_010,
},
}
for _, testCase := range tests {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
actual, err := calculateForwardingAmount(
testCase.incomingAmount, testCase.baseFee,
testCase.proportional,
)
require.Equal(t, testCase.expectErr, err != nil)
require.Equal(t, testCase.forwardAmount.ToSatoshis(),
actual.ToSatoshis())
})
}
}
// mockProcessor is a mocked blinding point processor that just returns the
// data that it is called with when "decrypting".
type mockProcessor struct {
decryptErr error
}
// DecryptBlindedHopData mocks blob decryption, returning the same data that
// it was called with and an optionally configured error.
func (m *mockProcessor) DecryptBlindedHopData(_ *btcec.PublicKey,
data []byte) ([]byte, error) {
return data, m.decryptErr
}
// NextEphemeral mocks getting our next ephemeral key.
func (m *mockProcessor) NextEphemeral(*btcec.PublicKey) (*btcec.PublicKey,
error) {
return nil, nil
}
// TestDecryptAndValidateFwdInfo tests deriving forwarding info using a
// blinding kit. This test does not cover assertions on the calculations of
// forwarding information, because this is covered in a test dedicated to those
// calculations.
func TestDecryptAndValidateFwdInfo(t *testing.T) {
t.Parallel()
// Encode valid blinding data that we'll fake decrypting for our test.
maxCltv := 1000
blindedData := record.NewBlindedRouteData(
lnwire.NewShortChanIDFromInt(1500), nil,
record.PaymentRelayInfo{
CltvExpiryDelta: 10,
BaseFee: 100,
FeeRate: 0,
},
&record.PaymentConstraints{
MaxCltvExpiry: 1000,
HtlcMinimumMsat: lnwire.MilliSatoshi(1),
},
nil,
)
validData, err := record.EncodeBlindedRouteData(blindedData)
require.NoError(t, err)
// Mocked error.
errDecryptFailed := errors.New("could not decrypt")
tests := []struct {
name string
data []byte
incomingCLTV uint32
updateAddBlinding *btcec.PublicKey
payloadBlinding *btcec.PublicKey
processor *mockProcessor
expectedErr error
}{
{
name: "no blinding point",
data: validData,
processor: &mockProcessor{},
expectedErr: ErrNoBlindingPoint,
},
{
name: "decryption failed",
data: validData,
updateAddBlinding: &btcec.PublicKey{},
incomingCLTV: 500,
processor: &mockProcessor{
decryptErr: errDecryptFailed,
},
expectedErr: errDecryptFailed,
},
{
name: "decode fails",
data: []byte{1, 2, 3},
updateAddBlinding: &btcec.PublicKey{},
incomingCLTV: 500,
processor: &mockProcessor{},
expectedErr: ErrDecodeFailed,
},
{
name: "validation fails",
data: validData,
updateAddBlinding: &btcec.PublicKey{},
incomingCLTV: uint32(maxCltv) + 10,
processor: &mockProcessor{},
expectedErr: ErrInvalidPayload{
Type: record.LockTimeOnionType,
Violation: InsufficientViolation,
},
},
{
name: "valid using update add",
updateAddBlinding: &btcec.PublicKey{},
data: validData,
processor: &mockProcessor{},
expectedErr: nil,
},
{
name: "valid using payload",
payloadBlinding: &btcec.PublicKey{},
data: validData,
processor: &mockProcessor{},
expectedErr: nil,
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
// We don't actually use blinding keys due to our
// mocking so they can be nil.
kit := BlindingKit{
Processor: testCase.processor,
IncomingAmount: 10000,
IncomingCltv: testCase.incomingCLTV,
}
if testCase.updateAddBlinding != nil {
kit.UpdateAddBlinding = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](testCase.updateAddBlinding),
)
}
_, err := kit.DecryptAndValidateFwdInfo(
&Payload{
encryptedData: testCase.data,
blindingPoint: testCase.payloadBlinding,
}, false,
make(map[tlv.Type][]byte),
)
require.ErrorIs(t, err, testCase.expectedErr)
})
}
}