Merge pull request #9167 from ellemouton/mcEncodingToTLV

routing+migration32: update migration 32 to use pure TLV encoding for mission control results
This commit is contained in:
Oliver Gugger 2024-11-01 13:16:04 +01:00 committed by GitHub
commit 22ae3e5ddd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1742 additions and 692 deletions

View file

@ -2,8 +2,10 @@ package lnwire
import ( import (
"fmt" "fmt"
"io"
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/tlv"
) )
const ( const (
@ -49,3 +51,39 @@ func (m MilliSatoshi) String() string {
} }
// TODO(roasbeef): extend with arithmetic operations? // TODO(roasbeef): extend with arithmetic operations?
// Record returns a TLV record that can be used to encode/decode a MilliSatoshi
// to/from a TLV stream.
func (m *MilliSatoshi) Record() tlv.Record {
return tlv.MakeDynamicRecord(
0, m, tlv.SizeBigSize(m), encodeMilliSatoshis,
decodeMilliSatoshis,
)
}
func encodeMilliSatoshis(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*MilliSatoshi); ok {
bigSize := uint64(*v)
return tlv.EBigSize(w, &bigSize, buf)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.MilliSatoshi")
}
func decodeMilliSatoshis(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*MilliSatoshi); ok {
var bigSize uint64
err := tlv.DBigSize(r, &bigSize, buf, l)
if err != nil {
return err
}
*v = MilliSatoshi(bigSize)
return nil
}
return tlv.NewTypeForDecodingErr(val, "lnwire.MilliSatoshi", l, l)
}

View file

@ -0,0 +1,37 @@
package lnwire
import (
"io"
"github.com/lightningnetwork/lnd/tlv"
)
// TrueBoolean is a record that indicates true or false using the presence of
// the record. If the record is absent, it indicates false. If it is present,
// it indicates true.
type TrueBoolean struct{}
// Record returns the tlv record for the boolean entry.
func (b *TrueBoolean) Record() tlv.Record {
return tlv.MakeStaticRecord(
0, b, 0, booleanEncoder, booleanDecoder,
)
}
func booleanEncoder(_ io.Writer, val interface{}, _ *[8]byte) error {
if _, ok := val.(*TrueBoolean); ok {
return nil
}
return tlv.NewTypeForEncodingErr(val, "TrueBoolean")
}
func booleanDecoder(_ io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if _, ok := val.(*TrueBoolean); ok && (l == 0 || l == 1) {
return nil
}
return tlv.NewTypeForEncodingErr(val, "TrueBoolean")
}

View file

