mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-23 14:40:30 +01:00
We need to know what role we're playing to be able to handle errors correctly, but the information that we need for this is held by our iterator: - Whether we had a blinding point in update add (blinding kit) - Whether we had a blinding point in payload As we're now going to use the route role return value even when our err!=nil, we rename the error to signal that we're using less canonical golang here. An alternative to this approach is to attach a RouteRole to our ErrInvalidPayload. The downside of that approach is: - Propagate context through parsing (whether we had updateAddHtlc) - Clumsy handling for errors that are not of type ErrInvalidPayload
295 lines
8 KiB
Go
295 lines
8 KiB
Go
package hop
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/btcsuite/btcd/btcec/v2"
|
|
"github.com/davecgh/go-spew/spew"
|
|
sphinx "github.com/lightningnetwork/lightning-onion"
|
|
"github.com/lightningnetwork/lnd/lnwire"
|
|
"github.com/lightningnetwork/lnd/record"
|
|
"github.com/lightningnetwork/lnd/tlv"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestSphinxHopIteratorForwardingInstructions tests that we're able to
|
|
// properly decode an onion payload, no matter the payload type, into the
|
|
// original set of forwarding instructions.
|
|
func TestSphinxHopIteratorForwardingInstructions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// First, we'll make the hop data that the sender would create to send
|
|
// an HTLC through our imaginary route.
|
|
hopData := sphinx.HopData{
|
|
ForwardAmount: 100000,
|
|
OutgoingCltv: 4343,
|
|
}
|
|
copy(hopData.NextAddress[:], bytes.Repeat([]byte("a"), 8))
|
|
|
|
// Next, we'll make the hop forwarding information that we should
|
|
// extract each type, no matter the payload type.
|
|
nextAddrInt := binary.BigEndian.Uint64(hopData.NextAddress[:])
|
|
expectedFwdInfo := ForwardingInfo{
|
|
NextHop: lnwire.NewShortChanIDFromInt(nextAddrInt),
|
|
AmountToForward: lnwire.MilliSatoshi(hopData.ForwardAmount),
|
|
OutgoingCTLV: hopData.OutgoingCltv,
|
|
}
|
|
|
|
// For our TLV payload, we'll serialize the hop into into a TLV stream
|
|
// as we would normally in the routing network.
|
|
var b bytes.Buffer
|
|
tlvRecords := []tlv.Record{
|
|
record.NewAmtToFwdRecord(&hopData.ForwardAmount),
|
|
record.NewLockTimeRecord(&hopData.OutgoingCltv),
|
|
record.NewNextHopIDRecord(&nextAddrInt),
|
|
}
|
|
tlvStream, err := tlv.NewStream(tlvRecords...)
|
|
require.NoError(t, err, "unable to create stream")
|
|
if err := tlvStream.Encode(&b); err != nil {
|
|
t.Fatalf("unable to encode stream: %v", err)
|
|
}
|
|
|
|
var testCases = []struct {
|
|
sphinxPacket *sphinx.ProcessedPacket
|
|
expectedFwdInfo ForwardingInfo
|
|
}{
|
|
// A regular legacy payload that signals more hops.
|
|
{
|
|
sphinxPacket: &sphinx.ProcessedPacket{
|
|
Payload: sphinx.HopPayload{
|
|
Type: sphinx.PayloadLegacy,
|
|
},
|
|
Action: sphinx.MoreHops,
|
|
ForwardingInstructions: &hopData,
|
|
},
|
|
expectedFwdInfo: expectedFwdInfo,
|
|
},
|
|
// A TLV payload, which includes the sphinx action as
|
|
// cid may be zero for blinded routes (thus we require the
|
|
// action to signal whether we are at the final hop).
|
|
{
|
|
sphinxPacket: &sphinx.ProcessedPacket{
|
|
Payload: sphinx.HopPayload{
|
|
Type: sphinx.PayloadTLV,
|
|
Payload: b.Bytes(),
|
|
},
|
|
Action: sphinx.MoreHops,
|
|
},
|
|
expectedFwdInfo: expectedFwdInfo,
|
|
},
|
|
}
|
|
|
|
// Finally, we'll test that we get the same set of
|
|
// ForwardingInstructions for each payload type.
|
|
iterator := sphinxHopIterator{}
|
|
for i, testCase := range testCases {
|
|
iterator.processedPacket = testCase.sphinxPacket
|
|
|
|
pld, _, pldErr := iterator.HopPayload()
|
|
if pldErr != nil {
|
|
t.Fatalf("#%v: unable to extract forwarding "+
|
|
"instructions: %v", i, pldErr)
|
|
}
|
|
|
|
fwdInfo := pld.ForwardingInfo()
|
|
if fwdInfo != testCase.expectedFwdInfo {
|
|
t.Fatalf("#%v: wrong fwding info: expected %v, got %v",
|
|
i, spew.Sdump(testCase.expectedFwdInfo),
|
|
spew.Sdump(fwdInfo))
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestForwardingAmountCalc tests calculation of forwarding amounts from the
|
|
// hop's forwarding parameters.
|
|
func TestForwardingAmountCalc(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
incomingAmount lnwire.MilliSatoshi
|
|
baseFee uint32
|
|
proportional uint32
|
|
forwardAmount lnwire.MilliSatoshi
|
|
expectErr bool
|
|
}{
|
|
{
|
|
name: "overflow",
|
|
incomingAmount: 10,
|
|
baseFee: 100,
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "trivial proportional",
|
|
incomingAmount: 100_000,
|
|
baseFee: 1000,
|
|
proportional: 10,
|
|
forwardAmount: 99000,
|
|
},
|
|
{
|
|
name: "both fees charged",
|
|
incomingAmount: 10_002_020,
|
|
baseFee: 1000,
|
|
proportional: 1,
|
|
forwardAmount: 10_001_010,
|
|
},
|
|
}
|
|
|
|
for _, testCase := range tests {
|
|
testCase := testCase
|
|
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
actual, err := calculateForwardingAmount(
|
|
testCase.incomingAmount, testCase.baseFee,
|
|
testCase.proportional,
|
|
)
|
|
|
|
require.Equal(t, testCase.expectErr, err != nil)
|
|
require.Equal(t, testCase.forwardAmount.ToSatoshis(),
|
|
actual.ToSatoshis())
|
|
})
|
|
}
|
|
}
|
|
|
|
// mockProcessor is a mocked blinding point processor that just returns the
|
|
// data that it is called with when "decrypting".
|
|
type mockProcessor struct {
|
|
decryptErr error
|
|
}
|
|
|
|
// DecryptBlindedHopData mocks blob decryption, returning the same data that
|
|
// it was called with and an optionally configured error.
|
|
func (m *mockProcessor) DecryptBlindedHopData(_ *btcec.PublicKey,
|
|
data []byte) ([]byte, error) {
|
|
|
|
return data, m.decryptErr
|
|
}
|
|
|
|
// NextEphemeral mocks getting our next ephemeral key.
|
|
func (m *mockProcessor) NextEphemeral(*btcec.PublicKey) (*btcec.PublicKey,
|
|
error) {
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
// TestDecryptAndValidateFwdInfo tests deriving forwarding info using a
|
|
// blinding kit. This test does not cover assertions on the calculations of
|
|
// forwarding information, because this is covered in a test dedicated to those
|
|
// calculations.
|
|
func TestDecryptAndValidateFwdInfo(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Encode valid blinding data that we'll fake decrypting for our test.
|
|
maxCltv := 1000
|
|
blindedData := record.NewBlindedRouteData(
|
|
lnwire.NewShortChanIDFromInt(1500), nil,
|
|
record.PaymentRelayInfo{
|
|
CltvExpiryDelta: 10,
|
|
BaseFee: 100,
|
|
FeeRate: 0,
|
|
},
|
|
&record.PaymentConstraints{
|
|
MaxCltvExpiry: 1000,
|
|
HtlcMinimumMsat: lnwire.MilliSatoshi(1),
|
|
},
|
|
nil,
|
|
)
|
|
|
|
validData, err := record.EncodeBlindedRouteData(blindedData)
|
|
require.NoError(t, err)
|
|
|
|
// Mocked error.
|
|
errDecryptFailed := errors.New("could not decrypt")
|
|
|
|
tests := []struct {
|
|
name string
|
|
data []byte
|
|
incomingCLTV uint32
|
|
updateAddBlinding *btcec.PublicKey
|
|
payloadBlinding *btcec.PublicKey
|
|
processor *mockProcessor
|
|
expectedErr error
|
|
}{
|
|
{
|
|
name: "no blinding point",
|
|
data: validData,
|
|
processor: &mockProcessor{},
|
|
expectedErr: ErrNoBlindingPoint,
|
|
},
|
|
{
|
|
name: "decryption failed",
|
|
data: validData,
|
|
updateAddBlinding: &btcec.PublicKey{},
|
|
incomingCLTV: 500,
|
|
processor: &mockProcessor{
|
|
decryptErr: errDecryptFailed,
|
|
},
|
|
expectedErr: errDecryptFailed,
|
|
},
|
|
{
|
|
name: "decode fails",
|
|
data: []byte{1, 2, 3},
|
|
updateAddBlinding: &btcec.PublicKey{},
|
|
incomingCLTV: 500,
|
|
processor: &mockProcessor{},
|
|
expectedErr: ErrDecodeFailed,
|
|
},
|
|
{
|
|
name: "validation fails",
|
|
data: validData,
|
|
updateAddBlinding: &btcec.PublicKey{},
|
|
incomingCLTV: uint32(maxCltv) + 10,
|
|
processor: &mockProcessor{},
|
|
expectedErr: ErrInvalidPayload{
|
|
Type: record.LockTimeOnionType,
|
|
Violation: InsufficientViolation,
|
|
},
|
|
},
|
|
{
|
|
name: "valid using update add",
|
|
updateAddBlinding: &btcec.PublicKey{},
|
|
data: validData,
|
|
processor: &mockProcessor{},
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "valid using payload",
|
|
payloadBlinding: &btcec.PublicKey{},
|
|
data: validData,
|
|
processor: &mockProcessor{},
|
|
expectedErr: nil,
|
|
},
|
|
}
|
|
|
|
for _, testCase := range tests {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
// We don't actually use blinding keys due to our
|
|
// mocking so they can be nil.
|
|
kit := BlindingKit{
|
|
Processor: testCase.processor,
|
|
IncomingAmount: 10000,
|
|
IncomingCltv: testCase.incomingCLTV,
|
|
}
|
|
|
|
if testCase.updateAddBlinding != nil {
|
|
kit.UpdateAddBlinding = tlv.SomeRecordT(
|
|
//nolint:lll
|
|
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](testCase.updateAddBlinding),
|
|
)
|
|
}
|
|
_, err := kit.DecryptAndValidateFwdInfo(
|
|
&Payload{
|
|
encryptedData: testCase.data,
|
|
blindingPoint: testCase.payloadBlinding,
|
|
}, false,
|
|
make(map[tlv.Type][]byte),
|
|
)
|
|
require.ErrorIs(t, err, testCase.expectedErr)
|
|
})
|
|
}
|
|
}
|