mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
htlcswitch/hop: parse and validate AMP records
This commit is contained in:
parent
135a0a9f7f
commit
c2729cbbbd
@ -119,6 +119,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
amt uint64
|
||||
cltv uint32
|
||||
mpp = &record.MPP{}
|
||||
amp = &record.AMP{}
|
||||
)
|
||||
|
||||
tlvStream, err := tlv.NewStream(
|
||||
@ -126,6 +127,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
record.NewLockTimeRecord(&cltv),
|
||||
record.NewNextHopIDRecord(&cid),
|
||||
mpp.Record(),
|
||||
amp.Record(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -160,6 +162,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
mpp = nil
|
||||
}
|
||||
|
||||
// If no AMP field was parsed, set the MPP field on the resulting
|
||||
// payload to nil.
|
||||
if _, ok := parsedTypes[record.AMPOnionType]; !ok {
|
||||
amp = nil
|
||||
}
|
||||
|
||||
// Filter out the custom records.
|
||||
customRecords := NewCustomRecords(parsedTypes)
|
||||
|
||||
@ -171,6 +179,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
OutgoingCTLV: cltv,
|
||||
},
|
||||
MPP: mpp,
|
||||
AMP: amp,
|
||||
customRecords: customRecords,
|
||||
}, nil
|
||||
}
|
||||
@ -207,6 +216,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
|
||||
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
|
||||
_, hasNextHop := parsedTypes[record.NextHopOnionType]
|
||||
_, hasMPP := parsedTypes[record.MPPOnionType]
|
||||
_, hasAMP := parsedTypes[record.AMPOnionType]
|
||||
|
||||
switch {
|
||||
|
||||
@ -243,6 +253,14 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
|
||||
Violation: IncludedViolation,
|
||||
FinalHop: isFinalHop,
|
||||
}
|
||||
|
||||
// Intermediate nodes should never receive AMP fields.
|
||||
case !isFinalHop && hasAMP:
|
||||
return ErrInvalidPayload{
|
||||
Type: record.AMPOnionType,
|
||||
Violation: IncludedViolation,
|
||||
FinalHop: isFinalHop,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testUnknownRequiredType = 0x10
|
||||
@ -18,6 +19,7 @@ type decodePayloadTest struct {
|
||||
expErr error
|
||||
expCustomRecords map[uint64][]byte
|
||||
shouldHaveMPP bool
|
||||
shouldHaveAMP bool
|
||||
}
|
||||
|
||||
var decodePayloadTests = []decodePayloadTest{
|
||||
@ -183,6 +185,37 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "intermediate hop with amp",
|
||||
payload: []byte{
|
||||
// amount
|
||||
0x02, 0x00,
|
||||
// cltv
|
||||
0x04, 0x00,
|
||||
// next hop id
|
||||
0x06, 0x08,
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
// amp
|
||||
0x0e, 0x41,
|
||||
// amp.root_share
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
// amp.set_id
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
// amp.child_index
|
||||
0x09,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.AMPOnionType,
|
||||
Violation: hop.IncludedViolation,
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "final hop with mpp",
|
||||
payload: []byte{
|
||||
@ -201,6 +234,30 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
expErr: nil,
|
||||
shouldHaveMPP: true,
|
||||
},
|
||||
{
|
||||
name: "final hop with amp",
|
||||
payload: []byte{
|
||||
// amount
|
||||
0x02, 0x00,
|
||||
// cltv
|
||||
0x04, 0x00,
|
||||
// amp
|
||||
0x0e, 0x41,
|
||||
// amp.root_share
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
// amp.set_id
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
// amp.child_index
|
||||
0x09,
|
||||
},
|
||||
shouldHaveAMP: true,
|
||||
},
|
||||
}
|
||||
|
||||
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
|
||||
@ -223,6 +280,20 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
}
|
||||
|
||||
testRootShare = [32]byte{
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||
}
|
||||
testSetID = [32]byte{
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||
}
|
||||
testChildIndex = uint32(9)
|
||||
)
|
||||
|
||||
p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
|
||||
@ -249,6 +320,17 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
||||
t.Fatalf("unexpected MPP payload")
|
||||
}
|
||||
|
||||
if test.shouldHaveAMP {
|
||||
if p.AMP == nil {
|
||||
t.Fatalf("payload should have AMP record")
|
||||
}
|
||||
require.Equal(t, testRootShare, p.AMP.RootShare())
|
||||
require.Equal(t, testSetID, p.AMP.SetID())
|
||||
require.Equal(t, testChildIndex, p.AMP.ChildIndex())
|
||||
} else if p.AMP != nil {
|
||||
t.Fatalf("unexpected AMP payload")
|
||||
}
|
||||
|
||||
// Convert expected nil map to empty map, because we always expect an
|
||||
// initiated map from the payload.
|
||||
expCustomRecords := make(record.CustomSet)
|
||||
|
Loading…
Reference in New Issue
Block a user