@ -9,6 +9,7 @@ import (
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/channeldb/migtest" "github.com/lightningnetwork/lnd/channeldb/migtest"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
) )
var ( var (
@ -24,34 +25,17 @@ var (
_ = pubKeyY.SetByteSlice(pubkeyBytes) _ = pubKeyY.SetByteSlice(pubkeyBytes)
pubkey = btcec.NewPublicKey(new(btcec.FieldVal).SetInt(4), pubKeyY) pubkey = btcec.NewPublicKey(new(btcec.FieldVal).SetInt(4), pubKeyY)
paymentResultCommon1 = paymentResultCommon{ customRecord = map[uint64][]byte{
65536: {4, 2, 2},
}
resultOld1 = paymentResultOld{
id: 0, id: 0,
timeFwd: time.Unix(0, 1), timeFwd: time.Unix(0, 1),
timeReply: time.Unix(0, 2), timeReply: time.Unix(0, 2),
success: false, success: false,
failureSourceIdx: &failureIndex, failureSourceIdx: &failureIndex,
failure: &lnwire.FailFeeInsufficient{}, failure: &lnwire.FailFeeInsufficient{},
}
paymentResultCommon2 = paymentResultCommon{
id: 2,
timeFwd: time.Unix(0, 4),
timeReply: time.Unix(0, 7),
success: true,
}
)
// TestMigrateMCRouteSerialisation tests that the MigrateMCRouteSerialisation
// migration function correctly migrates the MC store from using the old route
// encoding to using the newer, more minimal route encoding.
func TestMigrateMCRouteSerialisation(t *testing.T) {
customRecord := map[uint64][]byte{
65536: {4, 2, 2},
}
resultsOld := []*paymentResultOld{
{
paymentResultCommon: paymentResultCommon1,
route: &Route{ route: &Route{
TotalTimeLock: 100, TotalTimeLock: 100,
TotalAmount: 400, TotalAmount: 400,
@ -65,18 +49,12 @@ func TestMigrateMCRouteSerialisation(t *testing.T) {
OutgoingTimeLock: 300, OutgoingTimeLock: 300,
AmtToForward: 500, AmtToForward: 500,
MPP: &MPP{ MPP: &MPP{
paymentAddr: [32]byte{ paymentAddr: [32]byte{4, 5},
4, 5,
},
totalMsat: 900, totalMsat: 900,
}, },
AMP: &AMP{ AMP: &AMP{
rootShare: [32]byte{ rootShare: [32]byte{0, 0},
0, 0, setID: [32]byte{5, 5, 5},
},
setID: [32]byte{
5, 5, 5,
},
childIndex: 90, childIndex: 90,
}, },
CustomRecords: customRecord, CustomRecords: customRecord,
@ -97,9 +75,7 @@ func TestMigrateMCRouteSerialisation(t *testing.T) {
OutgoingTimeLock: 4, OutgoingTimeLock: 4,
AmtToForward: 4, AmtToForward: 4,
BlindingPoint: pubkey, BlindingPoint: pubkey,
EncryptedData: []byte{ EncryptedData: []byte{1, 2, 3},
1, 2, 3,
},
TotalAmtMsat: 600, TotalAmtMsat: 600,
}, },
// A hop with a blinding key and custom // A hop with a blinding key and custom
@ -111,16 +87,18 @@ func TestMigrateMCRouteSerialisation(t *testing.T) {
AmtToForward: 4, AmtToForward: 4,
CustomRecords: customRecord, CustomRecords: customRecord,
BlindingPoint: pubkey, BlindingPoint: pubkey,
EncryptedData: []byte{ EncryptedData: []byte{1, 2, 3},
1, 2, 3,
},
TotalAmtMsat: 600, TotalAmtMsat: 600,
}, },
}, },
}, },
}, }
{
paymentResultCommon: paymentResultCommon2, resultOld2 = paymentResultOld{
id: 2,
timeFwd: time.Unix(0, 4),
timeReply: time.Unix(0, 7),
success: true,
route: &Route{ route: &Route{
TotalTimeLock: 101, TotalTimeLock: 101,
TotalAmount: 401, TotalAmount: 401,
@ -132,67 +110,128 @@ func TestMigrateMCRouteSerialisation(t *testing.T) {
OutgoingTimeLock: 4, OutgoingTimeLock: 4,
AmtToForward: 4, AmtToForward: 4,
BlindingPoint: pubkey, BlindingPoint: pubkey,
EncryptedData: []byte{ EncryptedData: []byte{1, 2, 3},
1, 2, 3, CustomRecords: customRecord,
},
TotalAmtMsat: 600, TotalAmtMsat: 600,
}, },
}, },
}, },
},
} }
expectedResultsNew := []*paymentResultNew{ //nolint:lll
{ resultNew1Hop1 = &mcHop{
paymentResultCommon: paymentResultCommon1, channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](100),
route: &mcRoute{ pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub),
sourcePubKey: testPub, amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](500),
totalAmount: 400, hasCustomRecords: tlv.SomeRecordT(
hops: []*mcHop{ tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean](),
{ ),
channelID: 100,
pubKeyBytes: testPub,
amtToFwd: 500,
hasCustomRecords: true,
},
{
channelID: 800,
pubKeyBytes: testPub,
amtToFwd: 4,
},
{
channelID: 800,
pubKeyBytes: testPub,
amtToFwd: 4,
hasBlindingPoint: true,
},
{
channelID: 800,
pubKeyBytes: testPub,
amtToFwd: 4,
hasBlindingPoint: true,
hasCustomRecords: true,
},
},
},
},
{
paymentResultCommon: paymentResultCommon2,
route: &mcRoute{
sourcePubKey: testPub2,
totalAmount: 401,
hops: []*mcHop{
{
channelID: 800,
pubKeyBytes: testPub,
amtToFwd: 4,
hasBlindingPoint: true,
},
},
},
},
} }
//nolint:lll
resultNew1Hop2 = &mcHop{
channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800),
pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub),
amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4),
}
//nolint:lll
resultNew1Hop3 = &mcHop{
channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800),
pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub),
amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4),
hasBlindingPoint: tlv.SomeRecordT(
tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean](),
),
}
//nolint:lll
resultNew1Hop4 = &mcHop{
channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800),
pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub),
amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4),
hasCustomRecords: tlv.SomeRecordT(
tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean](),
),
hasBlindingPoint: tlv.SomeRecordT(
tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean](),
),
}
//nolint:lll
resultNew2Hop1 = &mcHop{
channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800),
pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub),
amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4),
hasCustomRecords: tlv.SomeRecordT(
tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean](),
),
hasBlindingPoint: tlv.SomeRecordT(
tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean](),
),
}
//nolint:lll
resultNew1 = paymentResultNew{
id: 0,
timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint64(time.Unix(0, 1).UnixNano()),
),
timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1](
uint64(time.Unix(0, 2).UnixNano()),
),
failure: tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3](
*newPaymentFailure(
&failureIndex,
&lnwire.FailFeeInsufficient{},
),
),
),
route: tlv.NewRecordT[tlv.TlvType2](mcRoute{
sourcePubKey: tlv.NewRecordT[tlv.TlvType0](testPub),
totalAmount: tlv.NewRecordT[tlv.TlvType1, lnwire.MilliSatoshi](400),
hops: tlv.NewRecordT[tlv.TlvType2, mcHops](mcHops{
resultNew1Hop1,
resultNew1Hop2,
resultNew1Hop3,
resultNew1Hop4,
}),
}),
}
//nolint:lll
resultNew2 = paymentResultNew{
id: 2,
timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](
uint64(time.Unix(0, 4).UnixNano()),
),
timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1, uint64](
uint64(time.Unix(0, 7).UnixNano()),
),
route: tlv.NewRecordT[tlv.TlvType2](mcRoute{
sourcePubKey: tlv.NewRecordT[tlv.TlvType0](testPub2),
totalAmount: tlv.NewRecordT[tlv.TlvType1, lnwire.MilliSatoshi](401),
hops: tlv.NewRecordT[tlv.TlvType2](mcHops{
resultNew2Hop1,
}),
}),
}
)
// TestMigrateMCRouteSerialisation tests that the MigrateMCRouteSerialisation
// migration function correctly migrates the MC store from using the old route
// encoding to using the newer, more minimal route encoding.
func TestMigrateMCRouteSerialisation(t *testing.T) {
var (
resultsOld = []*paymentResultOld{
&resultOld1, &resultOld2,
}
expectedResultsNew = []*paymentResultNew{
&resultNew1, &resultNew2,
}
)
// Prime the database with some mission control data that uses the // Prime the database with some mission control data that uses the
// old route encoding. // old route encoding.
before := func(tx kvdb.RwTx) error { before := func(tx kvdb.RwTx) error {

View file

@ -8,6 +8,8 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/tlv"
) )
const ( const (
@ -22,30 +24,22 @@ var (
resultsKey = []byte("missioncontrol-results") resultsKey = []byte("missioncontrol-results")
) )
// paymentResultCommon holds the fields that are shared by the old and new
// payment result encoding.
type paymentResultCommon struct {
id uint64
timeFwd, timeReply time.Time
success bool
failureSourceIdx *int
failure lnwire.FailureMessage
}
// paymentResultOld is the information that becomes available when a payment // paymentResultOld is the information that becomes available when a payment
// attempt completes. // attempt completes.
type paymentResultOld struct { type paymentResultOld struct {
paymentResultCommon id uint64
timeFwd, timeReply time.Time
route *Route route *Route
success bool
failureSourceIdx *int
failure lnwire.FailureMessage
} }
// deserializeOldResult deserializes a payment result using the old encoding. // deserializeOldResult deserializes a payment result using the old encoding.
func deserializeOldResult(k, v []byte) (*paymentResultOld, error) { func deserializeOldResult(k, v []byte) (*paymentResultOld, error) {
// Parse payment id. // Parse payment id.
result := paymentResultOld{ result := paymentResultOld{
paymentResultCommon: paymentResultCommon{
id: byteOrder.Uint64(k[8:]), id: byteOrder.Uint64(k[8:]),
},
} }
r := bytes.NewReader(v) r := bytes.NewReader(v)
@ -99,67 +93,563 @@ func deserializeOldResult(k, v []byte) (*paymentResultOld, error) {
// convertPaymentResult converts a paymentResultOld to a paymentResultNew. // convertPaymentResult converts a paymentResultOld to a paymentResultNew.
func convertPaymentResult(old *paymentResultOld) *paymentResultNew { func convertPaymentResult(old *paymentResultOld) *paymentResultNew {
return &paymentResultNew{ var failure *paymentFailure
paymentResultCommon: old.paymentResultCommon, if !old.success {
route: extractMCRoute(old.route), failure = newPaymentFailure(old.failureSourceIdx, old.failure)
} }
return newPaymentResult(
old.id, extractMCRoute(old.route), old.timeFwd, old.timeReply,
failure,
)
}
// newPaymentResult constructs a new paymentResult.
func newPaymentResult(id uint64, rt *mcRoute, timeFwd, timeReply time.Time,
failure *paymentFailure) *paymentResultNew {
result := &paymentResultNew{
id: id,
timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint64(timeFwd.UnixNano()),
),
timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1](
uint64(timeReply.UnixNano()),
),
route: tlv.NewRecordT[tlv.TlvType2](*rt),
}
if failure != nil {
result.failure = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3](*failure),
)
}
return result
} }
// paymentResultNew is the information that becomes available when a payment // paymentResultNew is the information that becomes available when a payment
// attempt completes. // attempt completes.
type paymentResultNew struct { type paymentResultNew struct {
paymentResultCommon id uint64
route *mcRoute timeFwd tlv.RecordT[tlv.TlvType0, uint64]
timeReply tlv.RecordT[tlv.TlvType1, uint64]
route tlv.RecordT[tlv.TlvType2, mcRoute]
// failure holds information related to the failure of a payment. The
// presence of this record indicates a payment failure. The absence of
// this record indicates a successful payment.
failure tlv.OptionalRecordT[tlv.TlvType3, paymentFailure]
}
// paymentFailure represents the presence of a payment failure. It may or may
// not include additional information about said failure.
type paymentFailure struct {
info tlv.OptionalRecordT[tlv.TlvType0, paymentFailureInfo]
}
// newPaymentFailure constructs a new paymentFailure struct. If the source
// index is nil, then an empty paymentFailure is returned. This represents a
// failure with unknown details. Otherwise, the index and failure message are
// used to populate the info field of the paymentFailure.
func newPaymentFailure(sourceIdx *int,
failureMsg lnwire.FailureMessage) *paymentFailure {
if sourceIdx == nil {
return &paymentFailure{}
}
info := paymentFailureInfo{
sourceIdx: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint8(*sourceIdx),
),
msg: tlv.NewRecordT[tlv.TlvType1](failureMessage{failureMsg}),
}
return &paymentFailure{
info: tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType0](info)),
}
}
// Record returns a TLV record that can be used to encode/decode a
// paymentFailure to/from a TLV stream.
func (r *paymentFailure) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodePaymentFailure(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodePaymentFailure, decodePaymentFailure,
)
}
func encodePaymentFailure(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*paymentFailure); ok {
var recordProducers []tlv.RecordProducer
v.info.WhenSome(
func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) {
recordProducers = append(recordProducers, &r)
},
)
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(recordProducers...),
)
}
return tlv.NewTypeForEncodingErr(val, "routing.paymentFailure")
}
func decodePaymentFailure(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*paymentFailure); ok {
var h paymentFailure
info := tlv.ZeroRecordT[tlv.TlvType0, paymentFailureInfo]()
typeMap, err := lnwire.DecodeRecords(
r, lnwire.ProduceRecordsSorted(&info)...,
)
if err != nil {
return err
}
if _, ok := typeMap[h.info.TlvType()]; ok {
h.info = tlv.SomeRecordT(info)
}
*v = h
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.paymentFailure", l, l)
}
// paymentFailureInfo holds additional information about a payment failure.
type paymentFailureInfo struct {
sourceIdx tlv.RecordT[tlv.TlvType0, uint8]
msg tlv.RecordT[tlv.TlvType1, failureMessage]
}
// Record returns a TLV record that can be used to encode/decode a
// paymentFailureInfo to/from a TLV stream.
func (r *paymentFailureInfo) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodePaymentFailureInfo(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodePaymentFailureInfo,
decodePaymentFailureInfo,
)
}
func encodePaymentFailureInfo(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*paymentFailureInfo); ok {
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(
&v.sourceIdx, &v.msg,
),
)
}
return tlv.NewTypeForEncodingErr(val, "routing.paymentFailureInfo")
}
func decodePaymentFailureInfo(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*paymentFailureInfo); ok {
var h paymentFailureInfo
_, err := lnwire.DecodeRecords(
r,
lnwire.ProduceRecordsSorted(&h.sourceIdx, &h.msg)...,
)
if err != nil {
return err
}
*v = h
return nil
}
return tlv.NewTypeForDecodingErr(
val, "routing.paymentFailureInfo", l, l,
)
}
type failureMessage struct {
lnwire.FailureMessage
}
// Record returns a TLV record that can be used to encode/decode a list of
// failureMessage to/from a TLV stream.
func (r *failureMessage) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeFailureMessage(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodeFailureMessage, decodeFailureMessage,
)
}
func encodeFailureMessage(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*failureMessage); ok {
var b bytes.Buffer
err := lnwire.EncodeFailureMessage(&b, v.FailureMessage, 0)
if err != nil {
return err
}
_, err = w.Write(b.Bytes())
return err
}
return tlv.NewTypeForEncodingErr(val, "routing.failureMessage")
}
func decodeFailureMessage(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*failureMessage); ok {
msg, err := lnwire.DecodeFailureMessage(r, 0)
if err != nil {
return err
}
*v = failureMessage{
FailureMessage: msg,
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.failureMessage", l, l)
} }
// extractMCRoute extracts the fields required by MC from the Route struct to // extractMCRoute extracts the fields required by MC from the Route struct to
// create the more minimal mcRoute struct. // create the more minimal mcRoute struct.
func extractMCRoute(route *Route) *mcRoute { func extractMCRoute(r *Route) *mcRoute {
return &mcRoute{ return &mcRoute{
sourcePubKey: route.SourcePubKey, sourcePubKey: tlv.NewRecordT[tlv.TlvType0](r.SourcePubKey),
totalAmount: route.TotalAmount, totalAmount: tlv.NewRecordT[tlv.TlvType1](r.TotalAmount),
hops: extractMCHops(route.Hops), hops: tlv.NewRecordT[tlv.TlvType2](
extractMCHops(r.Hops),
),
} }
} }
// extractMCHops extracts the Hop fields that MC actually uses from a slice of // extractMCHops extracts the Hop fields that MC actually uses from a slice of
// Hops. // Hops.
func extractMCHops(hops []*Hop) []*mcHop { func extractMCHops(hops []*Hop) mcHops {
mcHops := make([]*mcHop, len(hops)) return fn.Map(extractMCHop, hops)
for i, hop := range hops {
mcHops[i] = extractMCHop(hop)
}
return mcHops
} }
// extractMCHop extracts the Hop fields that MC actually uses from a Hop. // extractMCHop extracts the Hop fields that MC actually uses from a Hop.
func extractMCHop(hop *Hop) *mcHop { func extractMCHop(hop *Hop) *mcHop {
return &mcHop{ h := mcHop{
channelID: hop.ChannelID, channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](
pubKeyBytes: hop.PubKeyBytes, hop.ChannelID,
amtToFwd: hop.AmtToForward, ),
hasBlindingPoint: hop.BlindingPoint != nil, pubKeyBytes: tlv.NewRecordT[tlv.TlvType1, Vertex](
hasCustomRecords: len(hop.CustomRecords) > 0, hop.PubKeyBytes,
),
amtToFwd: tlv.NewRecordT[tlv.TlvType2, lnwire.MilliSatoshi](
hop.AmtToForward,
),
} }
if hop.BlindingPoint != nil {
h.hasBlindingPoint = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3, lnwire.TrueBoolean](
lnwire.TrueBoolean{},
),
)
}
if len(hop.CustomRecords) != 0 {
h.hasCustomRecords = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType4, lnwire.TrueBoolean](
lnwire.TrueBoolean{},
),
)
}
return &h
} }
// mcRoute holds the bare minimum info about a payment attempt route that MC // mcRoute holds the bare minimum info about a payment attempt route that MC
// requires. // requires.
type mcRoute struct { type mcRoute struct {
sourcePubKey Vertex sourcePubKey tlv.RecordT[tlv.TlvType0, Vertex]
totalAmount lnwire.MilliSatoshi totalAmount tlv.RecordT[tlv.TlvType1, lnwire.MilliSatoshi]
hops []*mcHop hops tlv.RecordT[tlv.TlvType2, mcHops]
}
// Record returns a TLV record that can be used to encode/decode an mcRoute
// to/from a TLV stream.
func (r *mcRoute) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeMCRoute(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodeMCRoute, decodeMCRoute,
)
}
func encodeMCRoute(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*mcRoute); ok {
return serializeRoute(w, v)
}
return tlv.NewTypeForEncodingErr(val, "routing.mcRoute")
}
func decodeMCRoute(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if v, ok := val.(*mcRoute); ok {
route, err := deserializeRoute(io.LimitReader(r, int64(l)))
if err != nil {
return err
}
*v = *route
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.mcRoute", l, l)
}
// mcHops is a list of mcHop records.
type mcHops []*mcHop
// Record returns a TLV record that can be used to encode/decode a list of
// mcHop to/from a TLV stream.
func (h *mcHops) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeMCHops(&b, h, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, h, recordSize, encodeMCHops, decodeMCHops,
)
}
func encodeMCHops(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*mcHops); ok {
// Encode the number of hops as a var int.
if err := tlv.WriteVarInt(w, uint64(len(*v)), buf); err != nil {
return err
}
// With that written out, we'll now encode the entries
// themselves as a sub-TLV record, which includes its _own_
// inner length prefix.
for _, hop := range *v {
var hopBytes bytes.Buffer
if err := serializeNewHop(&hopBytes, hop); err != nil {
return err
}
// We encode the record with a varint length followed by
// the _raw_ TLV bytes.
tlvLen := uint64(len(hopBytes.Bytes()))
if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
return err
}
if _, err := w.Write(hopBytes.Bytes()); err != nil {
return err
}
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "routing.mcHops")
}
func decodeMCHops(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*mcHops); ok {
// First, we'll decode the varint that encodes how many hops
// are encoded in the stream.
numHops, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Now that we know how many records we'll need to read, we can
// iterate and read them all out in series.
for i := uint64(0); i < numHops; i++ {
// Read out the varint that encodes the size of this
// inner TLV record.
hopSize, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Using this information, we'll create a new limited
// reader that'll return an EOF once the end has been
// reached so the stream stops consuming bytes.
innerTlvReader := &io.LimitedReader{
R: r,
N: int64(hopSize),
}
hop, err := deserializeNewHop(innerTlvReader)
if err != nil {
return err
}
*v = append(*v, hop)
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.mcHops", l, l)
}
// serializeRoute serializes a mcRoute and writes the resulting bytes to the
// given io.Writer.
func serializeRoute(w io.Writer, r *mcRoute) error {
records := lnwire.ProduceRecordsSorted(
&r.sourcePubKey,
&r.totalAmount,
&r.hops,
)
return lnwire.EncodeRecordsTo(w, records)
}
// deserializeRoute deserializes the mcRoute from the given io.Reader.
func deserializeRoute(r io.Reader) (*mcRoute, error) {
var rt mcRoute
records := lnwire.ProduceRecordsSorted(
&rt.sourcePubKey,
&rt.totalAmount,
&rt.hops,
)
_, err := lnwire.DecodeRecords(r, records...)
if err != nil {
return nil, err
}
return &rt, nil
}
// deserializeNewHop deserializes the mcHop from the given io.Reader.
func deserializeNewHop(r io.Reader) (*mcHop, error) {
var (
h mcHop
blinding = tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean]()
custom = tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean]()
)
records := lnwire.ProduceRecordsSorted(
&h.channelID,
&h.pubKeyBytes,
&h.amtToFwd,
&blinding,
&custom,
)
typeMap, err := lnwire.DecodeRecords(r, records...)
if err != nil {
return nil, err
}
if _, ok := typeMap[h.hasBlindingPoint.TlvType()]; ok {
h.hasBlindingPoint = tlv.SomeRecordT(blinding)
}
if _, ok := typeMap[h.hasCustomRecords.TlvType()]; ok {
h.hasCustomRecords = tlv.SomeRecordT(custom)
}
return &h, nil
}
// serializeNewHop serializes a mcHop and writes the resulting bytes to the
// given io.Writer.
func serializeNewHop(w io.Writer, h *mcHop) error {
recordProducers := []tlv.RecordProducer{
&h.channelID,
&h.pubKeyBytes,
&h.amtToFwd,
}
h.hasBlindingPoint.WhenSome(func(
hasBlinding tlv.RecordT[tlv.TlvType3, lnwire.TrueBoolean]) {
recordProducers = append(recordProducers, &hasBlinding)
})
h.hasCustomRecords.WhenSome(func(
hasCustom tlv.RecordT[tlv.TlvType4, lnwire.TrueBoolean]) {
recordProducers = append(recordProducers, &hasCustom)
})
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(recordProducers...),
)
} }
// mcHop holds the bare minimum info about a payment attempt route hop that MC // mcHop holds the bare minimum info about a payment attempt route hop that MC
// requires. // requires.
type mcHop struct { type mcHop struct {
channelID uint64 channelID tlv.RecordT[tlv.TlvType0, uint64]
pubKeyBytes Vertex pubKeyBytes tlv.RecordT[tlv.TlvType1, Vertex]
amtToFwd lnwire.MilliSatoshi amtToFwd tlv.RecordT[tlv.TlvType2, lnwire.MilliSatoshi]
hasBlindingPoint bool hasBlindingPoint tlv.OptionalRecordT[tlv.TlvType3, lnwire.TrueBoolean]
hasCustomRecords bool hasCustomRecords tlv.OptionalRecordT[tlv.TlvType4, lnwire.TrueBoolean]
} }
// serializeOldResult serializes a payment result and returns a key and value // serializeOldResult serializes a payment result and returns a key and value
@ -225,48 +715,30 @@ func getResultKeyOld(rp *paymentResultOld) []byte {
// serializeNewResult serializes a payment result and returns a key and value // serializeNewResult serializes a payment result and returns a key and value
// byte slice to insert into the bucket. // byte slice to insert into the bucket.
func serializeNewResult(rp *paymentResultNew) ([]byte, []byte, error) { func serializeNewResult(rp *paymentResultNew) ([]byte, []byte, error) {
// Write timestamps, success status, failure source index and route. recordProducers := []tlv.RecordProducer{
var b bytes.Buffer &rp.timeFwd,
&rp.timeReply,
var dbFailureSourceIdx int32 &rp.route,
if rp.failureSourceIdx == nil {
dbFailureSourceIdx = unknownFailureSourceIdx
} else {
dbFailureSourceIdx = int32(*rp.failureSourceIdx)
} }
err := WriteElements( rp.failure.WhenSome(
&b, func(failure tlv.RecordT[tlv.TlvType3, paymentFailure]) {
uint64(rp.timeFwd.UnixNano()), recordProducers = append(recordProducers, &failure)
uint64(rp.timeReply.UnixNano()), },
rp.success, dbFailureSourceIdx, )
// Compose key that identifies this result.
key := getResultKeyNew(rp)
var buff bytes.Buffer
err := lnwire.EncodeRecordsTo(
&buff, lnwire.ProduceRecordsSorted(recordProducers...),
) )
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if err := serializeMCRoute(&b, rp.route); err != nil { return key, buff.Bytes(), err
return nil, nil, err
}
// Write failure. If there is no failure message, write an empty
// byte slice.
var failureBytes bytes.Buffer
if rp.failure != nil {
err := lnwire.EncodeFailureMessage(&failureBytes, rp.failure, 0)
if err != nil {
return nil, nil, err
}
}
err = wire.WriteVarBytes(&b, 0, failureBytes.Bytes())
if err != nil {
return nil, nil, err
}
// Compose key that identifies this result.
key := getResultKeyNew(rp)
return key, b.Bytes(), nil
} }
// getResultKeyNew returns a byte slice representing a unique key for this // getResultKeyNew returns a byte slice representing a unique key for this
@ -278,43 +750,9 @@ func getResultKeyNew(rp *paymentResultNew) []byte {
// key. This allows importing mission control data from an external // key. This allows importing mission control data from an external
// source without key collisions and keeps the records sorted // source without key collisions and keeps the records sorted
// chronologically. // chronologically.
byteOrder.PutUint64(keyBytes[:], uint64(rp.timeReply.UnixNano())) byteOrder.PutUint64(keyBytes[:], rp.timeReply.Val)
byteOrder.PutUint64(keyBytes[8:], rp.id) byteOrder.PutUint64(keyBytes[8:], rp.id)
copy(keyBytes[16:], rp.route.sourcePubKey[:]) copy(keyBytes[16:], rp.route.Val.sourcePubKey.Val[:])
return keyBytes[:] return keyBytes[:]
} }
// serializeMCRoute serializes an mcRoute and writes the bytes to the given
// io.Writer.
func serializeMCRoute(w io.Writer, r *mcRoute) error {
if err := WriteElements(
w, r.totalAmount, r.sourcePubKey[:],
); err != nil {
return err
}
if err := WriteElements(w, uint32(len(r.hops))); err != nil {
return err
}
for _, h := range r.hops {
if err := serializeNewHop(w, h); err != nil {
return err
}
}
return nil
}
// serializeMCRoute serializes an mcHop and writes the bytes to the given
// io.Writer.
func serializeNewHop(w io.Writer, h *mcHop) error {
return WriteElements(w,
h.pubKeyBytes[:],
h.channelID,
h.amtToFwd,
h.hasBlindingPoint,
h.hasCustomRecords,
)
}

