channeldb: prepare RevocationLog for optional fields

This commit is a pure refactor. It restructures the RevocationLog
serialize and deserialize methods so that optional TLV fields can be
easily added.
This commit is contained in:
Elle Mouton 2023-02-02 08:52:21 +02:00
parent 38dc67e1ef
commit ce8e7ecfa7
No known key found for this signature in database
GPG key ID: D7D916376026F177
3 changed files with 90 additions and 74 deletions

View file

@ -16,8 +16,18 @@ import (
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
) )
// OutputIndexEmpty is used when the output index doesn't exist. const (
const OutputIndexEmpty = math.MaxUint16 // OutputIndexEmpty is used when the output index doesn't exist.
OutputIndexEmpty = math.MaxUint16
// A set of tlv type definitions used to serialize the body of
// revocation logs to the database.
//
// NOTE: A migration should be added whenever this list changes.
revLogOurOutputIndexType tlv.Type = 0
revLogTheirOutputIndexType tlv.Type = 1
revLogCommitTxHashType tlv.Type = 2
)
var ( var (
// revocationLogBucketDeprecated is dedicated for storing the necessary // revocationLogBucketDeprecated is dedicated for storing the necessary
@ -208,29 +218,6 @@ type RevocationLog struct {
HTLCEntries []*HTLCEntry HTLCEntries []*HTLCEntry
} }
// toTlvStream converts an RevocationLog record into a tlv representation.
func (rl *RevocationLog) toTlvStream() (*tlv.Stream, error) {
const (
// A set of tlv type definitions used to serialize the body of
// revocation logs to the database. We define it here instead
// of the head of the file to avoid naming conflicts.
//
// NOTE: A migration should be added whenever this list
// changes.
ourOutputIndexType tlv.Type = 0
theirOutputIndexType tlv.Type = 1
commitTxHashType tlv.Type = 2
)
return tlv.NewStream(
tlv.MakePrimitiveRecord(ourOutputIndexType, &rl.OurOutputIndex),
tlv.MakePrimitiveRecord(
theirOutputIndexType, &rl.TheirOutputIndex,
),
tlv.MakePrimitiveRecord(commitTxHashType, &rl.CommitTxHash),
)
}
// putRevocationLog uses the fields `CommitTx` and `Htlcs` from a // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a
// ChannelCommitment to construct a revocation log entry and saves them to // ChannelCommitment to construct a revocation log entry and saves them to
// disk. It also saves our output index and their output index, which are // disk. It also saves our output index and their output index, which are
@ -304,8 +291,21 @@ func fetchRevocationLog(log kvdb.RBucket,
// serializeRevocationLog serializes a RevocationLog record based on tlv // serializeRevocationLog serializes a RevocationLog record based on tlv
// format. // format.
func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
// Add the tlv records for all non-optional fields.
records := []tlv.Record{
tlv.MakePrimitiveRecord(
revLogOurOutputIndexType, &rl.OurOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogCommitTxHashType, &rl.CommitTxHash,
),
}
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := rl.toTlvStream() tlvStream, err := tlv.NewStream(records...)
if err != nil { if err != nil {
return err return err
} }
@ -351,13 +351,20 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) {
var rl RevocationLog var rl RevocationLog
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := rl.toTlvStream() tlvStream, err := tlv.NewStream(
if err != nil { tlv.MakePrimitiveRecord(
return rl, err revLogOurOutputIndexType, &rl.OurOutputIndex,
} ),
tlv.MakePrimitiveRecord(
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogCommitTxHashType, &rl.CommitTxHash,
),
)
// Read the tlv stream. // Read the tlv stream.
if err := readTlvStream(r, tlvStream); err != nil { if _, err := readTlvStream(r, tlvStream); err != nil {
return rl, err return rl, err
} }
@ -382,7 +389,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
} }
// Read the HTLC entry. // Read the HTLC entry.
if err := readTlvStream(r, tlvStream); err != nil { if _, err := readTlvStream(r, tlvStream); err != nil {
// We've reached the end when hitting an EOF. // We've reached the end when hitting an EOF.
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
break break
@ -427,7 +434,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error {
// readTlvStream is a helper function that decodes the tlv stream from the // readTlvStream is a helper function that decodes the tlv stream from the
// reader. // reader.
func readTlvStream(r io.Reader, s *tlv.Stream) error { func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) {
var bodyLen uint64 var bodyLen uint64
// Read the stream's length. // Read the stream's length.
@ -436,16 +443,17 @@ func readTlvStream(r io.Reader, s *tlv.Stream) error {
// We'll convert any EOFs to ErrUnexpectedEOF, since this results in an // We'll convert any EOFs to ErrUnexpectedEOF, since this results in an
// invalid record. // invalid record.
case err == io.EOF: case err == io.EOF:
return io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
// Other unexpected errors. // Other unexpected errors.
case err != nil: case err != nil:
return err return nil, err
} }
// TODO(yy): add overflow check. // TODO(yy): add overflow check.
lr := io.LimitReader(r, int64(bodyLen)) lr := io.LimitReader(r, int64(bodyLen))
return s.Decode(lr)
return s.DecodeWithParsedTypes(lr)
} }
// fetchLogBucket returns a read bucket by visiting both the new and the old // fetchLogBucket returns a read bucket by visiting both the new and the old

View file

@ -12,8 +12,18 @@ import (
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
) )
// OutputIndexEmpty is used when the output index doesn't exist. const (
const OutputIndexEmpty = math.MaxUint16 // OutputIndexEmpty is used when the output index doesn't exist.
OutputIndexEmpty = math.MaxUint16
// A set of tlv type definitions used to serialize the body of
// revocation logs to the database.
//
// NOTE: A migration should be added whenever this list changes.
revLogOurOutputIndexType tlv.Type = 0
revLogTheirOutputIndexType tlv.Type = 1
revLogCommitTxHashType tlv.Type = 2
)
var ( var (
// revocationLogBucketDeprecated is dedicated for storing the necessary // revocationLogBucketDeprecated is dedicated for storing the necessary
@ -196,29 +206,6 @@ type RevocationLog struct {
HTLCEntries []*HTLCEntry HTLCEntries []*HTLCEntry
} }
// toTlvStream converts an RevocationLog record into a tlv representation.
func (rl *RevocationLog) toTlvStream() (*tlv.Stream, error) {
const (
// A set of tlv type definitions used to serialize the body of
// revocation logs to the database. We define it here instead
// of the head of the file to avoid naming conflicts.
//
// NOTE: A migration should be added whenever this list
// changes.
ourOutputIndexType tlv.Type = 0
theirOutputIndexType tlv.Type = 1
commitTxHashType tlv.Type = 2
)
return tlv.NewStream(
tlv.MakePrimitiveRecord(ourOutputIndexType, &rl.OurOutputIndex),
tlv.MakePrimitiveRecord(
theirOutputIndexType, &rl.TheirOutputIndex,
),
tlv.MakePrimitiveRecord(commitTxHashType, &rl.CommitTxHash),
)
}
// putRevocationLog uses the fields `CommitTx` and `Htlcs` from a // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a
// ChannelCommitment to construct a revocation log entry and saves them to // ChannelCommitment to construct a revocation log entry and saves them to
// disk. It also saves our output index and their output index, which are // disk. It also saves our output index and their output index, which are
@ -292,8 +279,21 @@ func fetchRevocationLog(log kvdb.RBucket,
// serializeRevocationLog serializes a RevocationLog record based on tlv // serializeRevocationLog serializes a RevocationLog record based on tlv
// format. // format.
func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
// Add the tlv records for all non-optional fields.
records := []tlv.Record{
tlv.MakePrimitiveRecord(
revLogOurOutputIndexType, &rl.OurOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogCommitTxHashType, &rl.CommitTxHash,
),
}
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := rl.toTlvStream() tlvStream, err := tlv.NewStream(records...)
if err != nil { if err != nil {
return err return err
} }
@ -339,13 +339,20 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) {
var rl RevocationLog var rl RevocationLog
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := rl.toTlvStream() tlvStream, err := tlv.NewStream(
if err != nil { tlv.MakePrimitiveRecord(
return rl, err revLogOurOutputIndexType, &rl.OurOutputIndex,
} ),
tlv.MakePrimitiveRecord(
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogCommitTxHashType, &rl.CommitTxHash,
),
)
// Read the tlv stream. // Read the tlv stream.
if err := readTlvStream(r, tlvStream); err != nil { if _, err := readTlvStream(r, tlvStream); err != nil {
return rl, err return rl, err
} }
@ -370,7 +377,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
} }
// Read the HTLC entry. // Read the HTLC entry.
if err := readTlvStream(r, tlvStream); err != nil { if _, err := readTlvStream(r, tlvStream); err != nil {
// We've reached the end when hitting an EOF. // We've reached the end when hitting an EOF.
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
break break
@ -415,7 +422,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error {
// readTlvStream is a helper function that decodes the tlv stream from the // readTlvStream is a helper function that decodes the tlv stream from the
// reader. // reader.
func readTlvStream(r io.Reader, s *tlv.Stream) error { func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) {
var bodyLen uint64 var bodyLen uint64
// Read the stream's length. // Read the stream's length.
@ -424,16 +431,17 @@ func readTlvStream(r io.Reader, s *tlv.Stream) error {
// We'll convert any EOFs to ErrUnexpectedEOF, since this results in an // We'll convert any EOFs to ErrUnexpectedEOF, since this results in an
// invalid record. // invalid record.
case err == io.EOF: case err == io.EOF:
return io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
// Other unexpected errors. // Other unexpected errors.
case err != nil: case err != nil:
return err return nil, err
} }
// TODO(yy): add overflow check. // TODO(yy): add overflow check.
lr := io.LimitReader(r, int64(bodyLen)) lr := io.LimitReader(r, int64(bodyLen))
return s.Decode(lr)
return s.DecodeWithParsedTypes(lr)
} }
// fetchOldRevocationLog finds the revocation log from the deprecated // fetchOldRevocationLog finds the revocation log from the deprecated

View file

@ -127,7 +127,7 @@ func TestReadTLVStream(t *testing.T) {
// Read the tlv stream. // Read the tlv stream.
buf := bytes.NewBuffer(testValueBytes) buf := bytes.NewBuffer(testValueBytes)
err = readTlvStream(buf, ts) _, err = readTlvStream(buf, ts)
require.NoError(t, err) require.NoError(t, err)
// Check the bytes are read as expected. // Check the bytes are read as expected.
@ -150,7 +150,7 @@ func TestReadTLVStreamErr(t *testing.T) {
// Read the tlv stream. // Read the tlv stream.
buf := bytes.NewBuffer(b) buf := bytes.NewBuffer(b)
err = readTlvStream(buf, ts) _, err = readTlvStream(buf, ts)
require.ErrorIs(t, err, io.ErrUnexpectedEOF) require.ErrorIs(t, err, io.ErrUnexpectedEOF)
// Check the bytes are not read. // Check the bytes are not read.