From 382539a6eb48f3ac6043b5de4e78d4e644ce50db Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 22 Oct 2024 12:49:34 +0200 Subject: [PATCH] channeldb/graphdb: move outpoint ser/deser funcs to graphdb We have the same helpers for writing and reading a wire.Outpoint type defined separately in a couple places. We will want to use these from the graph db package soon though so instead of defining them again there, this commit unifies things and creates a single exported set of helpers. The next commit will make use of these. --- channeldb/channel.go | 10 ++++++---- channeldb/codec.go | 30 ++---------------------------- channeldb/db.go | 22 +++++++++++++++------- channeldb/graph.go | 16 ++++++++-------- channeldb/reports.go | 7 ++++--- channeldb/reports_test.go | 3 ++- graph/db/codec.go | 33 ++++++++++++++++++++++++++++++++- 7 files changed, 69 insertions(+), 52 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index c21716a45..0f198b5b3 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -21,6 +21,7 @@ import ( "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -1330,7 +1331,7 @@ func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey, // With the bucket for the node and chain fetched, we can now go down // another level, for this channel itself. var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } chanBucket := chainBucket.NestedReadBucket(chanPointBuf.Bytes()) @@ -1377,7 +1378,7 @@ func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // With the bucket for the node and chain fetched, we can now go down // another level, for this channel itself. var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes()) @@ -1422,7 +1423,8 @@ func (c *OpenChannel) fullSync(tx kvdb.RwTx) error { } var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint); err != nil { + err := graphdb.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint) + if err != nil { return err } @@ -3822,7 +3824,7 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary, } var chanPointBuf bytes.Buffer - err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint) + err := graphdb.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint) if err != nil { return err } diff --git a/channeldb/codec.go b/channeldb/codec.go index 9917dbf86..8c39f4d73 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -18,32 +18,6 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) -// writeOutpoint writes an outpoint to the passed writer using the minimal -// amount of bytes possible. -func writeOutpoint(w io.Writer, o *wire.OutPoint) error { - if _, err := w.Write(o.Hash[:]); err != nil { - return err - } - if err := binary.Write(w, byteOrder, o.Index); err != nil { - return err - } - - return nil -} - -// readOutpoint reads an outpoint from the passed reader that was previously -// written using the writeOutpoint struct. -func readOutpoint(r io.Reader, o *wire.OutPoint) error { - if _, err := io.ReadFull(r, o.Hash[:]); err != nil { - return err - } - if err := binary.Read(r, byteOrder, &o.Index); err != nil { - return err - } - - return nil -} - // UnknownElementType is an error returned when the codec is unable to encode or // decode a particular type. type UnknownElementType struct { @@ -99,7 +73,7 @@ func WriteElement(w io.Writer, element interface{}) error { } case wire.OutPoint: - return writeOutpoint(w, &e) + return graphdb.WriteOutpoint(w, &e) case lnwire.ShortChannelID: if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { @@ -289,7 +263,7 @@ func ReadElement(r io.Reader, element interface{}) error { } case *wire.OutPoint: - return readOutpoint(r, e) + return graphdb.ReadOutpoint(r, e) case *lnwire.ShortChannelID: var a uint64 diff --git a/channeldb/db.go b/channeldb/db.go index 92e0498ec..b90b09e7f 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -30,6 +30,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/migration33" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/clock" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -646,7 +647,9 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) ( chanBucket := chainBucket.NestedReadBucket(chanPoint) var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(chanPoint), &outPoint) + err := graphdb.ReadOutpoint( + bytes.NewReader(chanPoint), &outPoint, + ) if err != nil { return err } @@ -675,7 +678,8 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( *OpenChannel, error) { var targetChanPoint bytes.Buffer - if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { + err := graphdb.WriteOutpoint(&targetChanPoint, &chanPoint) + if err != nil { return nil, err } @@ -709,7 +713,9 @@ func (c *ChannelStateDB) FetchChannelByID(tx kvdb.RTx, id lnwire.ChannelID) ( ) err := chainBkt.ForEach(func(k, _ []byte) error { var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(k), &outPoint) + err := graphdb.ReadOutpoint( + bytes.NewReader(k), &outPoint, + ) if err != nil { return err } @@ -1089,7 +1095,7 @@ func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) ( var b bytes.Buffer var err error - if err = writeOutpoint(&b, chanID); err != nil { + if err = graphdb.WriteOutpoint(&b, chanID); err != nil { return err } @@ -1131,7 +1137,9 @@ func (c *ChannelStateDB) FetchClosedChannelForID(cid lnwire.ChannelID) ( // We scan over all possible candidates for this channel ID. for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() { var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(op), &outPoint) + err := graphdb.ReadOutpoint( + bytes.NewReader(op), &outPoint, + ) if err != nil { return err } @@ -1173,7 +1181,7 @@ func (c *ChannelStateDB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { ) err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error { var b bytes.Buffer - if err := writeOutpoint(&b, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&b, chanPoint); err != nil { return err } @@ -1693,7 +1701,7 @@ func fetchHistoricalChanBucket(tx kvdb.RTx, // With the bucket for the node and chain fetched, we can now go down // another level, for the channel itself. var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } chanBucket := historicalChanBucket.NestedReadBucket( diff --git a/channeldb/graph.go b/channeldb/graph.go index a52a13f3c..9e24367a6 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -1129,7 +1129,7 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, // Finally we add it to the channel index which maps channel points // (outpoints) to the shorter channel ID's. var b bytes.Buffer - if err := writeOutpoint(&b, &edge.ChannelPoint); err != nil { + if err := graphdb.WriteOutpoint(&b, &edge.ChannelPoint); err != nil { return err } return chanIndex.Put(b.Bytes(), chanKey[:]) @@ -1336,7 +1336,7 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, // if NOT if filter var opBytes bytes.Buffer - if err := writeOutpoint(&opBytes, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&opBytes, chanPoint); err != nil { return err } @@ -1808,7 +1808,7 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { // getChanID returns the assigned channel ID for a given channel point. func getChanID(tx kvdb.RTx, chanPoint *wire.OutPoint) (uint64, error) { var b bytes.Buffer - if err := writeOutpoint(&b, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&b, chanPoint); err != nil { return 0, err } @@ -2636,7 +2636,7 @@ func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, return err } var b bytes.Buffer - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + if err := graphdb.WriteOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { return err } if err := chanIndex.Delete(b.Bytes()); err != nil { @@ -3414,7 +3414,7 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( return ErrGraphNoEdgesFound } var b bytes.Buffer - if err := writeOutpoint(&b, op); err != nil { + if err := graphdb.WriteOutpoint(&b, op); err != nil { return err } chanID := chanIndex.Get(b.Bytes()) @@ -3660,7 +3660,7 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { chanPointReader := bytes.NewReader(chanPointBytes) var chanPoint wire.OutPoint - err := readOutpoint(chanPointReader, &chanPoint) + err := graphdb.ReadOutpoint(chanPointReader, &chanPoint) if err != nil { return err } @@ -4282,7 +4282,7 @@ func putChanEdgeInfo(edgeIndex kvdb.RwBucket, return err } - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + if err := graphdb.WriteOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { return err } if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil { @@ -4366,7 +4366,7 @@ func deserializeChanEdgeInfo(r io.Reader) (models.ChannelEdgeInfo, error) { } edgeInfo.ChannelPoint = wire.OutPoint{} - if err := readOutpoint(r, &edgeInfo.ChannelPoint); err != nil { + if err := graphdb.ReadOutpoint(r, &edgeInfo.ChannelPoint); err != nil { return models.ChannelEdgeInfo{}, err } if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil { diff --git a/channeldb/reports.go b/channeldb/reports.go index c4e58d81e..4f46bd9e1 100644 --- a/channeldb/reports.go +++ b/channeldb/reports.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/tlv" ) @@ -164,7 +165,7 @@ func putReport(tx kvdb.RwTx, chainHash chainhash.Hash, // Finally write our outpoint to be used as the key for this record. var keyBuf bytes.Buffer - if err := writeOutpoint(&keyBuf, &report.OutPoint); err != nil { + if err := graphdb.WriteOutpoint(&keyBuf, &report.OutPoint); err != nil { return err } @@ -317,7 +318,7 @@ func fetchReportWriteBucket(tx kvdb.RwTx, chainHash chainhash.Hash, } var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } @@ -341,7 +342,7 @@ func fetchReportReadBucket(tx kvdb.RTx, chainHash chainhash.Hash, // With the bucket for the node and chain fetched, we can now go down // another level, for the channel itself. var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } diff --git a/channeldb/reports_test.go b/channeldb/reports_test.go index 48a41914f..1148fdf03 100644 --- a/channeldb/reports_test.go +++ b/channeldb/reports_test.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/stretchr/testify/require" ) @@ -137,7 +138,7 @@ func TestFetchChannelWriteBucket(t *testing.T) { error) { var chanPointBuf bytes.Buffer - err := writeOutpoint(&chanPointBuf, &testChanPoint1) + err := graphdb.WriteOutpoint(&chanPointBuf, &testChanPoint1) require.NoError(t, err) return chainHash.CreateBucketIfNotExists(chanPointBuf.Bytes()) diff --git a/graph/db/codec.go b/graph/db/codec.go index c649f9fa3..029f9b93d 100644 --- a/graph/db/codec.go +++ b/graph/db/codec.go @@ -1,8 +1,39 @@ package graphdb -import "encoding/binary" +import ( + "encoding/binary" + "io" + + "github.com/btcsuite/btcd/wire" +) var ( // byteOrder defines the preferred byte order, which is Big Endian. byteOrder = binary.BigEndian ) + +// WriteOutpoint writes an outpoint to the passed writer using the minimal +// amount of bytes possible. +func WriteOutpoint(w io.Writer, o *wire.OutPoint) error { + if _, err := w.Write(o.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, byteOrder, o.Index); err != nil { + return err + } + + return nil +} + +// ReadOutpoint reads an outpoint from the passed reader that was previously +// written using the WriteOutpoint struct. +func ReadOutpoint(r io.Reader, o *wire.OutPoint) error { + if _, err := io.ReadFull(r, o.Hash[:]); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &o.Index); err != nil { + return err + } + + return nil +}