channeldb: convert HTLCEntry to use tlv.RecordT

This commit is contained in:
Olaoluwa Osuntokun 2024-03-30 17:28:35 -07:00 committed by Oliver Gugger
parent c1e641e9d9
commit 1b1e7a6168
No known key found for this signature in database
GPG key ID: 8E4256593F177720
5 changed files with 212 additions and 166 deletions

View file

@ -593,15 +593,21 @@ func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment,
require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch")
for i, rHtlc := range r.HTLCEntries {
cHtlc := c.Htlcs[i]
require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch")
require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(),
"Amt mismatch")
require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout,
"RefundTimeout mismatch")
require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex,
"OutputIndex mismatch")
require.Equal(t, rHtlc.Incoming, cHtlc.Incoming,
"Incoming mismatch")
require.Equal(t, rHtlc.RHash.Val[:], cHtlc.RHash[:], "RHash")
require.Equal(
t, rHtlc.Amt.Val.Int(), cHtlc.Amt.ToSatoshis(), "Amt",
)
require.Equal(
t, rHtlc.RefundTimeout.Val, cHtlc.RefundTimeout,
"RefundTimeout",
)
require.EqualValues(
t, rHtlc.OutputIndex.Val, cHtlc.OutputIndex,
"OutputIndex",
)
require.Equal(
t, rHtlc.Incoming.Val, cHtlc.Incoming, "Incoming",
)
}
}

View file

