mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
records/mpp: add MPP struct to create corresponding tlv.Record
Used to encode/decode MPP tlv records
This commit is contained in:
parent
b3b51923dc
commit
b1b7ff8006
98
record/mpp.go
Normal file
98
record/mpp.go
Normal file
@ -0,0 +1,98 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// MPPOnionType is the type used in the onion to reference the MPP fields:
|
||||
// total_amt and payment_addr.
|
||||
const MPPOnionType tlv.Type = 8
|
||||
|
||||
// MPP is a record that encodes the fields necessary for multi-path payments.
|
||||
type MPP struct {
|
||||
// paymentAddr is a random, receiver-generated value used to avoid
|
||||
// collisions with concurrent payers.
|
||||
paymentAddr [32]byte
|
||||
|
||||
// totalMsat is the total value of the payment, potentially spread
|
||||
// across more than one HTLC.
|
||||
totalMsat lnwire.MilliSatoshi
|
||||
}
|
||||
|
||||
// NewMPP generates a new MPP record with the given total and payment address.
|
||||
func NewMPP(total lnwire.MilliSatoshi, addr [32]byte) *MPP {
|
||||
return &MPP{
|
||||
paymentAddr: addr,
|
||||
totalMsat: total,
|
||||
}
|
||||
}
|
||||
|
||||
// PaymentAddr returns the payment address contained in the MPP record.
|
||||
func (r *MPP) PaymentAddr() [32]byte {
|
||||
return r.paymentAddr
|
||||
}
|
||||
|
||||
// TotalMsat returns the total value of an MPP payment in msats.
|
||||
func (r *MPP) TotalMsat() lnwire.MilliSatoshi {
|
||||
return r.totalMsat
|
||||
}
|
||||
|
||||
// MPPEncoder writes the MPP record to the provided io.Writer.
|
||||
func MPPEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if v, ok := val.(*MPP); ok {
|
||||
err := tlv.EBytes32(w, &v.paymentAddr, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tlv.ETUint64T(w, uint64(v.totalMsat), buf)
|
||||
}
|
||||
return tlv.NewTypeForEncodingErr(val, "MPP")
|
||||
}
|
||||
|
||||
const (
|
||||
// minMPPLength is the minimum length of a serialized MPP TLV record,
|
||||
// which occurs when the truncated encoding of total_amt_msat takes 0
|
||||
// bytes, leaving only the payment_addr.
|
||||
minMPPLength = 32
|
||||
|
||||
// maxMPPLength is the maximum length of a serialized MPP TLV record,
|
||||
// which occurs when the truncated encoding of total_amt_msat takes 8
|
||||
// bytes.
|
||||
maxMPPLength = 40
|
||||
)
|
||||
|
||||
// MPPDecoder reads the MPP record to the provided io.Reader.
|
||||
func MPPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if v, ok := val.(*MPP); ok && minMPPLength <= l && l <= maxMPPLength {
|
||||
if err := tlv.DBytes32(r, &v.paymentAddr, buf, 32); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var total uint64
|
||||
if err := tlv.DTUint64(r, &total, buf, l-32); err != nil {
|
||||
return err
|
||||
}
|
||||
v.totalMsat = lnwire.MilliSatoshi(total)
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
return tlv.NewTypeForDecodingErr(val, "MPP", l, maxMPPLength)
|
||||
}
|
||||
|
||||
// Record returns a tlv.Record that can be used to encode or decode this record.
|
||||
func (r *MPP) Record() tlv.Record {
|
||||
// Fixed-size, 32 byte payment address followed by truncated 64-bit
|
||||
// total msat.
|
||||
size := func() uint64 {
|
||||
return 32 + tlv.SizeTUint64(uint64(r.totalMsat))
|
||||
}
|
||||
|
||||
return tlv.MakeDynamicRecord(
|
||||
MPPOnionType, r, size, MPPEncoder, MPPDecoder,
|
||||
)
|
||||
}
|
73
record/record_test.go
Normal file
73
record/record_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package record_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
type recordEncDecTest struct {
|
||||
name string
|
||||
encRecord func() tlv.RecordProducer
|
||||
decRecord func() tlv.RecordProducer
|
||||
assert func(*testing.T, interface{})
|
||||
}
|
||||
|
||||
var (
|
||||
testTotal = lnwire.MilliSatoshi(45)
|
||||
testAddr = [32]byte{0x01, 0x02}
|
||||
)
|
||||
|
||||
var recordEncDecTests = []recordEncDecTest{
|
||||
{
|
||||
name: "mpp",
|
||||
encRecord: func() tlv.RecordProducer {
|
||||
return record.NewMPP(testTotal, testAddr)
|
||||
},
|
||||
decRecord: func() tlv.RecordProducer {
|
||||
return new(record.MPP)
|
||||
},
|
||||
assert: func(t *testing.T, r interface{}) {
|
||||
mpp := r.(*record.MPP)
|
||||
if mpp.TotalMsat() != testTotal {
|
||||
t.Fatal("incorrect total msat")
|
||||
}
|
||||
if mpp.PaymentAddr() != testAddr {
|
||||
t.Fatal("incorrect payment addr")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestRecordEncodeDecode is a generic test framework for custom TLV records. It
|
||||
// asserts that records can encode and decode themselves, and that the value of
|
||||
// the original record matches the decoded record.
|
||||
func TestRecordEncodeDecode(t *testing.T) {
|
||||
for _, test := range recordEncDecTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
r := test.encRecord()
|
||||
r2 := test.decRecord()
|
||||
encStream := tlv.MustNewStream(r.Record())
|
||||
decStream := tlv.MustNewStream(r2.Record())
|
||||
|
||||
test.assert(t, r)
|
||||
|
||||
var b bytes.Buffer
|
||||
err := encStream.Encode(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode record: %v", err)
|
||||
}
|
||||
|
||||
err = decStream.Decode(bytes.NewReader(b.Bytes()))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode record: %v", err)
|
||||
}
|
||||
|
||||
test.assert(t, r2)
|
||||
})
|
||||
}
|
||||
}
|
@ -43,6 +43,14 @@ func SizeVarBytes(e *[]byte) SizeFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// RecorderProducer is an interface for objects that can produce a Record object
|
||||
// capable of encoding and/or decoding the RecordProducer as a Record.
|
||||
type RecordProducer interface {
|
||||
// Record returns a Record that can be used to encode or decode the
|
||||
// backing object.
|
||||
Record() Record
|
||||
}
|
||||
|
||||
// Record holds the required information to encode or decode a TLV record.
|
||||
type Record struct {
|
||||
value interface{}
|
||||
|
Loading…
Reference in New Issue
Block a user