lnd/watchtower/tlv_bench_test.go
Oliver Gugger 67f5957c5e
tlv+watchtower: move bench test
The benchmark tests import both the tlv and watchtower packages. To make
it possible to extract the tlv package into its own submodule, this test
is better located in the watchtower package.
2022-02-21 13:48:31 +01:00

166 lines
4.0 KiB
Go

package watchtower_test
import (
"bytes"
"io"
"io/ioutil"
"testing"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
"github.com/stretchr/testify/require"
)
// CreateSessionTLV mirrors the wtwire.CreateSession message, but uses TLV for
// encoding/decoding.
type CreateSessionTLV struct {
BlobType blob.Type
MaxUpdates uint16
RewardBase uint32
RewardRate uint32
SweepFeeRate chainfee.SatPerKWeight
tlvStream *tlv.Stream
}
// EBlobType is an encoder for blob.Type.
func EBlobType(w io.Writer, val interface{}, buf *[8]byte) error {
if t, ok := val.(*blob.Type); ok {
return tlv.EUint16T(w, uint16(*t), buf)
}
return tlv.NewTypeForEncodingErr(val, "blob.Type")
}
// EBlobType is an decoder for blob.Type.
func DBlobType(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if typ, ok := val.(*blob.Type); ok {
var t uint16
err := tlv.DUint16(r, &t, buf, l)
if err != nil {
return err
}
*typ = blob.Type(t)
return nil
}
return tlv.NewTypeForDecodingErr(val, "blob.Type", l, 2)
}
// ESatPerKW is an encoder for lnwallet.SatPerKWeight.
func ESatPerKW(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*chainfee.SatPerKWeight); ok {
v64 := uint64(*v)
return tlv.EUint64(w, &v64, buf)
}
return tlv.NewTypeForEncodingErr(val, "chainfee.SatPerKWeight")
}
// DSatPerKW is an decoder for lnwallet.SatPerKWeight.
func DSatPerKW(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*chainfee.SatPerKWeight); ok {
var sat uint64
err := tlv.DUint64(r, &sat, buf, l)
if err != nil {
return err
}
*v = chainfee.SatPerKWeight(sat)
return nil
}
return tlv.NewTypeForDecodingErr(val, "chainfee.SatPerKWeight", l, 8)
}
// NewCreateSessionTLV initializes a new CreateSessionTLV message.
func NewCreateSessionTLV() *CreateSessionTLV {
m := &CreateSessionTLV{}
m.tlvStream = tlv.MustNewStream(
tlv.MakeStaticRecord(0, &m.BlobType, 2, EBlobType, DBlobType),
tlv.MakePrimitiveRecord(1, &m.MaxUpdates),
tlv.MakePrimitiveRecord(2, &m.RewardBase),
tlv.MakePrimitiveRecord(3, &m.RewardRate),
tlv.MakeStaticRecord(4, &m.SweepFeeRate, 8, ESatPerKW, DSatPerKW),
)
return m
}
// Encode writes the CreateSessionTLV to the passed io.Writer.
func (c *CreateSessionTLV) Encode(w io.Writer) error {
return c.tlvStream.Encode(w)
}
// Decode reads the CreateSessionTLV from the passed io.Reader.
func (c *CreateSessionTLV) Decode(r io.Reader) error {
return c.tlvStream.Decode(r)
}
// BenchmarkEncodeCreateSession benchmarks encoding of the non-TLV
// CreateSession.
func BenchmarkEncodeCreateSession(t *testing.B) {
m := &wtwire.CreateSession{}
t.ReportAllocs()
t.ResetTimer()
var err error
for i := 0; i < t.N; i++ {
err = m.Encode(ioutil.Discard, 0)
}
require.NoError(t, err)
}
// BenchmarkEncodeCreateSessionTLV benchmarks encoding of the TLV CreateSession.
func BenchmarkEncodeCreateSessionTLV(t *testing.B) {
m := NewCreateSessionTLV()
t.ReportAllocs()
t.ResetTimer()
var err error
for i := 0; i < t.N; i++ {
err = m.Encode(ioutil.Discard)
}
require.NoError(t, err)
}
// BenchmarkDecodeCreateSession benchmarks encoding of the non-TLV
// CreateSession.
func BenchmarkDecodeCreateSession(t *testing.B) {
m := &wtwire.CreateSession{}
var b bytes.Buffer
err := m.Encode(&b, 0)
require.NoError(t, err)
r := bytes.NewReader(b.Bytes())
t.ReportAllocs()
t.ResetTimer()
for i := 0; i < t.N; i++ {
r.Seek(0, 0)
err = m.Decode(r, 0)
}
require.NoError(t, err)
}
// BenchmarkDecodeCreateSessionTLV benchmarks decoding of the TLV CreateSession.
func BenchmarkDecodeCreateSessionTLV(t *testing.B) {
m := NewCreateSessionTLV()
var b bytes.Buffer
err := m.Encode(&b)
require.NoError(t, err)
r := bytes.NewReader(b.Bytes())
t.ReportAllocs()
t.ResetTimer()
for i := 0; i < t.N; i++ {
r.Seek(0, 0)
err = m.Decode(r)
}
require.NoError(t, err)
}