View file

@ -29,6 +29,32 @@ const VertexSize = 33
// public key. // public key.
type Vertex [VertexSize]byte type Vertex [VertexSize]byte
// Record returns a TLV record that can be used to encode/decode a Vertex
// to/from a TLV stream.
func (v *Vertex) Record() tlv.Record {
return tlv.MakeStaticRecord(
0, v, VertexSize, encodeVertex, decodeVertex,
)
}
func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error {
if b, ok := val.(*Vertex); ok {
_, err := w.Write(b[:])
return err
}
return tlv.NewTypeForEncodingErr(val, "Vertex")
}
func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*Vertex); ok {
_, err := io.ReadFull(r, b[:])
return err
}
return tlv.NewTypeForDecodingErr(val, "Vertex", l, VertexSize)
}
// Route represents a path through the channel graph which runs over one or // Route represents a path through the channel graph which runs over one or
// more channels in succession. This struct carries all the information // more channels in succession. This struct carries all the information
// required to craft the Sphinx onion packet, and send the payment along the // required to craft the Sphinx onion packet, and send the payment along the

View file

@ -122,7 +122,8 @@
* [Migrate the mission control * [Migrate the mission control
store](https://github.com/lightningnetwork/lnd/pull/8911) to use a more store](https://github.com/lightningnetwork/lnd/pull/8911) to use a more
minimal encoding for payment attempt routes. minimal encoding for payment attempt routes as well as use [pure TLV
encoding](https://github.com/lightningnetwork/lnd/pull/9167).
* [Migrate the mission control * [Migrate the mission control
store](https://github.com/lightningnetwork/lnd/pull/9001) so that results are store](https://github.com/lightningnetwork/lnd/pull/9001) so that results are

View file

@ -411,7 +411,7 @@ func decodeDisableFlags(r io.Reader, val interface{}, buf *[8]byte,
} }
// TrueBoolean is a record that indicates true or false using the presence of // TrueBoolean is a record that indicates true or false using the presence of
// the record. If the record is absent, it indicates false. If it is presence, // the record. If the record is absent, it indicates false. If it is present,
// it indicates true. // it indicates true.
type TrueBoolean struct{} type TrueBoolean struct{}

View file

@ -1,8 +1,10 @@
package routing package routing
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io"
"sync" "sync"
"time" "time"
@ -16,6 +18,7 @@ import (
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
) )
const ( const (
@ -262,11 +265,38 @@ type MissionControlPairSnapshot struct {
// attempt completes. // attempt completes.
type paymentResult struct { type paymentResult struct {
id uint64 id uint64
timeFwd, timeReply time.Time timeFwd tlv.RecordT[tlv.TlvType0, uint64]
route *mcRoute timeReply tlv.RecordT[tlv.TlvType1, uint64]
success bool route tlv.RecordT[tlv.TlvType2, mcRoute]
failureSourceIdx *int
failure lnwire.FailureMessage // failure holds information related to the failure of a payment. The
// presence of this record indicates a payment failure. The absence of
// this record indicates a successful payment.
failure tlv.OptionalRecordT[tlv.TlvType3, paymentFailure]
}
// newPaymentResult constructs a new paymentResult.
func newPaymentResult(id uint64, rt *mcRoute, timeFwd, timeReply time.Time,
failure *paymentFailure) *paymentResult {
result := &paymentResult{
id: id,
timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint64(timeFwd.UnixNano()),
),
timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1](
uint64(timeReply.UnixNano()),
),
route: tlv.NewRecordT[tlv.TlvType2](*rt),
}
if failure != nil {
result.failure = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3](*failure),
)
}
return result
} }
// NewMissionController returns a new instance of MissionController. // NewMissionController returns a new instance of MissionController.
@ -590,15 +620,10 @@ func (m *MissionControl) ReportPaymentFail(paymentID uint64, rt *route.Route,
timestamp := m.cfg.clock.Now() timestamp := m.cfg.clock.Now()
result := &paymentResult{ result := newPaymentResult(
success: false, paymentID, extractMCRoute(rt), timestamp, timestamp,
timeFwd: timestamp, newPaymentFailure(failureSourceIdx, failure),
timeReply: timestamp, )
id: paymentID,
failureSourceIdx: failureSourceIdx,
failure: failure,
route: extractMCRoute(rt),
}
return m.processPaymentResult(result) return m.processPaymentResult(result)
} }
@ -610,15 +635,12 @@ func (m *MissionControl) ReportPaymentSuccess(paymentID uint64,
timestamp := m.cfg.clock.Now() timestamp := m.cfg.clock.Now()
result := &paymentResult{ result := newPaymentResult(
timeFwd: timestamp, paymentID, extractMCRoute(rt), timestamp, timestamp, nil,
timeReply: timestamp, )
id: paymentID,
success: true,
route: extractMCRoute(rt),
}
_, err := m.processPaymentResult(result) _, err := m.processPaymentResult(result)
return err return err
} }
@ -646,14 +668,11 @@ func (m *MissionControl) applyPaymentResult(
result *paymentResult) *channeldb.FailureReason { result *paymentResult) *channeldb.FailureReason {
// Interpret result. // Interpret result.
i := interpretResult( i := interpretResult(&result.route.Val, result.failure.ValOpt())
result.route, result.success, result.failureSourceIdx,
result.failure,
)
if i.policyFailure != nil { if i.policyFailure != nil {
if m.state.requestSecondChance( if m.state.requestSecondChance(
result.timeReply, time.Unix(0, int64(result.timeReply.Val)),
i.policyFailure.From, i.policyFailure.To, i.policyFailure.From, i.policyFailure.To,
) { ) {
return nil return nil
@ -681,7 +700,10 @@ func (m *MissionControl) applyPaymentResult(
m.log.Debugf("Reporting node failure to Mission Control: "+ m.log.Debugf("Reporting node failure to Mission Control: "+
"node=%v", *i.nodeFailure) "node=%v", *i.nodeFailure)
m.state.setAllFail(*i.nodeFailure, result.timeReply) m.state.setAllFail(
*i.nodeFailure,
time.Unix(0, int64(result.timeReply.Val)),
)
} }
for pair, pairResult := range i.pairResults { for pair, pairResult := range i.pairResults {
@ -698,7 +720,9 @@ func (m *MissionControl) applyPaymentResult(
} }
m.state.setLastPairResult( m.state.setLastPairResult(
pair.From, pair.To, result.timeReply, &pairResult, false, pair.From, pair.To,
time.Unix(0, int64(result.timeReply.Val)), &pairResult,
false,
) )
} }
@ -803,3 +827,158 @@ func (n *namespacedDB) purge() error {
return err return err
}, func() {}) }, func() {})
} }
// paymentFailure represents the presence of a payment failure. It may or may
// not include additional information about said failure.
type paymentFailure struct {
info tlv.OptionalRecordT[tlv.TlvType0, paymentFailureInfo]
}
// newPaymentFailure constructs a new paymentFailure struct. If the source
// index is nil, then an empty paymentFailure is returned. This represents a
// failure with unknown details. Otherwise, the index and failure message are
// used to populate the info field of the paymentFailure.
func newPaymentFailure(sourceIdx *int,
failureMsg lnwire.FailureMessage) *paymentFailure {
if sourceIdx == nil {
return &paymentFailure{}
}
info := paymentFailureInfo{
sourceIdx: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint8(*sourceIdx),
),
msg: tlv.NewRecordT[tlv.TlvType1](failureMessage{failureMsg}),
}
return &paymentFailure{
info: tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType0](info)),
}
}
// Record returns a TLV record that can be used to encode/decode a
// paymentFailure to/from a TLV stream.
func (r *paymentFailure) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodePaymentFailure(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodePaymentFailure, decodePaymentFailure,
)
}
func encodePaymentFailure(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*paymentFailure); ok {
var recordProducers []tlv.RecordProducer
v.info.WhenSome(
func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) {
recordProducers = append(recordProducers, &r)
},
)
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(recordProducers...),
)
}
return tlv.NewTypeForEncodingErr(val, "routing.paymentFailure")
}
func decodePaymentFailure(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*paymentFailure); ok {
var h paymentFailure
info := tlv.ZeroRecordT[tlv.TlvType0, paymentFailureInfo]()
typeMap, err := lnwire.DecodeRecords(
r, lnwire.ProduceRecordsSorted(&info)...,
)
if err != nil {
return err
}
if _, ok := typeMap[h.info.TlvType()]; ok {
h.info = tlv.SomeRecordT(info)
}
*v = h
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.paymentFailure", l, l)
}
// paymentFailureInfo holds additional information about a payment failure.
type paymentFailureInfo struct {
sourceIdx tlv.RecordT[tlv.TlvType0, uint8]
msg tlv.RecordT[tlv.TlvType1, failureMessage]
}
// Record returns a TLV record that can be used to encode/decode a
// paymentFailureInfo to/from a TLV stream.
func (r *paymentFailureInfo) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodePaymentFailureInfo(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodePaymentFailureInfo,
decodePaymentFailureInfo,
)
}
func encodePaymentFailureInfo(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*paymentFailureInfo); ok {
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(
&v.sourceIdx, &v.msg,
),
)
}
return tlv.NewTypeForEncodingErr(val, "routing.paymentFailureInfo")
}
func decodePaymentFailureInfo(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*paymentFailureInfo); ok {
var h paymentFailureInfo
_, err := lnwire.DecodeRecords(
r,
lnwire.ProduceRecordsSorted(&h.sourceIdx, &h.msg)...,
)
if err != nil {
return err
}
*v = h
return nil
}
return tlv.NewTypeForDecodingErr(
val, "routing.paymentFailureInfo", l, l,
)
}