@ -54,6 +54,74 @@ var (
ErrOutputIndexTooBig = errors.New("output index is over uint16")
)
// SparsePayHash is a type alias for a 32 byte array, which when serialized is
// able to save some space by not including an empty payment hash on disk.
type SparsePayHash [32]byte
// NewSparsePayHash creates a new SparsePayHash from a 32 byte array.
func NewSparsePayHash(rHash [32]byte) SparsePayHash {
return SparsePayHash(rHash)
}
// Record returns a tlv record for the SparsePayHash.
func (s *SparsePayHash) Record() tlv.Record {
// We use a zero for the type here, as this'll be used along with the
// RecordT type.
return tlv.MakeDynamicRecord(
0, s, s.hashLen,
sparseHashEncoder, sparseHashDecoder,
)
}
// hashLen is used by MakeDynamicRecord to return the size of the RHash.
//
// NOTE: for zero hash, we return a length 0.
func (s *SparsePayHash) hashLen() uint64 {
if bytes.Equal(s[:], lntypes.ZeroHash[:]) {
return 0
}
return 32
}
// sparseHashEncoder is the customized encoder which skips encoding the empty
// hash.
func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
v, ok := val.(*SparsePayHash)
if !ok {
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
}
// If the value is an empty hash, we will skip encoding it.
if bytes.Equal(v[:], lntypes.ZeroHash[:]) {
return nil
}
vArray := (*[32]byte)(v)
return tlv.EBytes32(w, vArray, buf)
}
// sparseHashDecoder is the customized decoder which skips decoding the empty
// hash.
func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
v, ok := val.(*SparsePayHash)
if !ok {
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
}
// If the length is zero, we will skip encoding the empty hash.
if l == 0 {
return nil
}
vArray := (*[32]byte)(v)
return tlv.DBytes32(r, vArray, buf, 32)
}
// HTLCEntry specifies the minimal info needed to be stored on disk for ALL the
// historical HTLCs, which is useful for constructing RevocationLog when a
// breach is detected.
@ -72,118 +140,62 @@ var (
// made into tlv records without further conversion.
type HTLCEntry struct {
// RHash is the payment hash of the HTLC.
RHash [32]byte
RHash tlv.RecordT[tlv.TlvType0, SparsePayHash]
// RefundTimeout is the absolute timeout on the HTLC that the sender
// must wait before reclaiming the funds in limbo.
RefundTimeout uint32
RefundTimeout tlv.RecordT[tlv.TlvType1, uint32]
// OutputIndex is the output index for this particular HTLC output
// within the commitment transaction.
//
// NOTE: we use uint16 instead of int32 here to save us 2 bytes, which
// gives us a max number of HTLCs of 65K.
OutputIndex uint16
OutputIndex tlv.RecordT[tlv.TlvType2, uint16]
// Incoming denotes whether we're the receiver or the sender of this
// HTLC.
//
// NOTE: this field is the memory representation of the field
// incomingUint.
Incoming bool
Incoming tlv.RecordT[tlv.TlvType3, bool]
// Amt is the amount of satoshis this HTLC escrows.
//
// NOTE: this field is the memory representation of the field amtUint.
Amt btcutil.Amount
// amtTlv is the uint64 format of Amt. This field is created so we can
// easily make it into a tlv record and save it to disk.
//
// NOTE: we keep this field for accounting purpose only. If the disk
// space becomes an issue, we could delete this field to save us extra
// 8 bytes.
amtTlv uint64
// incomingTlv is the uint8 format of Incoming. This field is created
// so we can easily make it into a tlv record and save it to disk.
incomingTlv uint8
}
// RHashLen is used by MakeDynamicRecord to return the size of the RHash.
//
// NOTE: for zero hash, we return a length 0.
func (h *HTLCEntry) RHashLen() uint64 {
if h.RHash == lntypes.ZeroHash {
return 0
}
return 32
}
// RHashEncoder is the customized encoder which skips encoding the empty hash.
func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
v, ok := val.(*[32]byte)
if !ok {
return tlv.NewTypeForEncodingErr(val, "RHash")
}
// If the value is an empty hash, we will skip encoding it.
if *v == lntypes.ZeroHash {
return nil
}
return tlv.EBytes32(w, v, buf)
}
// RHashDecoder is the customized decoder which skips decoding the empty hash.
func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
v, ok := val.(*[32]byte)
if !ok {
return tlv.NewTypeForEncodingErr(val, "RHash")
}
// If the length is zero, we will skip encoding the empty hash.
if l == 0 {
return nil
}
return tlv.DBytes32(r, v, buf, 32)
Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]]
}
// toTlvStream converts an HTLCEntry record into a tlv representation.
func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) {
const (
// A set of tlv type definitions used to serialize htlc entries
// to the database. We define it here instead of the head of
// the file to avoid naming conflicts.
//
// NOTE: A migration should be added whenever this list
// changes.
rHashType tlv.Type = 0
refundTimeoutType tlv.Type = 1
outputIndexType tlv.Type = 2
incomingType tlv.Type = 3
amtType tlv.Type = 4
)
return tlv.NewStream(
tlv.MakeDynamicRecord(
rHashType, &h.RHash, h.RHashLen,
RHashEncoder, RHashDecoder,
),
tlv.MakePrimitiveRecord(
refundTimeoutType, &h.RefundTimeout,
),
tlv.MakePrimitiveRecord(
outputIndexType, &h.OutputIndex,
),
tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv),
// We will save 3 bytes if the amount is less or equal to
// 4,294,967,295 msat, or roughly 0.043 bitcoin.
tlv.MakeBigSizeRecord(amtType, &h.amtTlv),
h.RHash.Record(),
h.RefundTimeout.Record(),
h.OutputIndex.Record(),
h.Incoming.Record(),
h.Amt.Record(),
)
}
// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC.
func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry {
return &HTLCEntry{
RHash: tlv.NewRecordT[tlv.TlvType0](
NewSparsePayHash(htlc.RHash),
),
RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1](
htlc.RefundTimeout,
),
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2](
uint16(htlc.OutputIndex),
),
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming),
Amt: tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(htlc.Amt.ToSatoshis()),
),
}
}
// RevocationLog stores the info needed to construct a breach retribution. Its
// fields can be viewed as a subset of a ChannelCommitment's. In the database,
// all historical versions of the RevocationLog are saved using the
@ -265,13 +277,7 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment,
return ErrOutputIndexTooBig
}
entry := &HTLCEntry{
RHash: htlc.RHash,
RefundTimeout: htlc.RefundTimeout,
Incoming: htlc.Incoming,
OutputIndex: uint16(htlc.OutputIndex),
Amt: htlc.Amt.ToSatoshis(),
}
entry := NewHTLCEntryFromHTLC(htlc)
rl.HTLCEntries = append(rl.HTLCEntries, entry)
}
@ -351,14 +357,6 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
// format.
func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error {
for _, htlc := range htlcs {
// Patch the incomingTlv field.
if htlc.Incoming {
htlc.incomingTlv = 1
}
// Patch the amtTlv field.
htlc.amtTlv = uint64(htlc.Amt)
// Create the tlv stream.
tlvStream, err := htlc.toTlvStream()
if err != nil {
@ -447,14 +445,6 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
return nil, err
}
// Patch the Incoming field.
if htlc.incomingTlv == 1 {
htlc.Incoming = true
}
// Patch the Amt field.
htlc.Amt = btcutil.Amount(htlc.amtTlv)
// Append the entry.
htlcs = append(htlcs, &htlc)
}
@ -469,6 +459,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error {
if err := s.Encode(&b); err != nil {
return err
}
// Write the stream's length as a varint.
err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{})
if err != nil {

View file

@ -34,12 +34,16 @@ var (
}
testHTLCEntry = HTLCEntry{
RefundTimeout: 740_000,
OutputIndex: 10,
Incoming: true,
Amt: 1000_000,
amtTlv: 1000_000,
incomingTlv: 1,
RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32](
740_000,
),
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](
10,
),
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true),
Amt: tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(btcutil.Amount(1_000_000)),
),
}
testHTLCEntryBytes = []byte{
// Body length 23.
@ -56,6 +60,40 @@ var (
0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40,
}
testHTLCEntryHash = HTLCEntry{
RHash: tlv.NewPrimitiveRecord[tlv.TlvType0](NewSparsePayHash(
[32]byte{0x33, 0x44, 0x55},
)),
RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32](
740_000,
),
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](
10,
),
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true),
Amt: tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(btcutil.Amount(1_000_000)),
),
}
testHTLCEntryHashBytes = []byte{
// Body length 54.
0x36,
// Rhash tlv.
0x0, 0x20,
0x33, 0x44, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
// RefundTimeout tlv.
0x1, 0x4, 0x0, 0xb, 0x4a, 0xa0,
// OutputIndex tlv.
0x2, 0x2, 0x0, 0xa,
// Incoming tlv.
0x3, 0x1, 0x1,
// Amt tlv.
0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40,
}
localBalance = lnwire.MilliSatoshi(9000)
remoteBalance = lnwire.MilliSatoshi(3000)
@ -68,11 +106,11 @@ var (
CommitTx: channels.TestFundingTx,
CommitSig: bytes.Repeat([]byte{1}, 71),
Htlcs: []HTLC{{
RefundTimeout: testHTLCEntry.RefundTimeout,
OutputIndex: int32(testHTLCEntry.OutputIndex),
Incoming: testHTLCEntry.Incoming,
RefundTimeout: testHTLCEntry.RefundTimeout.Val,
OutputIndex: int32(testHTLCEntry.OutputIndex.Val),
Incoming: testHTLCEntry.Incoming.Val,
Amt: lnwire.NewMSatFromSatoshis(
testHTLCEntry.Amt,
testHTLCEntry.Amt.Val.Int(),
),
}},
}
@ -193,11 +231,6 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) {
// Copy the testHTLCEntry.
entry := testHTLCEntry
// Set the internal fields to empty values so we can test the bytes are
// padded.
entry.incomingTlv = 0
entry.amtTlv = 0
// Write the tlv stream.
buf := bytes.NewBuffer([]byte{})
err := serializeHTLCEntries(buf, []*HTLCEntry{&entry})
@ -207,6 +240,21 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) {
require.Equal(t, testHTLCEntryBytes, buf.Bytes())
}
func TestSerializeHTLCEntriesWithRHash(t *testing.T) {
t.Parallel()
// Copy the testHTLCEntry.
entry := testHTLCEntryHash
// Write the tlv stream.
buf := bytes.NewBuffer([]byte{})
err := serializeHTLCEntries(buf, []*HTLCEntry{&entry})
require.NoError(t, err)
// Check the bytes are read as expected.
require.Equal(t, testHTLCEntryHashBytes, buf.Bytes())
}
func TestSerializeHTLCEntries(t *testing.T) {
t.Parallel()
@ -215,7 +263,7 @@ func TestSerializeHTLCEntries(t *testing.T) {
// Create a fake rHash.
rHashBytes := bytes.Repeat([]byte{10}, 32)
copy(entry.RHash[:], rHashBytes)
copy(entry.RHash.Val[:], rHashBytes)
// Construct the serialized bytes.
//
@ -269,7 +317,7 @@ func TestSerializeAndDeserializeRevLog(t *testing.T) {
t, &test.revLog, test.revLogBytes,
)
testDerializeRevocationLog(
testDeserializeRevocationLog(
t, &test.revLog, test.revLogBytes,
)
})
@ -293,7 +341,7 @@ func testSerializeRevocationLog(t *testing.T, rl *RevocationLog,
require.Equal(t, revLogBytes, buf.Bytes()[:bodyIndex])
}
func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog,
func testDeserializeRevocationLog(t *testing.T, revLog *RevocationLog,
revLogBytes []byte) {
// Construct the full bytes.
@ -309,7 +357,7 @@ func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog,
require.Equal(t, *revLog, rl)
}
func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) {
func TestDeserializeHTLCEntriesEmptyRHash(t *testing.T) {
t.Parallel()
// Read the tlv stream.
@ -322,7 +370,7 @@ func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) {
require.Equal(t, &testHTLCEntry, htlcs[0])
}
func TestDerializeHTLCEntries(t *testing.T) {
func TestDeserializeHTLCEntries(t *testing.T) {
t.Parallel()
// Copy the testHTLCEntry.
@ -330,7 +378,7 @@ func TestDerializeHTLCEntries(t *testing.T) {
// Create a fake rHash.
rHashBytes := bytes.Repeat([]byte{10}, 32)
copy(entry.RHash[:], rHashBytes)
copy(entry.RHash.Val[:], rHashBytes)
// Construct the serialized bytes.
//
@ -398,11 +446,11 @@ func TestDeleteLogBucket(t *testing.T) {
err = kvdb.Update(backend, func(tx kvdb.RwTx) error {
// Create the buckets.
chanBucket, _, err := createTestRevocatoinLogBuckets(tx)
chanBucket, _, err := createTestRevocationLogBuckets(tx)
require.NoError(t, err)
// Create the buckets again should give us an error.
_, _, err = createTestRevocatoinLogBuckets(tx)
_, _, err = createTestRevocationLogBuckets(tx)
require.ErrorIs(t, err, kvdb.ErrBucketExists)
// Delete both buckets.
@ -410,7 +458,7 @@ func TestDeleteLogBucket(t *testing.T) {
require.NoError(t, err)
// Create the buckets again should give us NO error.
_, _, err = createTestRevocatoinLogBuckets(tx)
_, _, err = createTestRevocationLogBuckets(tx)
return err
}, func() {})
require.NoError(t, err)
@ -516,7 +564,7 @@ func TestPutRevocationLog(t *testing.T) {
// Construct the testing db transaction.
dbTx := func(tx kvdb.RwTx) (RevocationLog, error) {
// Create the buckets.
_, bucket, err := createTestRevocatoinLogBuckets(tx)
_, bucket, err := createTestRevocationLogBuckets(tx)
require.NoError(t, err)
// Save the log.
@ -686,7 +734,7 @@ func TestFetchRevocationLogCompatible(t *testing.T) {
}
}
func createTestRevocatoinLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket,
func createTestRevocationLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket,
kvdb.RwBucket, error) {
chanBucket, err := tx.CreateTopLevelBucket(openChannelBucket)

View file

@ -2202,8 +2202,8 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel,
// then from the PoV of the remote commitment state, they're the
// receiver of this HTLC.
scriptInfo, err := genHtlcScript(
chanState.ChanType, htlc.Incoming, lntypes.Remote,
htlc.RefundTimeout, htlc.RHash, keyRing,
chanState.ChanType, htlc.Incoming.Val, lntypes.Remote,
htlc.RefundTimeout.Val, htlc.RHash.Val, keyRing,
)
if err != nil {
return emptyRetribution, err
@ -2216,7 +2216,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel,
WitnessScript: scriptInfo.WitnessScriptToSign(),
Output: &wire.TxOut{
PkScript: scriptInfo.PkScript(),
Value: int64(htlc.Amt),
Value: int64(htlc.Amt.Val.Int()),
},
HashType: sweepSigHash(chanState.ChanType),
}
@ -2249,10 +2249,10 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel,
SignDesc: signDesc,
OutPoint: wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
Index: uint32(htlc.OutputIndex.Val),
},
SecondLevelWitnessScript: secondLevelWitnessScript,
IsIncoming: htlc.Incoming,
IsIncoming: htlc.Incoming.Val,
SecondLevelTapTweak: secondLevelTapTweak,
}, nil
}
@ -2414,13 +2414,7 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment,
continue
}
entry := &channeldb.HTLCEntry{
RHash: htlc.RHash,
RefundTimeout: htlc.RefundTimeout,
OutputIndex: uint16(htlc.OutputIndex),
Incoming: htlc.Incoming,
Amt: htlc.Amt.ToSatoshis(),
}
entry := channeldb.NewHTLCEntryFromHTLC(htlc)
hr, err := createHtlcRetribution(
chanState, keyRing, commitHash,
commitmentSecret, leaseExpiry, entry,

View file

@ -9957,9 +9957,11 @@ func TestCreateHtlcRetribution(t *testing.T) {
aliceChannel.channelState,
)
htlc := &channeldb.HTLCEntry{
Amt: testAmt,
Incoming: true,
OutputIndex: 1,
Amt: tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(testAmt),
),
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true),
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](1),
}
// Create the htlc retribution.
@ -9973,8 +9975,8 @@ func TestCreateHtlcRetribution(t *testing.T) {
// Check the fields have expected values.
require.EqualValues(t, testAmt, hr.SignDesc.Output.Value)
require.Equal(t, commitHash, hr.OutPoint.Hash)
require.EqualValues(t, htlc.OutputIndex, hr.OutPoint.Index)
require.Equal(t, htlc.Incoming, hr.IsIncoming)
require.EqualValues(t, htlc.OutputIndex.Val, hr.OutPoint.Index)
require.Equal(t, htlc.Incoming.Val, hr.IsIncoming)
}
// TestCreateBreachRetribution checks that `createBreachRetribution` behaves as
@ -10014,9 +10016,13 @@ func TestCreateBreachRetribution(t *testing.T) {
aliceChannel.channelState,
)
htlc := &channeldb.HTLCEntry{
Amt: btcutil.Amount(testAmt),
Incoming: true,
OutputIndex: uint16(htlcIndex),
Amt: tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(btcutil.Amount(testAmt)),
),
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true),
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2](
uint16(htlcIndex),
),
}
// Create a dummy revocation log.
@ -10143,11 +10149,12 @@ func TestCreateBreachRetribution(t *testing.T) {
require.Equal(t, remote, br.RemoteOutpoint)
for _, hr := range br.HtlcRetributions {
require.EqualValues(t, testAmt,
hr.SignDesc.Output.Value)
require.EqualValues(
t, testAmt, hr.SignDesc.Output.Value,
)
require.Equal(t, commitHash, hr.OutPoint.Hash)
require.EqualValues(t, htlcIndex, hr.OutPoint.Index)
require.Equal(t, htlc.Incoming, hr.IsIncoming)
require.Equal(t, htlc.Incoming.Val, hr.IsIncoming)
}
}