mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-04 01:36:24 +01:00
channeldb: prepare RevocationLog for optional fields
This commit is a pure refactor. It restructures the RevocationLog serialize and deserialize methods so that optional TLV fields can be easily added.
This commit is contained in:
parent
38dc67e1ef
commit
ce8e7ecfa7
3 changed files with 90 additions and 74 deletions
|
@ -16,8 +16,18 @@ import (
|
|||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// OutputIndexEmpty is used when the output index doesn't exist.
|
||||
const OutputIndexEmpty = math.MaxUint16
|
||||
const (
|
||||
// OutputIndexEmpty is used when the output index doesn't exist.
|
||||
OutputIndexEmpty = math.MaxUint16
|
||||
|
||||
// A set of tlv type definitions used to serialize the body of
|
||||
// revocation logs to the database.
|
||||
//
|
||||
// NOTE: A migration should be added whenever this list changes.
|
||||
revLogOurOutputIndexType tlv.Type = 0
|
||||
revLogTheirOutputIndexType tlv.Type = 1
|
||||
revLogCommitTxHashType tlv.Type = 2
|
||||
)
|
||||
|
||||
var (
|
||||
// revocationLogBucketDeprecated is dedicated for storing the necessary
|
||||
|
@ -208,29 +218,6 @@ type RevocationLog struct {
|
|||
HTLCEntries []*HTLCEntry
|
||||
}
|
||||
|
||||
// toTlvStream converts an RevocationLog record into a tlv representation.
|
||||
func (rl *RevocationLog) toTlvStream() (*tlv.Stream, error) {
|
||||
const (
|
||||
// A set of tlv type definitions used to serialize the body of
|
||||
// revocation logs to the database. We define it here instead
|
||||
// of the head of the file to avoid naming conflicts.
|
||||
//
|
||||
// NOTE: A migration should be added whenever this list
|
||||
// changes.
|
||||
ourOutputIndexType tlv.Type = 0
|
||||
theirOutputIndexType tlv.Type = 1
|
||||
commitTxHashType tlv.Type = 2
|
||||
)
|
||||
|
||||
return tlv.NewStream(
|
||||
tlv.MakePrimitiveRecord(ourOutputIndexType, &rl.OurOutputIndex),
|
||||
tlv.MakePrimitiveRecord(
|
||||
theirOutputIndexType, &rl.TheirOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(commitTxHashType, &rl.CommitTxHash),
|
||||
)
|
||||
}
|
||||
|
||||
// putRevocationLog uses the fields `CommitTx` and `Htlcs` from a
|
||||
// ChannelCommitment to construct a revocation log entry and saves them to
|
||||
// disk. It also saves our output index and their output index, which are
|
||||
|
@ -304,8 +291,21 @@ func fetchRevocationLog(log kvdb.RBucket,
|
|||
// serializeRevocationLog serializes a RevocationLog record based on tlv
|
||||
// format.
|
||||
func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
|
||||
// Add the tlv records for all non-optional fields.
|
||||
records := []tlv.Record{
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogOurOutputIndexType, &rl.OurOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogCommitTxHashType, &rl.CommitTxHash,
|
||||
),
|
||||
}
|
||||
|
||||
// Create the tlv stream.
|
||||
tlvStream, err := rl.toTlvStream()
|
||||
tlvStream, err := tlv.NewStream(records...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -351,13 +351,20 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) {
|
|||
var rl RevocationLog
|
||||
|
||||
// Create the tlv stream.
|
||||
tlvStream, err := rl.toTlvStream()
|
||||
if err != nil {
|
||||
return rl, err
|
||||
}
|
||||
tlvStream, err := tlv.NewStream(
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogOurOutputIndexType, &rl.OurOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogCommitTxHashType, &rl.CommitTxHash,
|
||||
),
|
||||
)
|
||||
|
||||
// Read the tlv stream.
|
||||
if err := readTlvStream(r, tlvStream); err != nil {
|
||||
if _, err := readTlvStream(r, tlvStream); err != nil {
|
||||
return rl, err
|
||||
}
|
||||
|
||||
|
@ -382,7 +389,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
|
|||
}
|
||||
|
||||
// Read the HTLC entry.
|
||||
if err := readTlvStream(r, tlvStream); err != nil {
|
||||
if _, err := readTlvStream(r, tlvStream); err != nil {
|
||||
// We've reached the end when hitting an EOF.
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
break
|
||||
|
@ -427,7 +434,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error {
|
|||
|
||||
// readTlvStream is a helper function that decodes the tlv stream from the
|
||||
// reader.
|
||||
func readTlvStream(r io.Reader, s *tlv.Stream) error {
|
||||
func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) {
|
||||
var bodyLen uint64
|
||||
|
||||
// Read the stream's length.
|
||||
|
@ -436,16 +443,17 @@ func readTlvStream(r io.Reader, s *tlv.Stream) error {
|
|||
// We'll convert any EOFs to ErrUnexpectedEOF, since this results in an
|
||||
// invalid record.
|
||||
case err == io.EOF:
|
||||
return io.ErrUnexpectedEOF
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(yy): add overflow check.
|
||||
lr := io.LimitReader(r, int64(bodyLen))
|
||||
return s.Decode(lr)
|
||||
|
||||
return s.DecodeWithParsedTypes(lr)
|
||||
}
|
||||
|
||||
// fetchLogBucket returns a read bucket by visiting both the new and the old
|
||||
|
|
|
@ -12,8 +12,18 @@ import (
|
|||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// OutputIndexEmpty is used when the output index doesn't exist.
|
||||
const OutputIndexEmpty = math.MaxUint16
|
||||
const (
|
||||
// OutputIndexEmpty is used when the output index doesn't exist.
|
||||
OutputIndexEmpty = math.MaxUint16
|
||||
|
||||
// A set of tlv type definitions used to serialize the body of
|
||||
// revocation logs to the database.
|
||||
//
|
||||
// NOTE: A migration should be added whenever this list changes.
|
||||
revLogOurOutputIndexType tlv.Type = 0
|
||||
revLogTheirOutputIndexType tlv.Type = 1
|
||||
revLogCommitTxHashType tlv.Type = 2
|
||||
)
|
||||
|
||||
var (
|
||||
// revocationLogBucketDeprecated is dedicated for storing the necessary
|
||||
|
@ -196,29 +206,6 @@ type RevocationLog struct {
|
|||
HTLCEntries []*HTLCEntry
|
||||
}
|
||||
|
||||
// toTlvStream converts an RevocationLog record into a tlv representation.
|
||||
func (rl *RevocationLog) toTlvStream() (*tlv.Stream, error) {
|
||||
const (
|
||||
// A set of tlv type definitions used to serialize the body of
|
||||
// revocation logs to the database. We define it here instead
|
||||
// of the head of the file to avoid naming conflicts.
|
||||
//
|
||||
// NOTE: A migration should be added whenever this list
|
||||
// changes.
|
||||
ourOutputIndexType tlv.Type = 0
|
||||
theirOutputIndexType tlv.Type = 1
|
||||
commitTxHashType tlv.Type = 2
|
||||
)
|
||||
|
||||
return tlv.NewStream(
|
||||
tlv.MakePrimitiveRecord(ourOutputIndexType, &rl.OurOutputIndex),
|
||||
tlv.MakePrimitiveRecord(
|
||||
theirOutputIndexType, &rl.TheirOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(commitTxHashType, &rl.CommitTxHash),
|
||||
)
|
||||
}
|
||||
|
||||
// putRevocationLog uses the fields `CommitTx` and `Htlcs` from a
|
||||
// ChannelCommitment to construct a revocation log entry and saves them to
|
||||
// disk. It also saves our output index and their output index, which are
|
||||
|
@ -292,8 +279,21 @@ func fetchRevocationLog(log kvdb.RBucket,
|
|||
// serializeRevocationLog serializes a RevocationLog record based on tlv
|
||||
// format.
|
||||
func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
|
||||
// Add the tlv records for all non-optional fields.
|
||||
records := []tlv.Record{
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogOurOutputIndexType, &rl.OurOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogCommitTxHashType, &rl.CommitTxHash,
|
||||
),
|
||||
}
|
||||
|
||||
// Create the tlv stream.
|
||||
tlvStream, err := rl.toTlvStream()
|
||||
tlvStream, err := tlv.NewStream(records...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -339,13 +339,20 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) {
|
|||
var rl RevocationLog
|
||||
|
||||
// Create the tlv stream.
|
||||
tlvStream, err := rl.toTlvStream()
|
||||
if err != nil {
|
||||
return rl, err
|
||||
}
|
||||
tlvStream, err := tlv.NewStream(
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogOurOutputIndexType, &rl.OurOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(
|
||||
revLogCommitTxHashType, &rl.CommitTxHash,
|
||||
),
|
||||
)
|
||||
|
||||
// Read the tlv stream.
|
||||
if err := readTlvStream(r, tlvStream); err != nil {
|
||||
if _, err := readTlvStream(r, tlvStream); err != nil {
|
||||
return rl, err
|
||||
}
|
||||
|
||||
|
@ -370,7 +377,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
|
|||
}
|
||||
|
||||
// Read the HTLC entry.
|
||||
if err := readTlvStream(r, tlvStream); err != nil {
|
||||
if _, err := readTlvStream(r, tlvStream); err != nil {
|
||||
// We've reached the end when hitting an EOF.
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
break
|
||||
|
@ -415,7 +422,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error {
|
|||
|
||||
// readTlvStream is a helper function that decodes the tlv stream from the
|
||||
// reader.
|
||||
func readTlvStream(r io.Reader, s *tlv.Stream) error {
|
||||
func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) {
|
||||
var bodyLen uint64
|
||||
|
||||
// Read the stream's length.
|
||||
|
@ -424,16 +431,17 @@ func readTlvStream(r io.Reader, s *tlv.Stream) error {
|
|||
// We'll convert any EOFs to ErrUnexpectedEOF, since this results in an
|
||||
// invalid record.
|
||||
case err == io.EOF:
|
||||
return io.ErrUnexpectedEOF
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(yy): add overflow check.
|
||||
lr := io.LimitReader(r, int64(bodyLen))
|
||||
return s.Decode(lr)
|
||||
|
||||
return s.DecodeWithParsedTypes(lr)
|
||||
}
|
||||
|
||||
// fetchOldRevocationLog finds the revocation log from the deprecated
|
||||
|
|
|
@ -127,7 +127,7 @@ func TestReadTLVStream(t *testing.T) {
|
|||
|
||||
// Read the tlv stream.
|
||||
buf := bytes.NewBuffer(testValueBytes)
|
||||
err = readTlvStream(buf, ts)
|
||||
_, err = readTlvStream(buf, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the bytes are read as expected.
|
||||
|
@ -150,7 +150,7 @@ func TestReadTLVStreamErr(t *testing.T) {
|
|||
|
||||
// Read the tlv stream.
|
||||
buf := bytes.NewBuffer(b)
|
||||
err = readTlvStream(buf, ts)
|
||||
_, err = readTlvStream(buf, ts)
|
||||
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||
|
||||
// Check the bytes are not read.
|
||||
|
|
Loading…
Add table
Reference in a new issue