View file

@ -6,14 +6,12 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"math"
"sync" "sync"
"time" "time"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
) )
var ( var (
@ -26,12 +24,6 @@ var (
byteOrder = binary.BigEndian byteOrder = binary.BigEndian
) )
const (
// unknownFailureSourceIdx is the database encoding of an unknown error
// source.
unknownFailureSourceIdx = -1
)
// missionControlDB is an interface that defines the database methods that a // missionControlDB is an interface that defines the database methods that a
// single missionControlStore has access to. It allows the missionControlStore // single missionControlStore has access to. It allows the missionControlStore
// to be unaware of the overall DB structure and restricts its access to the DB // to be unaware of the overall DB structure and restricts its access to the DB
@ -168,132 +160,30 @@ func (b *missionControlStore) fetchAll() ([]*paymentResult, error) {
// serializeResult serializes a payment result and returns a key and value byte // serializeResult serializes a payment result and returns a key and value byte
// slice to insert into the bucket. // slice to insert into the bucket.
func serializeResult(rp *paymentResult) ([]byte, []byte, error) { func serializeResult(rp *paymentResult) ([]byte, []byte, error) {
// Write timestamps, success status, failure source index and route. recordProducers := []tlv.RecordProducer{
var b bytes.Buffer &rp.timeFwd,
&rp.timeReply,
var dbFailureSourceIdx int32 &rp.route,
if rp.failureSourceIdx == nil {
dbFailureSourceIdx = unknownFailureSourceIdx
} else {
dbFailureSourceIdx = int32(*rp.failureSourceIdx)
} }
err := channeldb.WriteElements( rp.failure.WhenSome(
&b, func(failure tlv.RecordT[tlv.TlvType3, paymentFailure]) {
uint64(rp.timeFwd.UnixNano()), recordProducers = append(recordProducers, &failure)
uint64(rp.timeReply.UnixNano()), },
rp.success, dbFailureSourceIdx,
) )
if err != nil {
return nil, nil, err
}
if err := serializeRoute(&b, rp.route); err != nil {
return nil, nil, err
}
// Write failure. If there is no failure message, write an empty
// byte slice.
var failureBytes bytes.Buffer
if rp.failure != nil {
err := lnwire.EncodeFailureMessage(&failureBytes, rp.failure, 0)
if err != nil {
return nil, nil, err
}
}
err = wire.WriteVarBytes(&b, 0, failureBytes.Bytes())
if err != nil {
return nil, nil, err
}
// Compose key that identifies this result. // Compose key that identifies this result.
key := getResultKey(rp) key := getResultKey(rp)
return key, b.Bytes(), nil var buff bytes.Buffer
} err := lnwire.EncodeRecordsTo(
&buff, lnwire.ProduceRecordsSorted(recordProducers...),
// deserializeRoute deserializes the mcRoute from the given io.Reader.
func deserializeRoute(r io.Reader) (*mcRoute, error) {
var rt mcRoute
if err := channeldb.ReadElements(r, &rt.totalAmount); err != nil {
return nil, err
}
var pub []byte
if err := channeldb.ReadElements(r, &pub); err != nil {
return nil, err
}
copy(rt.sourcePubKey[:], pub)
var numHops uint32
if err := channeldb.ReadElements(r, &numHops); err != nil {
return nil, err
}
var hops []*mcHop
for i := uint32(0); i < numHops; i++ {
hop, err := deserializeHop(r)
if err != nil {
return nil, err
}
hops = append(hops, hop)
}
rt.hops = hops
return &rt, nil
}
// deserializeHop deserializes the mcHop from the given io.Reader.
func deserializeHop(r io.Reader) (*mcHop, error) {
var h mcHop
var pub []byte
if err := channeldb.ReadElements(r, &pub); err != nil {
return nil, err
}
copy(h.pubKeyBytes[:], pub)
if err := channeldb.ReadElements(r,
&h.channelID, &h.amtToFwd, &h.hasBlindingPoint,
&h.hasCustomRecords,
); err != nil {
return nil, err
}
return &h, nil
}
// serializeRoute serializes a mcRoute and writes the resulting bytes to the
// given io.Writer.
func serializeRoute(w io.Writer, r *mcRoute) error {
err := channeldb.WriteElements(w, r.totalAmount, r.sourcePubKey[:])
if err != nil {
return err
}
if err := channeldb.WriteElements(w, uint32(len(r.hops))); err != nil {
return err
}
for _, h := range r.hops {
if err := serializeHop(w, h); err != nil {
return err
}
}
return nil
}
// serializeHop serializes a mcHop and writes the resulting bytes to the given
// io.Writer.
func serializeHop(w io.Writer, h *mcHop) error {
return channeldb.WriteElements(w,
h.pubKeyBytes[:],
h.channelID,
h.amtToFwd,
h.hasBlindingPoint,
h.hasCustomRecords,
) )
if err != nil {
return nil, nil, err
}
return key, buff.Bytes(), err
} }
// deserializeResult deserializes a payment result. // deserializeResult deserializes a payment result.
@ -303,57 +193,115 @@ func deserializeResult(k, v []byte) (*paymentResult, error) {
id: byteOrder.Uint64(k[8:]), id: byteOrder.Uint64(k[8:]),
} }
failure := tlv.ZeroRecordT[tlv.TlvType3, paymentFailure]()
recordProducers := []tlv.RecordProducer{
&result.timeFwd,
&result.timeReply,
&result.route,
&failure,
}
r := bytes.NewReader(v) r := bytes.NewReader(v)
typeMap, err := lnwire.DecodeRecords(
// Read timestamps, success status and failure source index. r, lnwire.ProduceRecordsSorted(recordProducers...)...,
var (
timeFwd, timeReply uint64
dbFailureSourceIdx int32
)
err := channeldb.ReadElements(
r, &timeFwd, &timeReply, &result.success, &dbFailureSourceIdx,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Convert time stamps to local time zone for consistent logging. if _, ok := typeMap[result.failure.TlvType()]; ok {
result.timeFwd = time.Unix(0, int64(timeFwd)).Local() result.failure = tlv.SomeRecordT(failure)
result.timeReply = time.Unix(0, int64(timeReply)).Local()
// Convert from unknown index magic number to nil value.
if dbFailureSourceIdx != unknownFailureSourceIdx {
failureSourceIdx := int(dbFailureSourceIdx)
result.failureSourceIdx = &failureSourceIdx
}
// Read route.
route, err := deserializeRoute(r)
if err != nil {
return nil, err
}
result.route = route
// Read failure.
failureBytes, err := wire.ReadVarBytes(
r, 0, math.MaxUint16, "failure",
)
if err != nil {
return nil, err
}
if len(failureBytes) > 0 {
result.failure, err = lnwire.DecodeFailureMessage(
bytes.NewReader(failureBytes), 0,
)
if err != nil {
return nil, err
}
} }
return &result, nil return &result, nil
} }
// serializeRoute serializes a mcRoute and writes the resulting bytes to the
// given io.Writer.
func serializeRoute(w io.Writer, r *mcRoute) error {
records := lnwire.ProduceRecordsSorted(
&r.sourcePubKey,
&r.totalAmount,
&r.hops,
)
return lnwire.EncodeRecordsTo(w, records)
}
// deserializeRoute deserializes the mcRoute from the given io.Reader.
func deserializeRoute(r io.Reader) (*mcRoute, error) {
var rt mcRoute
records := lnwire.ProduceRecordsSorted(
&rt.sourcePubKey,
&rt.totalAmount,
&rt.hops,
)
_, err := lnwire.DecodeRecords(r, records...)
if err != nil {
return nil, err
}
return &rt, nil
}
// deserializeHop deserializes the mcHop from the given io.Reader.
func deserializeHop(r io.Reader) (*mcHop, error) {
var (
h mcHop
blinding = tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean]()
custom = tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean]()
)
records := lnwire.ProduceRecordsSorted(
&h.channelID,
&h.pubKeyBytes,
&h.amtToFwd,
&blinding,
&custom,
)
typeMap, err := lnwire.DecodeRecords(r, records...)
if err != nil {
return nil, err
}
if _, ok := typeMap[h.hasBlindingPoint.TlvType()]; ok {
h.hasBlindingPoint = tlv.SomeRecordT(blinding)
}
if _, ok := typeMap[h.hasCustomRecords.TlvType()]; ok {
h.hasCustomRecords = tlv.SomeRecordT(custom)
}
return &h, nil
}
// serializeHop serializes a mcHop and writes the resulting bytes to the given
// io.Writer.
func serializeHop(w io.Writer, h *mcHop) error {
recordProducers := []tlv.RecordProducer{
&h.channelID,
&h.pubKeyBytes,
&h.amtToFwd,
}
h.hasBlindingPoint.WhenSome(func(
hasBlinding tlv.RecordT[tlv.TlvType3, lnwire.TrueBoolean]) {
recordProducers = append(recordProducers, &hasBlinding)
})
h.hasCustomRecords.WhenSome(func(
hasCustom tlv.RecordT[tlv.TlvType4, lnwire.TrueBoolean]) {
recordProducers = append(recordProducers, &hasCustom)
})
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(recordProducers...),
)
}
// AddResult adds a new result to the db. // AddResult adds a new result to the db.
func (b *missionControlStore) AddResult(rp *paymentResult) { func (b *missionControlStore) AddResult(rp *paymentResult) {
b.queueCond.L.Lock() b.queueCond.L.Lock()
@ -580,9 +528,70 @@ func getResultKey(rp *paymentResult) []byte {
// key. This allows importing mission control data from an external // key. This allows importing mission control data from an external
// source without key collisions and keeps the records sorted // source without key collisions and keeps the records sorted
// chronologically. // chronologically.
byteOrder.PutUint64(keyBytes[:], uint64(rp.timeReply.UnixNano())) byteOrder.PutUint64(keyBytes[:], rp.timeReply.Val)
byteOrder.PutUint64(keyBytes[8:], rp.id) byteOrder.PutUint64(keyBytes[8:], rp.id)
copy(keyBytes[16:], rp.route.sourcePubKey[:]) copy(keyBytes[16:], rp.route.Val.sourcePubKey.Val[:])
return keyBytes[:] return keyBytes[:]
} }
// failureMessage wraps the lnwire.FailureMessage interface such that we can
// apply a Record method and use the failureMessage in a TLV encoded type.
type failureMessage struct {
lnwire.FailureMessage
}
// Record returns a TLV record that can be used to encode/decode a list of
// failureMessage to/from a TLV stream.
func (r *failureMessage) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeFailureMessage(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodeFailureMessage, decodeFailureMessage,
)
}
func encodeFailureMessage(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*failureMessage); ok {
var b bytes.Buffer
err := lnwire.EncodeFailureMessage(&b, v.FailureMessage, 0)
if err != nil {
return err
}
_, err = w.Write(b.Bytes())
return err
}
return tlv.NewTypeForEncodingErr(val, "routing.failureMessage")
}
func decodeFailureMessage(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*failureMessage); ok {
msg, err := lnwire.DecodeFailureMessage(r, 0)
if err != nil {
return err
}
*v = failureMessage{
FailureMessage: msg,
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.failureMessage", l, l)
}

View file

@ -11,27 +11,25 @@ import (
"github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntest/wait"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const testMaxRecords = 2 const testMaxRecords = 2
var (
// mcStoreTestRoute is a test route for the mission control store tests. // mcStoreTestRoute is a test route for the mission control store tests.
mcStoreTestRoute = mcRoute{ var mcStoreTestRoute = extractMCRoute(&route.Route{
totalAmount: lnwire.MilliSatoshi(5), TotalAmount: lnwire.MilliSatoshi(5),
sourcePubKey: route.Vertex{1}, SourcePubKey: route.Vertex{1},
hops: []*mcHop{ Hops: []*route.Hop{
{ {
pubKeyBytes: route.Vertex{2}, PubKeyBytes: route.Vertex{2},
channelID: 4, ChannelID: 4,
amtToFwd: lnwire.MilliSatoshi(7), AmtToForward: lnwire.MilliSatoshi(7),
hasCustomRecords: true, CustomRecords: make(map[uint64][]byte),
hasBlindingPoint: false,
}, },
}, },
} })
)
// mcStoreTestHarness is the harness for a MissonControlStore test. // mcStoreTestHarness is the harness for a MissonControlStore test.
type mcStoreTestHarness struct { type mcStoreTestHarness struct {
@ -84,28 +82,31 @@ func TestMissionControlStore(t *testing.T) {
failureSourceIdx := 1 failureSourceIdx := 1
result1 := paymentResult{ result1 := newPaymentResult(
route: &mcStoreTestRoute, 99, mcStoreTestRoute, testTime, testTime,
failure: lnwire.NewFailIncorrectDetails(100, 1000), newPaymentFailure(
failureSourceIdx: &failureSourceIdx, &failureSourceIdx,
id: 99, lnwire.NewFailIncorrectDetails(100, 1000),
timeReply: testTime, ),
timeFwd: testTime.Add(-time.Minute), )
}
result2 := result1 result2 := newPaymentResult(
result2.timeReply = result1.timeReply.Add(time.Hour) 2, mcStoreTestRoute, testTime.Add(time.Hour),
result2.timeFwd = result1.timeReply.Add(time.Hour) testTime.Add(time.Hour),
result2.id = 2 newPaymentFailure(
&failureSourceIdx,
lnwire.NewFailIncorrectDetails(100, 1000),
),
)
// Store result. // Store result.
store.AddResult(&result2) store.AddResult(result2)
// Store again to test idempotency. // Store again to test idempotency.
store.AddResult(&result2) store.AddResult(result2)
// Store second result which has an earlier timestamp. // Store second result which has an earlier timestamp.
store.AddResult(&result1) store.AddResult(result1)
require.NoError(t, store.storeResults()) require.NoError(t, store.storeResults())
results, err = store.fetchAll() results, err = store.fetchAll()
@ -113,8 +114,8 @@ func TestMissionControlStore(t *testing.T) {
require.Len(t, results, 2) require.Len(t, results, 2)
// Check that results are stored in chronological order. // Check that results are stored in chronological order.
require.Equal(t, &result1, results[0]) require.Equal(t, result1, results[0])
require.Equal(t, &result2, results[1]) require.Equal(t, result2, results[1])
// Recreate store to test pruning. // Recreate store to test pruning.
store, err = newMissionControlStore( store, err = newMissionControlStore(
@ -124,12 +125,20 @@ func TestMissionControlStore(t *testing.T) {
// Add a newer result which failed due to mpp timeout. // Add a newer result which failed due to mpp timeout.
result3 := result1 result3 := result1
result3.timeReply = result1.timeReply.Add(2 * time.Hour) result3.timeReply = tlv.NewPrimitiveRecord[tlv.TlvType1](
result3.timeFwd = result1.timeReply.Add(2 * time.Hour) uint64(testTime.Add(2 * time.Hour).UnixNano()),
)
result3.timeFwd = tlv.NewPrimitiveRecord[tlv.TlvType0](
uint64(testTime.Add(2 * time.Hour).UnixNano()),
)
result3.id = 3 result3.id = 3
result3.failure = &lnwire.FailMPPTimeout{} result3.failure = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3](*newPaymentFailure(
&failureSourceIdx, &lnwire.FailMPPTimeout{},
)),
)
store.AddResult(&result3) store.AddResult(result3)
require.NoError(t, store.storeResults()) require.NoError(t, store.storeResults())
// Check that results are pruned. // Check that results are pruned.
@ -137,8 +146,25 @@ func TestMissionControlStore(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, results, 2) require.Len(t, results, 2)
require.Equal(t, &result2, results[0]) require.Equal(t, result2, results[0])
require.Equal(t, &result3, results[1]) require.Equal(t, result3, results[1])
// Also demonstrate the persistence of a success result.
result4 := newPaymentResult(
5, mcStoreTestRoute, testTime.Add(3*time.Hour),
testTime.Add(3*time.Hour), nil,
)
store.AddResult(result4)
require.NoError(t, store.storeResults())
// We should still only have 2 results.
results, err = store.fetchAll()
require.NoError(t, err)
require.Len(t, results, 2)
// The two latest results should have been returned.
require.Equal(t, result3, results[0])
require.Equal(t, result4, results[1])
} }
// TestMissionControlStoreFlushing asserts the periodic flushing of the store // TestMissionControlStoreFlushing asserts the periodic flushing of the store
@ -156,14 +182,11 @@ func TestMissionControlStoreFlushing(t *testing.T) {
) )
nextResult := func() *paymentResult { nextResult := func() *paymentResult {
lastID += 1 lastID += 1
return &paymentResult{ return newPaymentResult(
route: &mcStoreTestRoute, lastID, mcStoreTestRoute, testTime.Add(-time.Hour),
failure: failureDetails, testTime,
failureSourceIdx: &failureSourceIdx, newPaymentFailure(&failureSourceIdx, failureDetails),
id: lastID, )
timeReply: testTime,
timeFwd: testTime.Add(-time.Minute),
}
} }
// Helper to assert the number of results is correct. // Helper to assert the number of results is correct.
@ -260,14 +283,14 @@ func BenchmarkMissionControlStoreFlushing(b *testing.B) {
var lastID uint64 var lastID uint64
for i := 0; i < testMaxRecords; i++ { for i := 0; i < testMaxRecords; i++ {
lastID++ lastID++
result := &paymentResult{ result := newPaymentResult(
route: &mcStoreTestRoute, lastID, mcStoreTestRoute, testTimeFwd,
failure: failureDetails, testTime,
failureSourceIdx: &failureSourceIdx, newPaymentFailure(
id: lastID, &failureSourceIdx,
timeReply: testTime, failureDetails,
timeFwd: testTimeFwd, ),
} )
store.AddResult(result) store.AddResult(result)
} }
@ -278,13 +301,14 @@ func BenchmarkMissionControlStoreFlushing(b *testing.B) {
// Create the additional results. // Create the additional results.
results := make([]*paymentResult, tc) results := make([]*paymentResult, tc)
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
results[i] = &paymentResult{ results[i] = newPaymentResult(
route: &mcStoreTestRoute, 0, mcStoreTestRoute, testTimeFwd,
failure: failureDetails, testTime,
failureSourceIdx: &failureSourceIdx, newPaymentFailure(
timeReply: testTime, &failureSourceIdx,
timeFwd: testTimeFwd, failureDetails,
} ),
)
} }
// Run the actual benchmark. // Run the actual benchmark.

View file

@ -1,11 +1,15 @@
package routing package routing
import ( import (
"bytes"
"fmt" "fmt"
"io"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
) )
// Instantiate variables to allow taking a reference from the failure reason. // Instantiate variables to allow taking a reference from the failure reason.
@ -76,63 +80,73 @@ type interpretedResult struct {
// interpretResult interprets a payment outcome and returns an object that // interpretResult interprets a payment outcome and returns an object that
// contains information required to update mission control. // contains information required to update mission control.
func interpretResult(rt *mcRoute, success bool, failureSrcIdx *int, func interpretResult(rt *mcRoute,
failure lnwire.FailureMessage) *interpretedResult { failure fn.Option[paymentFailure]) *interpretedResult {
i := &interpretedResult{ i := &interpretedResult{
pairResults: make(map[DirectedNodePair]pairResult), pairResults: make(map[DirectedNodePair]pairResult),
} }
if success { return fn.ElimOption(failure, func() *interpretedResult {
i.processSuccess(rt) i.processSuccess(rt)
} else {
i.processFail(rt, failureSrcIdx, failure)
}
return i return i
}, func(info paymentFailure) *interpretedResult {
i.processFail(rt, info)
return i
})
} }
// processSuccess processes a successful payment attempt. // processSuccess processes a successful payment attempt.
func (i *interpretedResult) processSuccess(route *mcRoute) { func (i *interpretedResult) processSuccess(route *mcRoute) {
// For successes, all nodes must have acted in the right way. Therefore // For successes, all nodes must have acted in the right way. Therefore
// we mark all of them with a success result. // we mark all of them with a success result.
i.successPairRange(route, 0, len(route.hops)-1) i.successPairRange(route, 0, len(route.hops.Val)-1)
} }
// processFail processes a failed payment attempt. // processFail processes a failed payment attempt.
func (i *interpretedResult) processFail(rt *mcRoute, errSourceIdx *int, func (i *interpretedResult) processFail(rt *mcRoute, failure paymentFailure) {
failure lnwire.FailureMessage) { if failure.info.IsNone() {
if errSourceIdx == nil {
i.processPaymentOutcomeUnknown(rt) i.processPaymentOutcomeUnknown(rt)
return return
} }
var (
idx int
failMsg lnwire.FailureMessage
)
failure.info.WhenSome(
func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) {
idx = int(r.Val.sourceIdx.Val)
failMsg = r.Val.msg.Val.FailureMessage
},
)
// If the payment was to a blinded route and we received an error from // If the payment was to a blinded route and we received an error from
// after the introduction point, handle this error separately - there // after the introduction point, handle this error separately - there
// has been a protocol violation from the introduction node. This // has been a protocol violation from the introduction node. This
// penalty applies regardless of the error code that is returned. // penalty applies regardless of the error code that is returned.
introIdx, isBlinded := introductionPointIndex(rt) introIdx, isBlinded := introductionPointIndex(rt)
if isBlinded && introIdx < *errSourceIdx { if isBlinded && introIdx < idx {
i.processPaymentOutcomeBadIntro(rt, introIdx, *errSourceIdx) i.processPaymentOutcomeBadIntro(rt, introIdx, idx)
return return
} }
switch *errSourceIdx { switch idx {
// We are the source of the failure. // We are the source of the failure.
case 0: case 0:
i.processPaymentOutcomeSelf(rt, failure) i.processPaymentOutcomeSelf(rt, failMsg)
// A failure from the final hop was received. // A failure from the final hop was received.
case len(rt.hops): case len(rt.hops.Val):
i.processPaymentOutcomeFinal(rt, failure) i.processPaymentOutcomeFinal(rt, failMsg)
// An intermediate hop failed. Interpret the outcome, update reputation // An intermediate hop failed. Interpret the outcome, update reputation
// and try again. // and try again.
default: default:
i.processPaymentOutcomeIntermediate( i.processPaymentOutcomeIntermediate(rt, idx, failMsg)
rt, *errSourceIdx, failure,
)
} }
} }
@ -158,7 +172,7 @@ func (i *interpretedResult) processPaymentOutcomeBadIntro(route *mcRoute,
// a final failure reason because the recipient can't process the // a final failure reason because the recipient can't process the
// payment (independent of the introduction failing to convert the // payment (independent of the introduction failing to convert the
// error, we can't complete the payment if the last hop fails). // error, we can't complete the payment if the last hop fails).
if errSourceIdx == len(route.hops) { if errSourceIdx == len(route.hops.Val) {
i.finalFailureReason = &reasonError i.finalFailureReason = &reasonError
} }
} }
@ -178,7 +192,7 @@ func (i *interpretedResult) processPaymentOutcomeSelf(rt *mcRoute,
i.failNode(rt, 1) i.failNode(rt, 1)
// If this was a payment to a direct peer, we can stop trying. // If this was a payment to a direct peer, we can stop trying.
if len(rt.hops) == 1 { if len(rt.hops.Val) == 1 {
i.finalFailureReason = &reasonError i.finalFailureReason = &reasonError
} }
@ -188,7 +202,7 @@ func (i *interpretedResult) processPaymentOutcomeSelf(rt *mcRoute,
// available in the link has been updated. // available in the link has been updated.
default: default:
log.Warnf("Routing failure for local channel %v occurred", log.Warnf("Routing failure for local channel %v occurred",
rt.hops[0].channelID) rt.hops.Val[0].channelID)
} }
} }
@ -196,7 +210,7 @@ func (i *interpretedResult) processPaymentOutcomeSelf(rt *mcRoute,
func (i *interpretedResult) processPaymentOutcomeFinal(route *mcRoute, func (i *interpretedResult) processPaymentOutcomeFinal(route *mcRoute,
failure lnwire.FailureMessage) { failure lnwire.FailureMessage) {
n := len(route.hops) n := len(route.hops.Val)
failNode := func() { failNode := func() {
i.failNode(route, n) i.failNode(route, n)
@ -396,8 +410,8 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute,
// Set the node pair for which a channel update may be out of // Set the node pair for which a channel update may be out of
// date. The second chance logic uses the policyFailure field. // date. The second chance logic uses the policyFailure field.
i.policyFailure = &DirectedNodePair{ i.policyFailure = &DirectedNodePair{
From: route.hops[errorSourceIdx-1].pubKeyBytes, From: route.hops.Val[errorSourceIdx-1].pubKeyBytes.Val,
To: route.hops[errorSourceIdx].pubKeyBytes, To: route.hops.Val[errorSourceIdx].pubKeyBytes.Val,
} }
reportOutgoing() reportOutgoing()
@ -425,8 +439,8 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute,
// Set the node pair for which a channel update may be out of // Set the node pair for which a channel update may be out of
// date. The second chance logic uses the policyFailure field. // date. The second chance logic uses the policyFailure field.
i.policyFailure = &DirectedNodePair{ i.policyFailure = &DirectedNodePair{
From: route.hops[errorSourceIdx-1].pubKeyBytes, From: route.hops.Val[errorSourceIdx-1].pubKeyBytes.Val,
To: route.hops[errorSourceIdx].pubKeyBytes, To: route.hops.Val[errorSourceIdx].pubKeyBytes.Val,
} }
// We report incoming channel. If a second pair is granted in // We report incoming channel. If a second pair is granted in
@ -500,14 +514,14 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute,
// Note that if LND is extended to support multiple blinded // Note that if LND is extended to support multiple blinded
// routes, this will terminate the payment without re-trying // routes, this will terminate the payment without re-trying
// the other routes. // the other routes.
if introIdx == len(route.hops)-1 { if introIdx == len(route.hops.Val)-1 {
i.finalFailureReason = &reasonError i.finalFailureReason = &reasonError
} else { } else {
// If there are other hops between the recipient and // If there are other hops between the recipient and
// introduction node, then we just penalize the last // introduction node, then we just penalize the last
// hop in the blinded route to minimize the storage of // hop in the blinded route to minimize the storage of
// results for ephemeral keys. // results for ephemeral keys.
i.failPairBalance(route, len(route.hops)-1) i.failPairBalance(route, len(route.hops.Val)-1)
} }
// In all other cases, we penalize the reporting node. These are all // In all other cases, we penalize the reporting node. These are all
@ -522,8 +536,8 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute,
// (i.e., that we consider our own node to be at index zero). A boolean is // (i.e., that we consider our own node to be at index zero). A boolean is
// returned to indicate whether the route contains a blinded portion at all. // returned to indicate whether the route contains a blinded portion at all.
func introductionPointIndex(route *mcRoute) (int, bool) { func introductionPointIndex(route *mcRoute) (int, bool) {
for i, hop := range route.hops { for i, hop := range route.hops.Val {
if hop.hasBlindingPoint { if hop.hasBlindingPoint.IsSome() {
return i + 1, true return i + 1, true
} }
} }
@ -534,7 +548,7 @@ func introductionPointIndex(route *mcRoute) (int, bool) {
// processPaymentOutcomeUnknown processes a payment outcome for which no failure // processPaymentOutcomeUnknown processes a payment outcome for which no failure
// message or source is available. // message or source is available.
func (i *interpretedResult) processPaymentOutcomeUnknown(route *mcRoute) { func (i *interpretedResult) processPaymentOutcomeUnknown(route *mcRoute) {
n := len(route.hops) n := len(route.hops.Val)
// If this is a direct payment, the destination must be at fault. // If this is a direct payment, the destination must be at fault.
if n == 1 { if n == 1 {
@ -551,52 +565,204 @@ func (i *interpretedResult) processPaymentOutcomeUnknown(route *mcRoute) {
// extractMCRoute extracts the fields required by MC from the Route struct to // extractMCRoute extracts the fields required by MC from the Route struct to
// create the more minimal mcRoute struct. // create the more minimal mcRoute struct.
func extractMCRoute(route *route.Route) *mcRoute { func extractMCRoute(r *route.Route) *mcRoute {
return &mcRoute{ return &mcRoute{
sourcePubKey: route.SourcePubKey, sourcePubKey: tlv.NewRecordT[tlv.TlvType0](r.SourcePubKey),
totalAmount: route.TotalAmount, totalAmount: tlv.NewRecordT[tlv.TlvType1](r.TotalAmount),
hops: extractMCHops(route.Hops), hops: tlv.NewRecordT[tlv.TlvType2](
extractMCHops(r.Hops),
),
} }
} }
// extractMCHops extracts the Hop fields that MC actually uses from a slice of // extractMCHops extracts the Hop fields that MC actually uses from a slice of
// Hops. // Hops.
func extractMCHops(hops []*route.Hop) []*mcHop { func extractMCHops(hops []*route.Hop) mcHops {
mcHops := make([]*mcHop, len(hops)) return fn.Map(extractMCHop, hops)
for i, hop := range hops {
mcHops[i] = extractMCHop(hop)
}
return mcHops
} }
// extractMCHop extracts the Hop fields that MC actually uses from a Hop. // extractMCHop extracts the Hop fields that MC actually uses from a Hop.
func extractMCHop(hop *route.Hop) *mcHop { func extractMCHop(hop *route.Hop) *mcHop {
return &mcHop{ h := mcHop{
channelID: hop.ChannelID, channelID: tlv.NewPrimitiveRecord[tlv.TlvType0](
pubKeyBytes: hop.PubKeyBytes, hop.ChannelID,
amtToFwd: hop.AmtToForward, ),
hasBlindingPoint: hop.BlindingPoint != nil, pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](hop.PubKeyBytes),
hasCustomRecords: len(hop.CustomRecords) > 0, amtToFwd: tlv.NewRecordT[tlv.TlvType2](hop.AmtToForward),
} }
if hop.BlindingPoint != nil {
h.hasBlindingPoint = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3](lnwire.TrueBoolean{}),
)
}
if hop.CustomRecords != nil {
h.hasCustomRecords = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType4](lnwire.TrueBoolean{}),
)
}
return &h
} }
// mcRoute holds the bare minimum info about a payment attempt route that MC // mcRoute holds the bare minimum info about a payment attempt route that MC
// requires. // requires.
type mcRoute struct { type mcRoute struct {
sourcePubKey route.Vertex sourcePubKey tlv.RecordT[tlv.TlvType0, route.Vertex]
totalAmount lnwire.MilliSatoshi totalAmount tlv.RecordT[tlv.TlvType1, lnwire.MilliSatoshi]
hops []*mcHop hops tlv.RecordT[tlv.TlvType2, mcHops]
}
// Record returns a TLV record that can be used to encode/decode an mcRoute
// to/from a TLV stream.
func (r *mcRoute) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeMCRoute(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodeMCRoute, decodeMCRoute,
)
}
func encodeMCRoute(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*mcRoute); ok {
return serializeRoute(w, v)
}
return tlv.NewTypeForEncodingErr(val, "routing.mcRoute")
}
func decodeMCRoute(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if v, ok := val.(*mcRoute); ok {
route, err := deserializeRoute(io.LimitReader(r, int64(l)))
if err != nil {
return err
}
*v = *route
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.mcRoute", l, l)
}
// mcHops is a list of mcHop records.
type mcHops []*mcHop
// Record returns a TLV record that can be used to encode/decode a list of
// mcHop to/from a TLV stream.
func (h *mcHops) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeMCHops(&b, h, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, h, recordSize, encodeMCHops, decodeMCHops,
)
}
func encodeMCHops(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*mcHops); ok {
// Encode the number of hops as a var int.
if err := tlv.WriteVarInt(w, uint64(len(*v)), buf); err != nil {
return err
}
// With that written out, we'll now encode the entries
// themselves as a sub-TLV record, which includes its _own_
// inner length prefix.
for _, hop := range *v {
var hopBytes bytes.Buffer
if err := serializeHop(&hopBytes, hop); err != nil {
return err
}
// We encode the record with a varint length followed by
// the _raw_ TLV bytes.
tlvLen := uint64(len(hopBytes.Bytes()))
if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
return err
}
if _, err := w.Write(hopBytes.Bytes()); err != nil {
return err
}
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "routing.mcHops")
}
func decodeMCHops(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*mcHops); ok {
// First, we'll decode the varint that encodes how many hops
// are encoded in the stream.
numHops, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Now that we know how many records we'll need to read, we can
// iterate and read them all out in series.
for i := uint64(0); i < numHops; i++ {
// Read out the varint that encodes the size of this
// inner TLV record.
hopSize, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Using this information, we'll create a new limited
// reader that'll return an EOF once the end has been
// reached so the stream stops consuming bytes.
innerTlvReader := &io.LimitedReader{
R: r,
N: int64(hopSize),
}
hop, err := deserializeHop(innerTlvReader)
if err != nil {
return err
}
*v = append(*v, hop)
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.mcHops", l, l)
} }
// mcHop holds the bare minimum info about a payment attempt route hop that MC // mcHop holds the bare minimum info about a payment attempt route hop that MC
// requires. // requires.
type mcHop struct { type mcHop struct {
channelID uint64 channelID tlv.RecordT[tlv.TlvType0, uint64]
pubKeyBytes route.Vertex pubKeyBytes tlv.RecordT[tlv.TlvType1, route.Vertex]
amtToFwd lnwire.MilliSatoshi amtToFwd tlv.RecordT[tlv.TlvType2, lnwire.MilliSatoshi]
hasBlindingPoint bool hasBlindingPoint tlv.OptionalRecordT[tlv.TlvType3, lnwire.TrueBoolean]
hasCustomRecords bool hasCustomRecords tlv.OptionalRecordT[tlv.TlvType4, lnwire.TrueBoolean]
} }
// failNode marks the node indicated by idx in the route as failed. It also // failNode marks the node indicated by idx in the route as failed. It also
@ -604,7 +770,7 @@ type mcHop struct {
// intentionally panics when the self node is failed. // intentionally panics when the self node is failed.
func (i *interpretedResult) failNode(rt *mcRoute, idx int) { func (i *interpretedResult) failNode(rt *mcRoute, idx int) {
// Mark the node as failing. // Mark the node as failing.
i.nodeFailure = &rt.hops[idx-1].pubKeyBytes i.nodeFailure = &rt.hops.Val[idx-1].pubKeyBytes.Val
// Mark the incoming connection as failed for the node. We intent to // Mark the incoming connection as failed for the node. We intent to
// penalize as much as we can for a node level failure, including future // penalize as much as we can for a node level failure, including future
@ -620,7 +786,7 @@ func (i *interpretedResult) failNode(rt *mcRoute, idx int) {
// If not the ultimate node, mark the outgoing connection as failed for // If not the ultimate node, mark the outgoing connection as failed for
// the node. // the node.
if idx < len(rt.hops) { if idx < len(rt.hops.Val) {
outgoingChannelIdx := idx outgoingChannelIdx := idx
outPair, _ := getPair(rt, outgoingChannelIdx) outPair, _ := getPair(rt, outgoingChannelIdx)
i.pairResults[outPair] = failPairResult(0) i.pairResults[outPair] = failPairResult(0)
@ -667,18 +833,18 @@ func (i *interpretedResult) successPairRange(rt *mcRoute, fromIdx, toIdx int) {
func getPair(rt *mcRoute, channelIdx int) (DirectedNodePair, func getPair(rt *mcRoute, channelIdx int) (DirectedNodePair,
lnwire.MilliSatoshi) { lnwire.MilliSatoshi) {
nodeTo := rt.hops[channelIdx].pubKeyBytes nodeTo := rt.hops.Val[channelIdx].pubKeyBytes.Val
var ( var (
nodeFrom route.Vertex nodeFrom route.Vertex
amt lnwire.MilliSatoshi amt lnwire.MilliSatoshi
) )
if channelIdx == 0 { if channelIdx == 0 {
nodeFrom = rt.sourcePubKey nodeFrom = rt.sourcePubKey.Val
amt = rt.totalAmount amt = rt.totalAmount.Val
} else { } else {
nodeFrom = rt.hops[channelIdx-1].pubKeyBytes nodeFrom = rt.hops.Val[channelIdx-1].pubKeyBytes.Val
amt = rt.hops[channelIdx-1].amtToFwd amt = rt.hops.Val[channelIdx-1].amtToFwd.Val
} }
pair := NewDirectedNodePair(nodeFrom, nodeTo) pair := NewDirectedNodePair(nodeFrom, nodeTo)

View file

@ -4,7 +4,9 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
) )
@ -14,110 +16,170 @@ var (
{1, 0}, {1, 1}, {1, 2}, {1, 3}, {1, 4}, {1, 0}, {1, 1}, {1, 2}, {1, 3}, {1, 4},
} }
routeOneHop = mcRoute{ routeOneHop = extractMCRoute(&route.Route{
sourcePubKey: hops[0], SourcePubKey: hops[0],
totalAmount: 100, TotalAmount: 100,
hops: []*mcHop{ Hops: []*route.Hop{
{pubKeyBytes: hops[1], amtToFwd: 99}, {
PubKeyBytes: hops[1],
AmtToForward: 99,
}, },
} },
})
routeTwoHop = mcRoute{ routeTwoHop = extractMCRoute(&route.Route{
sourcePubKey: hops[0], SourcePubKey: hops[0],
totalAmount: 100, TotalAmount: 100,
hops: []*mcHop{ Hops: []*route.Hop{
{pubKeyBytes: hops[1], amtToFwd: 99}, {
{pubKeyBytes: hops[2], amtToFwd: 97}, PubKeyBytes: hops[1],
AmtToForward: 99,
}, },
} {
PubKeyBytes: hops[2],
AmtToForward: 97,
},
},
})
routeThreeHop = mcRoute{ routeThreeHop = extractMCRoute(&route.Route{
sourcePubKey: hops[0], SourcePubKey: hops[0],
totalAmount: 100, TotalAmount: 100,
hops: []*mcHop{ Hops: []*route.Hop{
{pubKeyBytes: hops[1], amtToFwd: 99}, {
{pubKeyBytes: hops[2], amtToFwd: 97}, PubKeyBytes: hops[1],
{pubKeyBytes: hops[3], amtToFwd: 94}, AmtToForward: 99,
}, },
} {
PubKeyBytes: hops[2],
AmtToForward: 97,
},
{
PubKeyBytes: hops[3],
AmtToForward: 94,
},
},
})
routeFourHop = mcRoute{ routeFourHop = extractMCRoute(&route.Route{
sourcePubKey: hops[0], SourcePubKey: hops[0],
totalAmount: 100, TotalAmount: 100,
hops: []*mcHop{ Hops: []*route.Hop{
{pubKeyBytes: hops[1], amtToFwd: 99}, {
{pubKeyBytes: hops[2], amtToFwd: 97}, PubKeyBytes: hops[1],
{pubKeyBytes: hops[3], amtToFwd: 94}, AmtToForward: 99,
{pubKeyBytes: hops[4], amtToFwd: 90},
}, },
} {
PubKeyBytes: hops[2],
AmtToForward: 97,
},
{
PubKeyBytes: hops[3],
AmtToForward: 94,
},
{
PubKeyBytes: hops[4],
AmtToForward: 90,
},
},
})
// blindedMultiHop is a blinded path where there are cleartext hops // blindedMultiHop is a blinded path where there are cleartext hops
// before the introduction node, and an intermediate blinded hop before // before the introduction node, and an intermediate blinded hop before
// the recipient after it. // the recipient after it.
blindedMultiHop = mcRoute{ blindedMultiHop = extractMCRoute(&route.Route{
sourcePubKey: hops[0], SourcePubKey: hops[0],
totalAmount: 100, TotalAmount: 100,
hops: []*mcHop{ Hops: []*route.Hop{
{pubKeyBytes: hops[1], amtToFwd: 99},
{ {
pubKeyBytes: hops[2], PubKeyBytes: hops[1],
amtToFwd: 95, AmtToForward: 99,
hasBlindingPoint: true,
}, },
{pubKeyBytes: hops[3], amtToFwd: 88}, {
{pubKeyBytes: hops[4], amtToFwd: 77}, PubKeyBytes: hops[2],
AmtToForward: 95,
BlindingPoint: genTestPubKey(),
}, },
} {
PubKeyBytes: hops[3],
AmtToForward: 88,
},
{
PubKeyBytes: hops[4],
AmtToForward: 77,
},
},
})
// blindedSingleHop is a blinded path with a single blinded hop after // blindedSingleHop is a blinded path with a single blinded hop after
// the introduction node. // the introduction node.
blindedSingleHop = mcRoute{ blindedSingleHop = extractMCRoute(&route.Route{
sourcePubKey: hops[0], SourcePubKey: hops[0],
totalAmount: 100, TotalAmount: 100,
hops: []*mcHop{ Hops: []*route.Hop{
{pubKeyBytes: hops[1], amtToFwd: 99},
{ {
pubKeyBytes: hops[2], PubKeyBytes: hops[1],
amtToFwd: 95, AmtToForward: 99,
hasBlindingPoint: true,
}, },
{pubKeyBytes: hops[3], amtToFwd: 88}, {
PubKeyBytes: hops[2],
AmtToForward: 95,
BlindingPoint: genTestPubKey(),
}, },
} {
PubKeyBytes: hops[3],
AmtToForward: 88,
},
},
})
// blindedMultiToIntroduction is a blinded path which goes directly // blindedMultiToIntroduction is a blinded path which goes directly
// to the introduction node, with multiple blinded hops after it. // to the introduction node, with multiple blinded hops after it.
blindedMultiToIntroduction = mcRoute{ blindedMultiToIntroduction = extractMCRoute(&route.Route{
sourcePubKey: hops[0], SourcePubKey: hops[0],
totalAmount: 100, TotalAmount: 100,
hops: []*mcHop{ Hops: []*route.Hop{
{ {
pubKeyBytes: hops[1], PubKeyBytes: hops[1],
amtToFwd: 90, AmtToForward: 90,
hasBlindingPoint: true, BlindingPoint: genTestPubKey(),
}, },
{pubKeyBytes: hops[2], amtToFwd: 75}, {
{pubKeyBytes: hops[3], amtToFwd: 58}, PubKeyBytes: hops[2],
AmtToForward: 75,
}, },
} {
PubKeyBytes: hops[3],
AmtToForward: 58,
},
},
})
// blindedIntroReceiver is a blinded path where the introduction node // blindedIntroReceiver is a blinded path where the introduction node
// is the recipient. // is the recipient.
blindedIntroReceiver = mcRoute{ blindedIntroReceiver = extractMCRoute(&route.Route{
sourcePubKey: hops[0], SourcePubKey: hops[0],
totalAmount: 100, TotalAmount: 100,
hops: []*mcHop{ Hops: []*route.Hop{
{pubKeyBytes: hops[1], amtToFwd: 95},
{ {
pubKeyBytes: hops[2], PubKeyBytes: hops[1],
amtToFwd: 90, AmtToForward: 95,
hasBlindingPoint: true, },
{
PubKeyBytes: hops[2],
AmtToForward: 90,
BlindingPoint: genTestPubKey(),
}, },
}, },
} })
) )
func genTestPubKey() *btcec.PublicKey {
key, _ := btcec.NewPrivateKey()
return key.PubKey()
}
func getTestPair(from, to int) DirectedNodePair { func getTestPair(from, to int) DirectedNodePair {
return NewDirectedNodePair(hops[from], hops[to]) return NewDirectedNodePair(hops[from], hops[to])
} }
@ -142,7 +204,7 @@ var resultTestCases = []resultTestCase{
// interpreted. // interpreted.
{ {
name: "fail", name: "fail",
route: &routeTwoHop, route: routeTwoHop,
failureSrcIdx: 1, failureSrcIdx: 1,
failure: lnwire.NewTemporaryChannelFailure(nil), failure: lnwire.NewTemporaryChannelFailure(nil),
@ -157,7 +219,7 @@ var resultTestCases = []resultTestCase{
// Tests that an expiry too soon failure result is properly interpreted. // Tests that an expiry too soon failure result is properly interpreted.
{ {
name: "fail expiry too soon", name: "fail expiry too soon",
route: &routeFourHop, route: routeFourHop,
failureSrcIdx: 3, failureSrcIdx: 3,
failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{}), failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{}),
@ -177,7 +239,7 @@ var resultTestCases = []resultTestCase{
// failure, but mark all pairs along the route as successful. // failure, but mark all pairs along the route as successful.
{ {
name: "fail incorrect details", name: "fail incorrect details",
route: &routeTwoHop, route: routeTwoHop,
failureSrcIdx: 2, failureSrcIdx: 2,
failure: lnwire.NewFailIncorrectDetails(97, 0), failure: lnwire.NewFailIncorrectDetails(97, 0),
@ -193,7 +255,7 @@ var resultTestCases = []resultTestCase{
// Tests a successful direct payment. // Tests a successful direct payment.
{ {
name: "success direct", name: "success direct",
route: &routeOneHop, route: routeOneHop,
success: true, success: true,
expectedResult: &interpretedResult{ expectedResult: &interpretedResult{
@ -206,7 +268,7 @@ var resultTestCases = []resultTestCase{
// Tests a successful two hop payment. // Tests a successful two hop payment.
{ {
name: "success", name: "success",
route: &routeTwoHop, route: routeTwoHop,
success: true, success: true,
expectedResult: &interpretedResult{ expectedResult: &interpretedResult{
@ -220,7 +282,7 @@ var resultTestCases = []resultTestCase{
// Tests a malformed htlc from a direct peer. // Tests a malformed htlc from a direct peer.
{ {
name: "fail malformed htlc from direct peer", name: "fail malformed htlc from direct peer",
route: &routeTwoHop, route: routeTwoHop,
failureSrcIdx: 0, failureSrcIdx: 0,
failure: lnwire.NewInvalidOnionKey(nil), failure: lnwire.NewInvalidOnionKey(nil),
@ -239,7 +301,7 @@ var resultTestCases = []resultTestCase{
// destination. // destination.
{ {
name: "fail malformed htlc from direct final peer", name: "fail malformed htlc from direct final peer",
route: &routeOneHop, route: routeOneHop,
failureSrcIdx: 0, failureSrcIdx: 0,
failure: lnwire.NewInvalidOnionKey(nil), failure: lnwire.NewInvalidOnionKey(nil),
@ -259,7 +321,7 @@ var resultTestCases = []resultTestCase{
// in a policy failure for the outgoing hop. // in a policy failure for the outgoing hop.
{ {
name: "fail fee insufficient intermediate", name: "fail fee insufficient intermediate",
route: &routeFourHop, route: routeFourHop,
failureSrcIdx: 2, failureSrcIdx: 2,
failure: lnwire.NewFeeInsufficient( failure: lnwire.NewFeeInsufficient(
0, lnwire.ChannelUpdate1{}, 0, lnwire.ChannelUpdate1{},
@ -282,7 +344,7 @@ var resultTestCases = []resultTestCase{
// failure is terminal since the receiver can't process our onion. // failure is terminal since the receiver can't process our onion.
{ {
name: "fail invalid onion payload final hop four", name: "fail invalid onion payload final hop four",
route: &routeFourHop, route: routeFourHop,
failureSrcIdx: 4, failureSrcIdx: 4,
failure: lnwire.NewInvalidOnionPayload(0, 0), failure: lnwire.NewInvalidOnionPayload(0, 0),
@ -311,7 +373,7 @@ var resultTestCases = []resultTestCase{
// Tests an invalid onion payload from a final hop on a three hop route. // Tests an invalid onion payload from a final hop on a three hop route.
{ {
name: "fail invalid onion payload final hop three", name: "fail invalid onion payload final hop three",
route: &routeThreeHop, route: routeThreeHop,
failureSrcIdx: 3, failureSrcIdx: 3,
failure: lnwire.NewInvalidOnionPayload(0, 0), failure: lnwire.NewInvalidOnionPayload(0, 0),
@ -338,7 +400,7 @@ var resultTestCases = []resultTestCase{
// can still try other paths. // can still try other paths.
{ {
name: "fail invalid onion payload intermediate", name: "fail invalid onion payload intermediate",
route: &routeFourHop, route: routeFourHop,
failureSrcIdx: 3, failureSrcIdx: 3,
failure: lnwire.NewInvalidOnionPayload(0, 0), failure: lnwire.NewInvalidOnionPayload(0, 0),
@ -366,7 +428,7 @@ var resultTestCases = []resultTestCase{
// since the remote node can't process our onion. // since the remote node can't process our onion.
{ {
name: "fail invalid onion payload direct", name: "fail invalid onion payload direct",
route: &routeOneHop, route: routeOneHop,
failureSrcIdx: 1, failureSrcIdx: 1,
failure: lnwire.NewInvalidOnionPayload(0, 0), failure: lnwire.NewInvalidOnionPayload(0, 0),
@ -385,7 +447,7 @@ var resultTestCases = []resultTestCase{
// penalize mpp timeouts. // penalize mpp timeouts.
{ {
name: "one hop mpp timeout", name: "one hop mpp timeout",
route: &routeOneHop, route: routeOneHop,
failureSrcIdx: 1, failureSrcIdx: 1,
failure: &lnwire.FailMPPTimeout{}, failure: &lnwire.FailMPPTimeout{},
@ -402,7 +464,7 @@ var resultTestCases = []resultTestCase{
// temporary measure while we decide how to penalize mpp timeouts. // temporary measure while we decide how to penalize mpp timeouts.
{ {
name: "two hop mpp timeout", name: "two hop mpp timeout",
route: &routeTwoHop, route: routeTwoHop,
failureSrcIdx: 2, failureSrcIdx: 2,
failure: &lnwire.FailMPPTimeout{}, failure: &lnwire.FailMPPTimeout{},
@ -419,7 +481,7 @@ var resultTestCases = []resultTestCase{
// disabled channel should be penalized for any amount. // disabled channel should be penalized for any amount.
{ {
name: "two hop channel disabled", name: "two hop channel disabled",
route: &routeTwoHop, route: routeTwoHop,
failureSrcIdx: 1, failureSrcIdx: 1,
failure: &lnwire.FailChannelDisabled{}, failure: &lnwire.FailChannelDisabled{},
@ -437,7 +499,7 @@ var resultTestCases = []resultTestCase{
// has not followed the specification properly. // has not followed the specification properly.
{ {
name: "error after introduction", name: "error after introduction",
route: &blindedMultiToIntroduction, route: blindedMultiToIntroduction,
failureSrcIdx: 2, failureSrcIdx: 2,
// Note that the failure code doesn't matter in this case - // Note that the failure code doesn't matter in this case -
// all we're worried about is errors originating after the // all we're worried about is errors originating after the
@ -460,7 +522,7 @@ var resultTestCases = []resultTestCase{
// hop when we expected the introduction node to convert. // hop when we expected the introduction node to convert.
{ {
name: "final failure expected intro", name: "final failure expected intro",
route: &blindedMultiHop, route: blindedMultiHop,
failureSrcIdx: 4, failureSrcIdx: 4,
failure: &lnwire.FailInvalidBlinding{}, failure: &lnwire.FailInvalidBlinding{},
@ -482,7 +544,7 @@ var resultTestCases = []resultTestCase{
// introduction point. // introduction point.
{ {
name: "blinded multi-hop introduction", name: "blinded multi-hop introduction",
route: &blindedMultiHop, route: blindedMultiHop,
failureSrcIdx: 2, failureSrcIdx: 2,
failure: &lnwire.FailInvalidBlinding{}, failure: &lnwire.FailInvalidBlinding{},
@ -498,7 +560,7 @@ var resultTestCases = []resultTestCase{
// introduction point, which is a direct peer. // introduction point, which is a direct peer.
{ {
name: "blinded multi-hop introduction peer", name: "blinded multi-hop introduction peer",
route: &blindedMultiToIntroduction, route: blindedMultiToIntroduction,
failureSrcIdx: 1, failureSrcIdx: 1,
failure: &lnwire.FailInvalidBlinding{}, failure: &lnwire.FailInvalidBlinding{},
@ -513,7 +575,7 @@ var resultTestCases = []resultTestCase{
// connected to the introduction node. // connected to the introduction node.
{ {
name: "blinded single hop introduction failure", name: "blinded single hop introduction failure",
route: &blindedSingleHop, route: blindedSingleHop,
failureSrcIdx: 2, failureSrcIdx: 2,
failure: &lnwire.FailInvalidBlinding{}, failure: &lnwire.FailInvalidBlinding{},
@ -529,7 +591,7 @@ var resultTestCases = []resultTestCase{
// blinding error and is penalized for returning the wrong error. // blinding error and is penalized for returning the wrong error.
{ {
name: "error before introduction", name: "error before introduction",
route: &blindedMultiHop, route: blindedMultiHop,
failureSrcIdx: 1, failureSrcIdx: 1,
failure: &lnwire.FailInvalidBlinding{}, failure: &lnwire.FailInvalidBlinding{},
@ -549,7 +611,7 @@ var resultTestCases = []resultTestCase{
// successful hop before the incorrect error. // successful hop before the incorrect error.
{ {
name: "intermediate unexpected blinding", name: "intermediate unexpected blinding",
route: &routeThreeHop, route: routeThreeHop,
failureSrcIdx: 2, failureSrcIdx: 2,
failure: &lnwire.FailInvalidBlinding{}, failure: &lnwire.FailInvalidBlinding{},
@ -570,7 +632,7 @@ var resultTestCases = []resultTestCase{
// hops before the erring incoming link (the erring node if our peer). // hops before the erring incoming link (the erring node if our peer).
{ {
name: "peer unexpected blinding", name: "peer unexpected blinding",
route: &routeThreeHop, route: routeThreeHop,
failureSrcIdx: 1, failureSrcIdx: 1,
failure: &lnwire.FailInvalidBlinding{}, failure: &lnwire.FailInvalidBlinding{},
@ -588,7 +650,7 @@ var resultTestCases = []resultTestCase{
// A node in a non-blinded route returns a blinding related error. // A node in a non-blinded route returns a blinding related error.
{ {
name: "final node unexpected blinding", name: "final node unexpected blinding",
route: &routeThreeHop, route: routeThreeHop,
failureSrcIdx: 3, failureSrcIdx: 3,
failure: &lnwire.FailInvalidBlinding{}, failure: &lnwire.FailInvalidBlinding{},
@ -606,7 +668,7 @@ var resultTestCases = []resultTestCase{
// Introduction node returns invalid blinding erroneously. // Introduction node returns invalid blinding erroneously.
{ {
name: "final node intro blinding", name: "final node intro blinding",
route: &blindedIntroReceiver, route: blindedIntroReceiver,
failureSrcIdx: 2, failureSrcIdx: 2,
failure: &lnwire.FailInvalidBlinding{}, failure: &lnwire.FailInvalidBlinding{},
@ -629,10 +691,15 @@ func TestResultInterpretation(t *testing.T) {
for _, testCase := range resultTestCases { for _, testCase := range resultTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
i := interpretResult( var failure fn.Option[paymentFailure]
testCase.route, testCase.success, if !testCase.success {
&testCase.failureSrcIdx, testCase.failure, failure = fn.Some(*newPaymentFailure(
) &testCase.failureSrcIdx,
testCase.failure,
))
}
i := interpretResult(testCase.route, failure)
expected := testCase.expectedResult expected := testCase.expectedResult

View file

@ -94,6 +94,32 @@ func (v Vertex) String() string {
return fmt.Sprintf("%x", v[:]) return fmt.Sprintf("%x", v[:])
} }
// Record returns a TLV record that can be used to encode/decode a Vertex
// to/from a TLV stream.
func (v *Vertex) Record() tlv.Record {
return tlv.MakeStaticRecord(
0, v, VertexSize, encodeVertex, decodeVertex,
)
}
func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error {
if b, ok := val.(*Vertex); ok {
_, err := w.Write(b[:])
return err
}
return tlv.NewTypeForEncodingErr(val, "Vertex")
}
func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*Vertex); ok {
_, err := io.ReadFull(r, b[:])
return err
}
return tlv.NewTypeForDecodingErr(val, "Vertex", l, VertexSize)
}
// Hop represents an intermediate or final node of the route. This naming // Hop represents an intermediate or final node of the route. This naming
// is in line with the definition given in BOLT #4: Onion Routing Protocol. // is in line with the definition given in BOLT #4: Onion Routing Protocol.
// The struct houses the channel along which this hop can be reached and // The struct houses the channel along which this hop can be reached and