mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 18:10:34 +01:00
100 lines
2.3 KiB
Go
100 lines
2.3 KiB
Go
package record_test
|
|
|
|
import (
|
|
"bytes"
|
|
"testing"
|
|
|
|
"github.com/lightningnetwork/lnd/lnwire"
|
|
"github.com/lightningnetwork/lnd/record"
|
|
"github.com/lightningnetwork/lnd/tlv"
|
|
)
|
|
|
|
type recordEncDecTest struct {
|
|
name string
|
|
encRecord func() tlv.RecordProducer
|
|
decRecord func() tlv.RecordProducer
|
|
assert func(*testing.T, interface{})
|
|
}
|
|
|
|
var (
|
|
testTotal = lnwire.MilliSatoshi(45)
|
|
testAddr = [32]byte{0x01, 0x02}
|
|
testShare = [32]byte{0x03, 0x04}
|
|
testSetID = [32]byte{0x05, 0x06}
|
|
testChildIndex = uint16(17)
|
|
)
|
|
|
|
var recordEncDecTests = []recordEncDecTest{
|
|
{
|
|
name: "mpp",
|
|
encRecord: func() tlv.RecordProducer {
|
|
return record.NewMPP(testTotal, testAddr)
|
|
},
|
|
decRecord: func() tlv.RecordProducer {
|
|
return new(record.MPP)
|
|
},
|
|
assert: func(t *testing.T, r interface{}) {
|
|
mpp := r.(*record.MPP)
|
|
if mpp.TotalMsat() != testTotal {
|
|
t.Fatal("incorrect total msat")
|
|
}
|
|
if mpp.PaymentAddr() != testAddr {
|
|
t.Fatal("incorrect payment addr")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "amp",
|
|
encRecord: func() tlv.RecordProducer {
|
|
return record.NewAMP(
|
|
testShare, testSetID, testChildIndex,
|
|
)
|
|
},
|
|
decRecord: func() tlv.RecordProducer {
|
|
return new(record.AMP)
|
|
},
|
|
assert: func(t *testing.T, r interface{}) {
|
|
amp := r.(*record.AMP)
|
|
if amp.RootShare() != testShare {
|
|
t.Fatal("incorrect root share")
|
|
}
|
|
if amp.SetID() != testSetID {
|
|
t.Fatal("incorrect set id")
|
|
}
|
|
if amp.ChildIndex() != testChildIndex {
|
|
t.Fatal("incorrect child index")
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
// TestRecordEncodeDecode is a generic test framework for custom TLV records. It
|
|
// asserts that records can encode and decode themselves, and that the value of
|
|
// the original record matches the decoded record.
|
|
func TestRecordEncodeDecode(t *testing.T) {
|
|
for _, test := range recordEncDecTests {
|
|
test := test
|
|
t.Run(test.name, func(t *testing.T) {
|
|
r := test.encRecord()
|
|
r2 := test.decRecord()
|
|
encStream := tlv.MustNewStream(r.Record())
|
|
decStream := tlv.MustNewStream(r2.Record())
|
|
|
|
test.assert(t, r)
|
|
|
|
var b bytes.Buffer
|
|
err := encStream.Encode(&b)
|
|
if err != nil {
|
|
t.Fatalf("unable to encode record: %v", err)
|
|
}
|
|
|
|
err = decStream.Decode(bytes.NewReader(b.Bytes()))
|
|
if err != nil {
|
|
t.Fatalf("unable to decode record: %v", err)
|
|
}
|
|
|
|
test.assert(t, r2)
|
|
})
|
|
}
|
|
}
|