mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-22 14:22:37 +01:00
multi: validate contents in blinded data against payload
This commit is contained in:
parent
03f6c5cd0a
commit
ca6d414308
4 changed files with 110 additions and 1 deletions
|
@ -200,7 +200,8 @@ func (b *BlindingKit) validateBlindingPoint(payloadBlinding *btcec.PublicKey,
|
|||
// DecryptAndValidateFwdInfo performs all operations required to decrypt and
|
||||
// validate a blinded route.
|
||||
func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
|
||||
isFinalHop bool) (*ForwardingInfo, error) {
|
||||
isFinalHop bool, payloadParsed map[tlv.Type][]byte) (
|
||||
*ForwardingInfo, error) {
|
||||
|
||||
// We expect this function to be called when we have encrypted data
|
||||
// present, and a blinding key is set either in the payload or the
|
||||
|
@ -227,6 +228,14 @@ func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload,
|
|||
ErrDecodeFailed, err)
|
||||
}
|
||||
|
||||
// Validate the contents of the payload against the values we've
|
||||
// just pulled out of the encrypted data blob.
|
||||
err = ValidatePayloadWithBlinded(isFinalHop, payloadParsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Validate the data in the blinded route against our incoming htlc's
|
||||
// information.
|
||||
if err := ValidateBlindedRouteData(
|
||||
routeData, b.IncomingAmount, b.IncomingCltv,
|
||||
); err != nil {
|
||||
|
|
|
@ -287,6 +287,7 @@ func TestDecryptAndValidateFwdInfo(t *testing.T) {
|
|||
encryptedData: testCase.data,
|
||||
blindingPoint: testCase.payloadBlinding,
|
||||
}, false,
|
||||
make(map[tlv.Type][]byte),
|
||||
)
|
||||
require.ErrorIs(t, err, testCase.expectedErr)
|
||||
})
|
||||
|
|
|
@ -484,3 +484,37 @@ func ValidateBlindedRouteData(blindedData *record.BlindedRouteData,
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePayloadWithBlinded validates a payload against the contents of
|
||||
// its encrypted data blob.
|
||||
func ValidatePayloadWithBlinded(isFinalHop bool,
|
||||
payloadParsed map[tlv.Type][]byte) error {
|
||||
|
||||
// Blinded routes restrict the presence of TLVs more strictly than
|
||||
// regular routes, check that intermediate and final hops only have
|
||||
// the TLVs the spec allows them to have.
|
||||
allowedTLVs := map[tlv.Type]bool{
|
||||
record.EncryptedDataOnionType: true,
|
||||
record.BlindingPointOnionType: true,
|
||||
}
|
||||
|
||||
if isFinalHop {
|
||||
allowedTLVs[record.AmtOnionType] = true
|
||||
allowedTLVs[record.LockTimeOnionType] = true
|
||||
allowedTLVs[record.TotalAmtMsatBlindedType] = true
|
||||
}
|
||||
|
||||
for tlvType := range payloadParsed {
|
||||
if _, ok := allowedTLVs[tlvType]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
return ErrInvalidPayload{
|
||||
Type: tlvType,
|
||||
Violation: IncludedViolation,
|
||||
FinalHop: isFinalHop,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -695,3 +696,67 @@ func TestValidateBlindedRouteData(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidatePayloadWithBlinded tests validation of the contents of a
|
||||
// payload when it's for a blinded payment.
|
||||
func TestValidatePayloadWithBlinded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
finalHopMap := map[tlv.Type][]byte{
|
||||
record.AmtOnionType: nil,
|
||||
record.LockTimeOnionType: nil,
|
||||
record.TotalAmtMsatBlindedType: nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
isFinal bool
|
||||
parsed map[tlv.Type][]byte
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
name: "final hop, valid",
|
||||
isFinal: true,
|
||||
parsed: finalHopMap,
|
||||
},
|
||||
{
|
||||
name: "intermediate hop, invalid",
|
||||
isFinal: false,
|
||||
parsed: finalHopMap,
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
name: "intermediate hop, invalid",
|
||||
isFinal: false,
|
||||
parsed: map[tlv.Type][]byte{
|
||||
record.EncryptedDataOnionType: nil,
|
||||
record.BlindingPointOnionType: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown record, invalid",
|
||||
isFinal: false,
|
||||
parsed: map[tlv.Type][]byte{
|
||||
tlv.Type(99): nil,
|
||||
},
|
||||
err: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range tests {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
err := hop.ValidatePayloadWithBlinded(
|
||||
testCase.isFinal, testCase.parsed,
|
||||
)
|
||||
|
||||
// We can't determine our exact error because we
|
||||
// iterate through a map (non-deterministic) in the
|
||||
// function.
|
||||
if testCase.err {
|
||||
require.NotNil(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue