mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 14:45:23 +01:00
Merge pull request #7380 from ellemouton/wtclientDiskQueue
watchtower: replace in-mem task queue with a disk over-flow queue
This commit is contained in:
commit
bdb41e5867
@ -12,6 +12,8 @@
|
||||
wtdb.BackupIDs](https://github.com/lightningnetwork/lnd/pull/7623) instead of
|
||||
the entire retribution struct. This reduces the amount of data that needs to
|
||||
be held in memory.
|
||||
* [Replace in-mem task pipeline with a disk-overflow
|
||||
queue](https://github.com/lightningnetwork/lnd/pull/7380)
|
||||
|
||||
## Misc
|
||||
|
||||
|
@ -23,6 +23,10 @@ type WtClient struct {
|
||||
// before sending the DeleteSession message to the tower server.
|
||||
SessionCloseRange uint32 `long:"session-close-range" description:"The range over which to choose a random number of blocks to wait after the last channel of a session is closed before sending the DeleteSession message to the tower server. Set to 1 for no delay."`
|
||||
|
||||
// MaxTasksInMemQueue is the maximum number of back-up tasks that should
|
||||
// be queued in memory before overflowing to disk.
|
||||
MaxTasksInMemQueue uint64 `long:"max-tasks-in-mem-queue" description:"The maximum number of updates that should be queued in memory before overflowing to disk."`
|
||||
|
||||
// MaxUpdates is the maximum number of updates to be backed up in a
|
||||
// single tower sessions.
|
||||
MaxUpdates uint16 `long:"max-updates" description:"The maximum number of updates to be backed up in a single session."`
|
||||
|
@ -1007,6 +1007,10 @@ litecoin.node=ltcd
|
||||
; The maximum number of updates to include in a tower session.
|
||||
; wtclient.max-updates=1024
|
||||
|
||||
; The maximum number of back-up tasks that should be queued in memory before
|
||||
; overflowing to disk.
|
||||
; wtclient.max-tasks-in-mem-queue=2000
|
||||
|
||||
[healthcheck]
|
||||
|
||||
; The number of times we should attempt to query our chain backend before
|
||||
|
51
server.go
51
server.go
@ -1516,6 +1516,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
||||
sessionCloseRange = cfg.WtClient.SessionCloseRange
|
||||
}
|
||||
|
||||
maxTasksInMemQueue := uint64(wtclient.DefaultMaxTasksInMemQueue)
|
||||
if cfg.WtClient.MaxTasksInMemQueue != 0 {
|
||||
maxTasksInMemQueue = cfg.WtClient.MaxTasksInMemQueue
|
||||
}
|
||||
|
||||
if err := policy.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1568,17 +1573,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
||||
return s.channelNotifier.
|
||||
SubscribeChannelEvents()
|
||||
},
|
||||
Signer: cc.Wallet.Cfg.Signer,
|
||||
NewAddress: newSweepPkScriptGen(cc.Wallet),
|
||||
SecretKeyRing: s.cc.KeyRing,
|
||||
Dial: cfg.net.Dial,
|
||||
AuthDial: authDial,
|
||||
DB: dbs.TowerClientDB,
|
||||
Policy: policy,
|
||||
ChainHash: *s.cfg.ActiveNetParams.GenesisHash,
|
||||
MinBackoff: 10 * time.Second,
|
||||
MaxBackoff: 5 * time.Minute,
|
||||
ForceQuitDelay: wtclient.DefaultForceQuitDelay,
|
||||
Signer: cc.Wallet.Cfg.Signer,
|
||||
NewAddress: newSweepPkScriptGen(cc.Wallet),
|
||||
SecretKeyRing: s.cc.KeyRing,
|
||||
Dial: cfg.net.Dial,
|
||||
AuthDial: authDial,
|
||||
DB: dbs.TowerClientDB,
|
||||
Policy: policy,
|
||||
ChainHash: *s.cfg.ActiveNetParams.GenesisHash,
|
||||
MinBackoff: 10 * time.Second,
|
||||
MaxBackoff: 5 * time.Minute,
|
||||
ForceQuitDelay: wtclient.DefaultForceQuitDelay,
|
||||
MaxTasksInMemQueue: maxTasksInMemQueue,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1601,17 +1607,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
||||
return s.channelNotifier.
|
||||
SubscribeChannelEvents()
|
||||
},
|
||||
Signer: cc.Wallet.Cfg.Signer,
|
||||
NewAddress: newSweepPkScriptGen(cc.Wallet),
|
||||
SecretKeyRing: s.cc.KeyRing,
|
||||
Dial: cfg.net.Dial,
|
||||
AuthDial: authDial,
|
||||
DB: dbs.TowerClientDB,
|
||||
Policy: anchorPolicy,
|
||||
ChainHash: *s.cfg.ActiveNetParams.GenesisHash,
|
||||
MinBackoff: 10 * time.Second,
|
||||
MaxBackoff: 5 * time.Minute,
|
||||
ForceQuitDelay: wtclient.DefaultForceQuitDelay,
|
||||
Signer: cc.Wallet.Cfg.Signer,
|
||||
NewAddress: newSweepPkScriptGen(cc.Wallet),
|
||||
SecretKeyRing: s.cc.KeyRing,
|
||||
Dial: cfg.net.Dial,
|
||||
AuthDial: authDial,
|
||||
DB: dbs.TowerClientDB,
|
||||
Policy: anchorPolicy,
|
||||
ChainHash: *s.cfg.ActiveNetParams.GenesisHash,
|
||||
MinBackoff: 10 * time.Second,
|
||||
MaxBackoff: 5 * time.Minute,
|
||||
ForceQuitDelay: wtclient.DefaultForceQuitDelay,
|
||||
MaxTasksInMemQueue: maxTasksInMemQueue,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -51,6 +51,10 @@ const (
|
||||
// random number of blocks to delay closing a session after its last
|
||||
// channel has been closed.
|
||||
DefaultSessionCloseRange = 288
|
||||
|
||||
// DefaultMaxTasksInMemQueue is the maximum number of items to be held
|
||||
// in the in-memory queue.
|
||||
DefaultMaxTasksInMemQueue = 2000
|
||||
)
|
||||
|
||||
// genSessionFilter constructs a filter that can be used to select sessions only
|
||||
@ -240,6 +244,10 @@ type Config struct {
|
||||
// number of blocks to delay closing a session after its last channel
|
||||
// has been closed.
|
||||
SessionCloseRange uint32
|
||||
|
||||
// MaxTasksInMemQueue is the maximum number of backup tasks that should
|
||||
// be kept in-memory. Any more tasks will overflow to disk.
|
||||
MaxTasksInMemQueue uint64
|
||||
}
|
||||
|
||||
// BreachRetributionBuilder is a function that can be used to construct a
|
||||
@ -293,7 +301,7 @@ type TowerClient struct {
|
||||
|
||||
log btclog.Logger
|
||||
|
||||
pipeline *taskPipeline
|
||||
pipeline *DiskOverflowQueue[*wtdb.BackupID]
|
||||
|
||||
negotiator SessionNegotiator
|
||||
candidateTowers TowerCandidateIterator
|
||||
@ -355,10 +363,21 @@ func New(config *Config) (*TowerClient, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
policy = cfg.Policy.BlobType.String()
|
||||
queueDB = cfg.DB.GetDBQueue([]byte(policy))
|
||||
)
|
||||
queue, err := NewDiskOverflowQueue[*wtdb.BackupID](
|
||||
queueDB, cfg.MaxTasksInMemQueue, plog,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &TowerClient{
|
||||
cfg: cfg,
|
||||
log: plog,
|
||||
pipeline: newTaskPipeline(plog),
|
||||
pipeline: queue,
|
||||
chanCommitHeights: make(map[lnwire.ChannelID]uint64),
|
||||
activeSessions: make(sessionQueueSet),
|
||||
summaries: chanSummaries,
|
||||
@ -671,6 +690,7 @@ func (c *TowerClient) Start() error {
|
||||
|
||||
// Stop idempotently initiates a graceful shutdown of the watchtower client.
|
||||
func (c *TowerClient) Stop() error {
|
||||
var returnErr error
|
||||
c.stopped.Do(func() {
|
||||
c.log.Debugf("Stopping watchtower client")
|
||||
|
||||
@ -693,7 +713,10 @@ func (c *TowerClient) Stop() error {
|
||||
// updates from being accepted. In practice, the links should be
|
||||
// shutdown before the client has been stopped, so all updates
|
||||
// would have been added prior.
|
||||
c.pipeline.Stop()
|
||||
err := c.pipeline.Stop()
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
}
|
||||
|
||||
// 3. Once the backup queue has shutdown, wait for the main
|
||||
// dispatcher to exit. The backup queue will signal it's
|
||||
@ -724,7 +747,8 @@ func (c *TowerClient) Stop() error {
|
||||
|
||||
c.log.Debugf("Client successfully stopped, stats: %s", c.stats)
|
||||
})
|
||||
return nil
|
||||
|
||||
return returnErr
|
||||
}
|
||||
|
||||
// ForceQuit idempotently initiates an unclean shutdown of the watchtower
|
||||
@ -737,7 +761,10 @@ func (c *TowerClient) ForceQuit() {
|
||||
// updates from being accepted. In practice, the links should be
|
||||
// shutdown before the client has been stopped, so all updates
|
||||
// would have been added prior.
|
||||
c.pipeline.ForceQuit()
|
||||
err := c.pipeline.Stop()
|
||||
if err != nil {
|
||||
c.log.Errorf("could not stop backup queue: %v", err)
|
||||
}
|
||||
|
||||
// 2. Once the backup queue has shutdown, wait for the main
|
||||
// dispatcher to exit. The backup queue will signal it's
|
||||
@ -836,7 +863,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||
CommitHeight: stateNum,
|
||||
}
|
||||
|
||||
return c.pipeline.QueueBackupTask(id)
|
||||
return c.pipeline.QueueBackupID(id)
|
||||
}
|
||||
|
||||
// nextSessionQueue attempts to fetch an active session from our set of
|
||||
@ -1323,7 +1350,7 @@ func (c *TowerClient) backupDispatcher() {
|
||||
|
||||
// Process each backup task serially from the queue of
|
||||
// revoked states.
|
||||
case task, ok := <-c.pipeline.NewBackupTasks():
|
||||
case task, ok := <-c.pipeline.NextBackupID():
|
||||
// All backups in the pipeline have been
|
||||
// processed, it is now safe to exit.
|
||||
if !ok {
|
||||
@ -1635,8 +1662,6 @@ func (c *TowerClient) AddTower(addr *lnwire.NetAddress) error {
|
||||
}:
|
||||
case <-c.pipeline.quit:
|
||||
return ErrClientExiting
|
||||
case <-c.pipeline.forceQuit:
|
||||
return ErrClientExiting
|
||||
}
|
||||
|
||||
select {
|
||||
@ -1644,8 +1669,6 @@ func (c *TowerClient) AddTower(addr *lnwire.NetAddress) error {
|
||||
return err
|
||||
case <-c.pipeline.quit:
|
||||
return ErrClientExiting
|
||||
case <-c.pipeline.forceQuit:
|
||||
return ErrClientExiting
|
||||
}
|
||||
}
|
||||
|
||||
@ -1702,8 +1725,6 @@ func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey,
|
||||
}:
|
||||
case <-c.pipeline.quit:
|
||||
return ErrClientExiting
|
||||
case <-c.pipeline.forceQuit:
|
||||
return ErrClientExiting
|
||||
}
|
||||
|
||||
select {
|
||||
@ -1711,8 +1732,6 @@ func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey,
|
||||
return err
|
||||
case <-c.pipeline.quit:
|
||||
return ErrClientExiting
|
||||
case <-c.pipeline.forceQuit:
|
||||
return ErrClientExiting
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -509,12 +509,13 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||
NewAddress: func() ([]byte, error) {
|
||||
return addrScript, nil
|
||||
},
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
MinBackoff: time.Millisecond,
|
||||
MaxBackoff: time.Second,
|
||||
ForceQuitDelay: 10 * time.Second,
|
||||
SessionCloseRange: 1,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
MinBackoff: time.Millisecond,
|
||||
MaxBackoff: time.Second,
|
||||
ForceQuitDelay: 10 * time.Second,
|
||||
SessionCloseRange: 1,
|
||||
MaxTasksInMemQueue: 2,
|
||||
}
|
||||
|
||||
h.clientCfg.BuildBreachRetribution = func(id lnwire.ChannelID,
|
||||
@ -1094,10 +1095,6 @@ var clientTests = []clientTest{
|
||||
hints := h.advanceChannelN(chanID, numUpdates)
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Stop the client in the background, to assert the
|
||||
// pipeline is always flushed before it exits.
|
||||
go h.client.Stop()
|
||||
|
||||
// Wait for all the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, time.Second)
|
||||
@ -1238,10 +1235,6 @@ var clientTests = []clientTest{
|
||||
// Now, queue the retributions for backup.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Stop the client in the background, to assert the
|
||||
// pipeline is always flushed before it exits.
|
||||
go h.client.Stop()
|
||||
|
||||
// Give the client time to saturate a large number of
|
||||
// session queues for which the server has not acked the
|
||||
// state updates that it has received.
|
||||
@ -1346,9 +1339,6 @@ var clientTests = []clientTest{
|
||||
h.backupStates(id, 0, numUpdates, nil)
|
||||
}
|
||||
|
||||
// Test reliable flush under multi-client scenario.
|
||||
go h.client.Stop()
|
||||
|
||||
// Wait for all the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, 10*time.Second)
|
||||
@ -1395,9 +1385,6 @@ var clientTests = []clientTest{
|
||||
// identical one.
|
||||
h.startClient()
|
||||
|
||||
// Now, queue the retributions for backup.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Wait for all the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, waitTime)
|
||||
@ -1449,9 +1436,6 @@ var clientTests = []clientTest{
|
||||
h.clientCfg.Policy.SweepFeeRate *= 2
|
||||
h.startClient()
|
||||
|
||||
// Now, queue the retributions for backup.
|
||||
h.backupStates(chanID, 0, numUpdates, nil)
|
||||
|
||||
// Wait for all the updates to be populated in the
|
||||
// server's database.
|
||||
h.waitServerUpdates(hints, waitTime)
|
||||
@ -2037,7 +2021,7 @@ var clientTests = []clientTest{
|
||||
},
|
||||
},
|
||||
{
|
||||
// Demonstrate that the client is unable to recover after
|
||||
// Demonstrate that the client is able to recover after
|
||||
// deleting its database by skipping through key indices until
|
||||
// it gets to one that does not result in the
|
||||
// CreateSessionCodeAlreadyExists error code being returned from
|
||||
@ -2088,6 +2072,51 @@ var clientTests = []clientTest{
|
||||
h.waitServerUpdates(hints[numUpdates/2:], waitTime)
|
||||
},
|
||||
},
|
||||
{
|
||||
// This test demonstrates that if there is no active session,
|
||||
// the updates are persisted to disk on restart and reliably
|
||||
// sent.
|
||||
name: "in-mem updates not lost on restart",
|
||||
cfg: harnessCfg{
|
||||
localBalance: localBalance,
|
||||
remoteBalance: remoteBalance,
|
||||
policy: wtpolicy.Policy{
|
||||
TxPolicy: defaultTxPolicy,
|
||||
MaxUpdates: 5,
|
||||
},
|
||||
// noServerStart ensures that the server does not
|
||||
// automatically start on creation of the test harness.
|
||||
// This ensures that the client does not initially have
|
||||
// any active sessions.
|
||||
noServerStart: true,
|
||||
},
|
||||
fn: func(h *testHarness) {
|
||||
const (
|
||||
chanID = 0
|
||||
numUpdates = 5
|
||||
)
|
||||
|
||||
// Try back up the first few states of the client's
|
||||
// channel. Since the server has not yet started, the
|
||||
// client should have no active session yet and so these
|
||||
// updates will just be kept in an in-memory queue.
|
||||
hints := h.advanceChannelN(chanID, numUpdates)
|
||||
|
||||
h.backupStates(chanID, 0, numUpdates/2, nil)
|
||||
|
||||
// Restart the Client (force quit). And also now start
|
||||
// the server.
|
||||
h.client.ForceQuit()
|
||||
h.startServer()
|
||||
h.startClient()
|
||||
|
||||
// Back up a few more states.
|
||||
h.backupStates(chanID, numUpdates/2, numUpdates, nil)
|
||||
|
||||
// Assert that the server does receive ALL the updates.
|
||||
h.waitServerUpdates(hints[0:numUpdates], waitTime)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestClient executes the client test suite, asserting the ability to backup
|
||||
|
@ -131,6 +131,10 @@ type DB interface {
|
||||
// update identified by seqNum was received and saved. The returned
|
||||
// lastApplied will be recorded.
|
||||
AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error
|
||||
|
||||
// GetDBQueue returns a BackupID Queue instance under the given name
|
||||
// space.
|
||||
GetDBQueue(namespace []byte) wtdb.Queue[*wtdb.BackupID]
|
||||
}
|
||||
|
||||
// AuthDialer connects to a remote node using an authenticated transport, such
|
||||
|
566
watchtower/wtclient/queue.go
Normal file
566
watchtower/wtclient/queue.go
Normal file
@ -0,0 +1,566 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
const (
|
||||
// dbErrorBackoff is the length of time we will back off before retrying
|
||||
// any DB action that failed.
|
||||
dbErrorBackoff = time.Second * 5
|
||||
)
|
||||
|
||||
// internalTask wraps a BackupID task with a success channel.
|
||||
type internalTask[T any] struct {
|
||||
task T
|
||||
success chan bool
|
||||
}
|
||||
|
||||
// newInternalTask creates a new internalTask with the given task.
|
||||
func newInternalTask[T any](task T) *internalTask[T] {
|
||||
return &internalTask[T]{
|
||||
task: task,
|
||||
success: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
// DiskOverflowQueue is a queue that must be initialised with a certain maximum
|
||||
// buffer size which represents the maximum number of elements that the queue
|
||||
// should hold in memory. If the queue is full, then any new elements added to
|
||||
// the queue will be persisted to disk instead. Once a consumer starts reading
|
||||
// from the front of the queue again then items on disk will be moved into the
|
||||
// queue again. The queue is also re-start safe. When it is stopped, any items
|
||||
// in the memory queue, will be persisted to disk. On start up, the queue will
|
||||
// be re-initialised with the items on disk.
|
||||
type DiskOverflowQueue[T any] struct {
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
log btclog.Logger
|
||||
|
||||
// db is the database that will be used to persist queue items to disk.
|
||||
db wtdb.Queue[T]
|
||||
|
||||
// toDisk represents the current mode of operation of the queue.
|
||||
toDisk atomic.Bool
|
||||
|
||||
// We used an unbound list for the input of the queue so that producers
|
||||
// putting items into the queue are never blocked.
|
||||
inputListMu sync.Mutex
|
||||
inputListCond *sync.Cond
|
||||
inputList *list.List
|
||||
|
||||
// inputChan is an unbuffered channel used to pass items from
|
||||
// drainInputList to feedMemQueue.
|
||||
inputChan chan *internalTask[T]
|
||||
|
||||
// memQueue is a buffered channel used to pass items from
|
||||
// feedMemQueue to feedOutputChan.
|
||||
memQueue chan T
|
||||
|
||||
// outputChan is an unbuffered channel from which items at the head of
|
||||
// the queue can be read.
|
||||
outputChan chan T
|
||||
|
||||
// newDiskItemSignal is used to signal that there is a new item in the
|
||||
// main disk queue. There should only be one reader and one writer for
|
||||
// this channel.
|
||||
newDiskItemSignal chan struct{}
|
||||
|
||||
// leftOverItem1 will be a non-nil task on shutdown if the
|
||||
// feedOutputChan method was holding an unhandled tasks at shutdown
|
||||
// time. Since feedOutputChan handles the very head of the queue, this
|
||||
// item should be the first to be reloaded on restart.
|
||||
leftOverItem1 *T
|
||||
|
||||
// leftOverItems2 will be non-empty on shutdown if the feedMemQueue
|
||||
// method was holding any unhandled tasks at shutdown time. Since
|
||||
// feedMemQueue manages the input to the queue, the tasks should be
|
||||
// pushed to the head of the disk queue.
|
||||
leftOverItems2 []T
|
||||
|
||||
// leftOverItem3 will be non-nil on shutdown if drainInputList was
|
||||
// holding an unhandled task at shutdown time. This task should be put
|
||||
// at the tail of the disk queue but should come before any input list
|
||||
// task.
|
||||
leftOverItem3 *T
|
||||
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewDiskOverflowQueue constructs a new DiskOverflowQueue.
|
||||
func NewDiskOverflowQueue[T any](db wtdb.Queue[T], maxQueueSize uint64,
|
||||
logger btclog.Logger) (*DiskOverflowQueue[T], error) {
|
||||
|
||||
if maxQueueSize < 2 {
|
||||
return nil, errors.New("the in-memory queue buffer size " +
|
||||
"must be larger than 2")
|
||||
}
|
||||
|
||||
q := &DiskOverflowQueue[T]{
|
||||
log: logger,
|
||||
db: db,
|
||||
inputList: list.New(),
|
||||
newDiskItemSignal: make(chan struct{}, 1),
|
||||
inputChan: make(chan *internalTask[T]),
|
||||
memQueue: make(chan T, maxQueueSize-2),
|
||||
outputChan: make(chan T),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
q.inputListCond = sync.NewCond(&q.inputListMu)
|
||||
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// Start kicks off all the goroutines that are required to manage the queue.
|
||||
func (q *DiskOverflowQueue[T]) Start() error {
|
||||
var err error
|
||||
q.startOnce.Do(func() {
|
||||
err = q.start()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// start kicks off all the goroutines that are required to manage the queue.
|
||||
func (q *DiskOverflowQueue[T]) start() error {
|
||||
numDisk, err := q.db.Len()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if numDisk != 0 {
|
||||
q.toDisk.Store(true)
|
||||
}
|
||||
|
||||
// Kick off the three goroutines which will handle the input list, the
|
||||
// in-memory queue and the output channel.
|
||||
// The three goroutines are moving items according to the following
|
||||
// diagram:
|
||||
//
|
||||
// ┌─────────┐ drainInputList ┌──────────┐
|
||||
// │inputList├─────┬──────────►│disk/db │
|
||||
// └─────────┘ │ └──────────┘
|
||||
// │ (depending on mode)
|
||||
// │ ┌──────────┐
|
||||
// └──────────►│inputChan │
|
||||
// └──────────┘
|
||||
//
|
||||
// ┌─────────┐ feedMemQueue ┌──────────┐
|
||||
// │disk/db ├───────┬────────►│memQueue │
|
||||
// └─────────┘ │ └──────────┘
|
||||
// │ (depending on mode)
|
||||
// ┌─────────┐ │
|
||||
// │inputChan├───────┘
|
||||
// └─────────┘
|
||||
//
|
||||
// ┌─────────┐ feedOutputChan ┌──────────┐
|
||||
// │memQueue ├────────────────►│outputChan│
|
||||
// └─────────┘ └──────────┘
|
||||
//
|
||||
q.wg.Add(3)
|
||||
go q.drainInputList()
|
||||
go q.feedMemQueue()
|
||||
go q.feedOutputChan()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the queue and persists any items in the memory queue to disk.
|
||||
func (q *DiskOverflowQueue[T]) Stop() error {
|
||||
var err error
|
||||
q.stopOnce.Do(func() {
|
||||
err = q.stop()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// stop the queue and persists any items in the memory queue to disk.
|
||||
func (q *DiskOverflowQueue[T]) stop() error {
|
||||
close(q.quit)
|
||||
|
||||
// Signal on the inputListCond until all the goroutines have returned.
|
||||
shutdown := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Millisecond):
|
||||
q.inputListCond.Signal()
|
||||
case <-shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
q.wg.Wait()
|
||||
close(shutdown)
|
||||
|
||||
// queueHead will be the items that we will be pushed to the head of
|
||||
// the queue.
|
||||
var queueHead []T
|
||||
|
||||
// First, we append leftOverItem1 since this task is the current head
|
||||
// of the queue.
|
||||
if q.leftOverItem1 != nil {
|
||||
queueHead = append(queueHead, *q.leftOverItem1)
|
||||
}
|
||||
|
||||
// Next, drain the buffered queue.
|
||||
for {
|
||||
task, ok := <-q.memQueue
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
queueHead = append(queueHead, task)
|
||||
}
|
||||
|
||||
// Then, any items held in leftOverItems2 would have been next to join
|
||||
// the memQueue. So those gets added next.
|
||||
if len(q.leftOverItems2) != 0 {
|
||||
queueHead = append(queueHead, q.leftOverItems2...)
|
||||
}
|
||||
|
||||
// Now, push these items to the head of the queue.
|
||||
err := q.db.PushHead(queueHead...)
|
||||
if err != nil {
|
||||
q.log.Errorf("Could not add tasks to queue head: %v", err)
|
||||
}
|
||||
|
||||
// Next we handle any items that need to be added to the main disk
|
||||
// queue.
|
||||
var diskQueue []T
|
||||
|
||||
// Any item in leftOverItem3 is the first item that should join the
|
||||
// disk queue.
|
||||
if q.leftOverItem3 != nil {
|
||||
diskQueue = append(diskQueue, *q.leftOverItem3)
|
||||
}
|
||||
|
||||
// Lastly, drain any items in the unbuffered input list.
|
||||
q.inputListCond.L.Lock()
|
||||
for q.inputList.Front() != nil {
|
||||
e := q.inputList.Front()
|
||||
|
||||
//nolint:forcetypeassert
|
||||
task := q.inputList.Remove(e).(T)
|
||||
|
||||
diskQueue = append(diskQueue, task)
|
||||
}
|
||||
q.inputListCond.L.Unlock()
|
||||
|
||||
// Now persist these items to the main disk queue.
|
||||
err = q.db.Push(diskQueue...)
|
||||
if err != nil {
|
||||
q.log.Errorf("Could not add tasks to queue tail: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueBackupID adds a wtdb.BackupID to the queue. It will only return an error
|
||||
// if the queue has been stopped. It is non-blocking.
|
||||
func (q *DiskOverflowQueue[T]) QueueBackupID(item *wtdb.BackupID) error {
|
||||
// Return an error if the queue has been stopped
|
||||
select {
|
||||
case <-q.quit:
|
||||
return ErrClientExiting
|
||||
default:
|
||||
}
|
||||
|
||||
// Add the new item to the unbound input list.
|
||||
q.inputListCond.L.Lock()
|
||||
q.inputList.PushBack(item)
|
||||
q.inputListCond.L.Unlock()
|
||||
|
||||
// Signal that there is a new item in the input list.
|
||||
q.inputListCond.Signal()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextBackupID can be used to read from the head of the DiskOverflowQueue.
|
||||
func (q *DiskOverflowQueue[T]) NextBackupID() <-chan T {
|
||||
return q.outputChan
|
||||
}
|
||||
|
||||
// drainInputList handles the input to the DiskOverflowQueue. It takes from the
|
||||
// un-bounded input list and then, depending on what mode the queue is in,
|
||||
// either puts the new item straight onto the persisted disk queue or attempts
|
||||
// to feed it into the memQueue. On exit, any unhandled task will be assigned to
|
||||
// leftOverItem3.
|
||||
func (q *DiskOverflowQueue[T]) drainInputList() {
|
||||
defer q.wg.Done()
|
||||
|
||||
for {
|
||||
// Wait for the input list to not be empty.
|
||||
q.inputListCond.L.Lock()
|
||||
for q.inputList.Front() == nil {
|
||||
q.inputListCond.Wait()
|
||||
|
||||
select {
|
||||
case <-q.quit:
|
||||
q.inputListCond.L.Unlock()
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Pop the first element from the queue.
|
||||
e := q.inputList.Front()
|
||||
|
||||
//nolint:forcetypeassert
|
||||
task := q.inputList.Remove(e).(T)
|
||||
q.inputListCond.L.Unlock()
|
||||
|
||||
// What we do with this new item depends on what the mode of the
|
||||
// queue currently is.
|
||||
for q.pushToActiveQueue(task) {
|
||||
}
|
||||
|
||||
// If the above returned false because the quit channel was
|
||||
// closed, then we exit.
|
||||
select {
|
||||
case <-q.quit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pushToActiveQueue handles the input of a new task to the queue. It returns
|
||||
// true if the task should be retried and false if the task was handled or the
|
||||
// quit channel fired.
|
||||
func (q *DiskOverflowQueue[T]) pushToActiveQueue(task T) bool {
|
||||
// If the queue is in disk mode then any new items should be put
|
||||
// straight into the disk queue.
|
||||
if q.toDisk.Load() {
|
||||
err := q.db.Push(task)
|
||||
if err != nil {
|
||||
// Log and back off for a few seconds and then
|
||||
// try again with the same task.
|
||||
q.log.Errorf("could not persist %s to disk. "+
|
||||
"Retrying after backoff", task)
|
||||
|
||||
select {
|
||||
// Backoff for a bit and then re-check the mode
|
||||
// and try again to handle the task.
|
||||
case <-time.After(dbErrorBackoff):
|
||||
return true
|
||||
|
||||
// If the queue is quit at this moment, then the
|
||||
// unhandled task is assigned to leftOverItem3
|
||||
// so that it can be handled by the stop method.
|
||||
case <-q.quit:
|
||||
q.leftOverItem3 = &task
|
||||
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Send a signal that there is a new item in the main
|
||||
// disk queue.
|
||||
select {
|
||||
case q.newDiskItemSignal <- struct{}{}:
|
||||
case <-q.quit:
|
||||
|
||||
// Because there might already be a signal in the
|
||||
// newDiskItemSignal channel, we can skip sending another
|
||||
// signal. The channel only has a buffer of one, so we would
|
||||
// block here if we didn't have a default case.
|
||||
default:
|
||||
}
|
||||
|
||||
// If we got here, we were able to store the task in the disk
|
||||
// queue, so we can return false as no retry is necessary.
|
||||
return false
|
||||
}
|
||||
|
||||
// If the mode is memory mode, then try feed it to the feedMemQueue
|
||||
// handler via the un-buffered inputChan channel. We wrap it in an
|
||||
// internal task so that we can find out if feedMemQueue successfully
|
||||
// handled the item. If it did, we continue in memory mode and if not,
|
||||
// then we switch to disk mode so that we can persist the item to the
|
||||
// disk queue instead.
|
||||
it := newInternalTask(task)
|
||||
|
||||
select {
|
||||
// Try feed the task to the feedMemQueue handler. The handler, if it
|
||||
// does take the task, is guaranteed to respond via the success channel
|
||||
// of the task to indicate if the task was successfully added to the
|
||||
// in-mem queue. This is guaranteed even if the queue is being stopped.
|
||||
case q.inputChan <- it:
|
||||
|
||||
// If the queue is quit at this moment, then the unhandled task is
|
||||
// assigned to leftOverItem3 so that it can be handled by the stop
|
||||
// method.
|
||||
case <-q.quit:
|
||||
q.leftOverItem3 = &task
|
||||
|
||||
return false
|
||||
|
||||
default:
|
||||
// The task was not accepted. So maybe the mode changed.
|
||||
return true
|
||||
}
|
||||
|
||||
// If we get here, it means that the feedMemQueue handler took the task.
|
||||
// It is guaranteed to respond via the success channel, so we wait for
|
||||
// that response here.
|
||||
s := <-it.success
|
||||
if s {
|
||||
return false
|
||||
}
|
||||
|
||||
// If the task was not successfully handled by feedMemQueue, then we
|
||||
// switch to disk mode so that the task can be persisted in the disk
|
||||
// queue instead.
|
||||
q.toDisk.Store(true)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// feedMemQueue manages which items should be fed onto the buffered
|
||||
// memQueue. If the queue is then in disk mode, then the handler will read new
|
||||
// tasks from the disk queue until it is empty. After that, it will switch
|
||||
// between reading from the input channel or the disk queue depending on the
|
||||
// queue mode.
|
||||
func (q *DiskOverflowQueue[T]) feedMemQueue() {
|
||||
defer func() {
|
||||
close(q.memQueue)
|
||||
q.wg.Done()
|
||||
}()
|
||||
|
||||
feedFromDisk := func() {
|
||||
select {
|
||||
case <-q.quit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
for {
|
||||
// Ideally, we want to do batch reads from the DB. So
|
||||
// we check how much capacity there is in the memQueue
|
||||
// and fetch enough tasks to fill that capacity. If
|
||||
// there is no capacity, however, then we at least want
|
||||
// to fetch one task.
|
||||
numToPop := cap(q.memQueue) - len(q.memQueue)
|
||||
if numToPop == 0 {
|
||||
numToPop = 1
|
||||
}
|
||||
|
||||
tasks, err := q.db.PopUpTo(numToPop)
|
||||
if errors.Is(err, wtdb.ErrEmptyQueue) {
|
||||
q.toDisk.Store(false)
|
||||
|
||||
return
|
||||
} else if err != nil {
|
||||
q.log.Errorf("Could not load next task from " +
|
||||
"disk. Retrying.")
|
||||
|
||||
select {
|
||||
case <-time.After(dbErrorBackoff):
|
||||
continue
|
||||
case <-q.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for i, task := range tasks {
|
||||
select {
|
||||
case q.memQueue <- task:
|
||||
|
||||
// If the queue is quit at this moment, then the
|
||||
// unhandled tasks are assigned to
|
||||
// leftOverItems2 so that they can be handled
|
||||
// by the stop method.
|
||||
case <-q.quit:
|
||||
q.leftOverItems2 = tasks[i:]
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the queue is in disk mode, then the memQueue is fed with tasks
|
||||
// from the disk queue until it is empty.
|
||||
if q.toDisk.Load() {
|
||||
feedFromDisk()
|
||||
}
|
||||
|
||||
// Now the queue enters its normal operation.
|
||||
for {
|
||||
select {
|
||||
case <-q.quit:
|
||||
return
|
||||
|
||||
// If there is a signal that a new item has been added to disk
|
||||
// then we use the disk queue as the source of the next task
|
||||
// to feed into memQueue.
|
||||
case <-q.newDiskItemSignal:
|
||||
feedFromDisk()
|
||||
|
||||
// If any items come through on the inputChan, then we try feed
|
||||
// these directly into the memQueue. If there is space in the
|
||||
// memeQueue then we respond with success to the producer,
|
||||
// otherwise we respond with failure so that the producer can
|
||||
// instead persist the task to disk. After the producer,
|
||||
// drainInputList, has pushed an item to inputChan, it is
|
||||
// guaranteed to await a response on the task's success channel
|
||||
// before quiting. Therefore, it is not required to listen on
|
||||
// the quit channel here.
|
||||
case task := <-q.inputChan:
|
||||
select {
|
||||
case q.memQueue <- task.task:
|
||||
task.success <- true
|
||||
continue
|
||||
default:
|
||||
task.success <- false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// feedOutputChan will pop an item from the buffered memQueue and block until
|
||||
// the item is taken from the un-buffered outputChan. This is done repeatedly
|
||||
// for the lifetime of the DiskOverflowQueue. On shutdown of the queue, any
|
||||
// item not consumed by the outputChan but held by this method is assigned to
|
||||
// the leftOverItem1 member so that the Stop method can persist the item to
|
||||
// disk so that it is reloaded on restart.
|
||||
//
|
||||
// NOTE: This must be run as a goroutine.
|
||||
func (q *DiskOverflowQueue[T]) feedOutputChan() {
|
||||
defer func() {
|
||||
close(q.outputChan)
|
||||
q.wg.Done()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case nextTask, ok := <-q.memQueue:
|
||||
// If the memQueue is closed, then the queue is
|
||||
// stopping.
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case q.outputChan <- nextTask:
|
||||
case <-q.quit:
|
||||
q.leftOverItem1 = &nextTask
|
||||
return
|
||||
}
|
||||
|
||||
case <-q.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
435
watchtower/wtclient/queue_test.go
Normal file
435
watchtower/wtclient/queue_test.go
Normal file
@ -0,0 +1,435 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lntest/wait"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
maxInMemItems = 5
|
||||
waitTime = time.Second * 2
|
||||
)
|
||||
|
||||
type initQueue func(t *testing.T) wtdb.Queue[*wtdb.BackupID]
|
||||
|
||||
// TestDiskOverflowQueue tests that the DiskOverflowQueue behaves as expected.
|
||||
func TestDiskOverflowQueue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbs := []struct {
|
||||
name string
|
||||
init initQueue
|
||||
}{
|
||||
{
|
||||
name: "kvdb",
|
||||
init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] {
|
||||
dbCfg := &kvdb.BoltConfig{
|
||||
DBTimeout: kvdb.DefaultDBTimeout,
|
||||
}
|
||||
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, t.TempDir(), "wtclient.db",
|
||||
)(dbCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
})
|
||||
|
||||
return db.GetDBQueue([]byte("test-namespace"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mock",
|
||||
init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] {
|
||||
db := wtmock.NewClientDB()
|
||||
|
||||
return db.GetDBQueue([]byte("test-namespace"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*testing.T, initQueue)
|
||||
}{
|
||||
{
|
||||
name: "overflow to disk",
|
||||
run: testOverflowToDisk,
|
||||
},
|
||||
{
|
||||
name: "startup with smaller buffer size",
|
||||
run: testRestartWithSmallerBufferSize,
|
||||
},
|
||||
{
|
||||
name: "start stop queue",
|
||||
run: testStartStopQueue,
|
||||
},
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
db := database
|
||||
t.Run(db.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
test.run(t, db.init)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testOverflowToDisk is a basic test that ensures that the queue correctly
|
||||
// overflows items to disk and then correctly reloads them.
|
||||
func testOverflowToDisk(t *testing.T, initQueue initQueue) {
|
||||
// Generate some backup IDs that we want to add to the queue.
|
||||
tasks := genBackupIDs(10)
|
||||
|
||||
// Init the DB.
|
||||
db := initQueue(t)
|
||||
|
||||
// New mock logger.
|
||||
log := newMockLogger(t.Logf)
|
||||
|
||||
// Init the queue with the mock DB.
|
||||
q, err := NewDiskOverflowQueue[*wtdb.BackupID](
|
||||
db, maxInMemItems, log,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start the queue.
|
||||
require.NoError(t, q.Start())
|
||||
|
||||
// Initially there should be no items on disk.
|
||||
assertNumDisk(t, db, 0)
|
||||
|
||||
// Start filling up the queue.
|
||||
enqueue(t, q, tasks[0])
|
||||
enqueue(t, q, tasks[1])
|
||||
enqueue(t, q, tasks[2])
|
||||
enqueue(t, q, tasks[3])
|
||||
enqueue(t, q, tasks[4])
|
||||
|
||||
// The queue should now be full, so any new items should be persisted to
|
||||
// disk.
|
||||
enqueue(t, q, tasks[5])
|
||||
waitForNumDisk(t, db, 1)
|
||||
|
||||
// Now pop all items from the queue to ensure that the item
|
||||
// from disk is loaded in properly once there is space.
|
||||
require.Equal(t, tasks[0], getNext(t, q, 0))
|
||||
require.Equal(t, tasks[1], getNext(t, q, 1))
|
||||
require.Equal(t, tasks[2], getNext(t, q, 2))
|
||||
require.Equal(t, tasks[3], getNext(t, q, 3))
|
||||
require.Equal(t, tasks[4], getNext(t, q, 4))
|
||||
require.Equal(t, tasks[5], getNext(t, q, 5))
|
||||
|
||||
// There should no longer be any items in the disk queue.
|
||||
assertNumDisk(t, db, 0)
|
||||
|
||||
require.NoError(t, q.Stop())
|
||||
}
|
||||
|
||||
// testRestartWithSmallerBufferSize tests that if the queue is restarted with
|
||||
// a smaller in-memory buffer size that it was initially started with, then
|
||||
// tasks are still loaded in the correct order.
|
||||
func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) {
|
||||
const (
|
||||
firstMaxInMemItems = 5
|
||||
secondMaxInMemItems = 2
|
||||
)
|
||||
|
||||
// Generate some backup IDs that we want to add to the queue.
|
||||
tasks := genBackupIDs(10)
|
||||
|
||||
// Create a db.
|
||||
db := newQueue(t)
|
||||
|
||||
// New mock logger.
|
||||
log := newMockLogger(t.Logf)
|
||||
|
||||
// Init the queue with the mock DB and an initial max in-mem
|
||||
// items number.
|
||||
q, err := NewDiskOverflowQueue[*wtdb.BackupID](
|
||||
db, firstMaxInMemItems, log,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, q.Start())
|
||||
|
||||
// Add 7 items to the queue. The first 5 will go into the in-mem
|
||||
// queue, the other 2 will be persisted to the main disk queue.
|
||||
enqueue(t, q, tasks[0])
|
||||
enqueue(t, q, tasks[1])
|
||||
enqueue(t, q, tasks[2])
|
||||
enqueue(t, q, tasks[3])
|
||||
enqueue(t, q, tasks[4])
|
||||
enqueue(t, q, tasks[5])
|
||||
enqueue(t, q, tasks[6])
|
||||
|
||||
waitForNumDisk(t, db, 2)
|
||||
|
||||
// Now stop the queue and re-initialise it with a smaller
|
||||
// buffer maximum.
|
||||
require.NoError(t, q.Stop())
|
||||
|
||||
// Check that there are now 7 items in the disk queue.
|
||||
waitForNumDisk(t, db, 7)
|
||||
|
||||
// Re-init the queue with a smaller max buffer size.
|
||||
q, err = NewDiskOverflowQueue[*wtdb.BackupID](
|
||||
db, secondMaxInMemItems, log,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, q.Start())
|
||||
|
||||
// Once more we shall repeat the above restart process just to ensure
|
||||
// that in-memory items are correctly re-written and read from the DB.
|
||||
waitForNumDisk(t, db, 5)
|
||||
require.NoError(t, q.Stop())
|
||||
waitForNumDisk(t, db, 7)
|
||||
q, err = NewDiskOverflowQueue[*wtdb.BackupID](
|
||||
db, secondMaxInMemItems, log,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, q.Start())
|
||||
waitForNumDisk(t, db, 5)
|
||||
|
||||
// Make sure that items are popped off the queue in the correct
|
||||
// order.
|
||||
require.Equal(t, tasks[0], getNext(t, q, 0))
|
||||
require.Equal(t, tasks[1], getNext(t, q, 1))
|
||||
require.Equal(t, tasks[2], getNext(t, q, 2))
|
||||
require.Equal(t, tasks[3], getNext(t, q, 3))
|
||||
require.Equal(t, tasks[4], getNext(t, q, 4))
|
||||
require.Equal(t, tasks[5], getNext(t, q, 5))
|
||||
require.Equal(t, tasks[6], getNext(t, q, 6))
|
||||
|
||||
require.NoError(t, q.Stop())
|
||||
}
|
||||
|
||||
// testStartStopQueue is a stress test that pushes a large number of tasks
|
||||
// through the queue while also restarting the queue a couple of times
|
||||
// throughout.
|
||||
func testStartStopQueue(t *testing.T, newQueue initQueue) {
|
||||
// Generate a lot of backup IDs that we want to add to the
|
||||
// queue one after the other.
|
||||
tasks := genBackupIDs(200_000)
|
||||
|
||||
// Construct the ClientDB.
|
||||
db := newQueue(t)
|
||||
|
||||
// New mock logger.
|
||||
log := newMockLogger(t.Logf)
|
||||
|
||||
// Init the queue with the mock DB.
|
||||
q, err := NewDiskOverflowQueue[*wtdb.BackupID](
|
||||
db, DefaultMaxTasksInMemQueue, log,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start the queue.
|
||||
require.NoError(t, q.Start())
|
||||
|
||||
// Initially there should be no items on disk.
|
||||
assertNumDisk(t, db, 0)
|
||||
|
||||
// We need to guard the queue with a mutex since we will be
|
||||
// stopping, re-creating and starting the queue multiple times.
|
||||
var (
|
||||
queueMtx sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
sendDone = make(chan struct{})
|
||||
)
|
||||
|
||||
// This goroutine will constantly try to add new items to the
|
||||
// queue, even if the queue is stopped.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for idx := range tasks {
|
||||
queueMtx.RLock()
|
||||
err := q.QueueBackupID(tasks[idx])
|
||||
require.NoError(t, err)
|
||||
queueMtx.RUnlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// This goroutine will repeatedly stop, re-create and start the
|
||||
// queue until we're done sending items.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
numRestarts := 0
|
||||
for {
|
||||
select {
|
||||
case <-sendDone:
|
||||
t.Logf("Restarted queue %d times",
|
||||
numRestarts)
|
||||
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
queueMtx.Lock()
|
||||
require.NoError(t, q.Stop())
|
||||
q, err = NewDiskOverflowQueue[*wtdb.BackupID](
|
||||
db, DefaultMaxTasksInMemQueue, log,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, q.Start())
|
||||
queueMtx.Unlock()
|
||||
|
||||
numRestarts++
|
||||
}
|
||||
}()
|
||||
|
||||
// We should be able to read all items from the queue, not being
|
||||
// affected by restarts, other than needing to wait for the
|
||||
// queue to be started again.
|
||||
results := make([]*wtdb.BackupID, 0, len(tasks))
|
||||
for i := 0; i < len(tasks); i++ {
|
||||
queueMtx.RLock()
|
||||
task := getNext(t, q, i)
|
||||
queueMtx.RUnlock()
|
||||
|
||||
results = append(results, task)
|
||||
}
|
||||
close(sendDone)
|
||||
require.Equal(t, tasks, results)
|
||||
|
||||
require.NoError(t, q.Stop())
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func getNext(t *testing.T, q *DiskOverflowQueue[*wtdb.BackupID],
|
||||
i int) *wtdb.BackupID {
|
||||
|
||||
var item *wtdb.BackupID
|
||||
select {
|
||||
case item = <-q.NextBackupID():
|
||||
case <-time.After(waitTime):
|
||||
t.Fatalf("task %d not received in time", i)
|
||||
}
|
||||
|
||||
return item
|
||||
}
|
||||
|
||||
func enqueue(t *testing.T, q *DiskOverflowQueue[*wtdb.BackupID],
|
||||
task *wtdb.BackupID) {
|
||||
|
||||
err := q.QueueBackupID(task)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func waitForNumDisk(t *testing.T, db wtdb.Queue[*wtdb.BackupID], num int) {
|
||||
err := wait.Predicate(func() bool {
|
||||
n, err := db.Len()
|
||||
require.NoError(t, err)
|
||||
|
||||
return n == uint64(num)
|
||||
}, waitTime)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func assertNumDisk(t *testing.T, db wtdb.Queue[*wtdb.BackupID], num int) {
|
||||
n, err := db.Len()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, num, n)
|
||||
}
|
||||
|
||||
func genBackupIDs(num int) []*wtdb.BackupID {
|
||||
ids := make([]*wtdb.BackupID, num)
|
||||
for i := 0; i < num; i++ {
|
||||
ids[i] = newBackupID(i)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func newBackupID(id int) *wtdb.BackupID {
|
||||
return &wtdb.BackupID{CommitHeight: uint64(id)}
|
||||
}
|
||||
|
||||
// BenchmarkDiskOverflowQueue benchmarks the performance of adding and removing
|
||||
// items from the DiskOverflowQueue using an in-memory disk db.
|
||||
func BenchmarkDiskOverflowQueue(b *testing.B) {
|
||||
enqueue := func(q *DiskOverflowQueue[*wtdb.BackupID],
|
||||
task *wtdb.BackupID) {
|
||||
|
||||
err := q.QueueBackupID(task)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
getNext := func(q *DiskOverflowQueue[*wtdb.BackupID],
|
||||
i int) *wtdb.BackupID {
|
||||
|
||||
var item *wtdb.BackupID
|
||||
select {
|
||||
case item = <-q.NextBackupID():
|
||||
case <-time.After(time.Second * 2):
|
||||
b.Fatalf("task %d not received in time", i)
|
||||
}
|
||||
|
||||
return item
|
||||
}
|
||||
|
||||
// Generate some backup IDs that we want to add to the queue.
|
||||
tasks := genBackupIDs(b.N)
|
||||
|
||||
// Create a mock db.
|
||||
db := wtmock.NewQueueDB[*wtdb.BackupID]()
|
||||
|
||||
// New mock logger.
|
||||
log := newMockLogger(b.Logf)
|
||||
|
||||
// Init the queue with the mock DB.
|
||||
q, err := NewDiskOverflowQueue[*wtdb.BackupID](db, 5, log)
|
||||
require.NoError(b, err)
|
||||
|
||||
// Start the queue.
|
||||
require.NoError(b, q.Start())
|
||||
|
||||
// Start filling up the queue.
|
||||
for n := 0; n < b.N; n++ {
|
||||
enqueue(q, tasks[n])
|
||||
}
|
||||
|
||||
// Pop all the items off the queue.
|
||||
for n := 0; n < b.N; n++ {
|
||||
require.Equal(b, tasks[n], getNext(q, n))
|
||||
}
|
||||
|
||||
require.NoError(b, q.Stop())
|
||||
}
|
||||
|
||||
type mockLogger struct {
|
||||
log func(string, ...any)
|
||||
|
||||
btclog.Logger
|
||||
}
|
||||
|
||||
func newMockLogger(logger func(string, ...any)) *mockLogger {
|
||||
return &mockLogger{log: logger}
|
||||
}
|
||||
|
||||
// Errorf formats message according to format specifier and writes to log.
|
||||
//
|
||||
// NOTE: this is part of the btclog.Logger interface.
|
||||
func (l *mockLogger) Errorf(format string, params ...any) {
|
||||
l.log("[ERR]: "+format, params...)
|
||||
}
|
@ -1,198 +0,0 @@
|
||||
package wtclient
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
// taskPipeline implements a reliable, in-order queue that ensures its queue
|
||||
// fully drained before exiting. Stopping the taskPipeline prevents the pipeline
|
||||
// from accepting any further tasks, and will cause the pipeline to exit after
|
||||
// all updates have been delivered to the downstream receiver. If this process
|
||||
// hangs and is unable to make progress, users can optionally call ForceQuit to
|
||||
// abandon the reliable draining of the queue in order to permit shutdown.
|
||||
type taskPipeline struct {
|
||||
started sync.Once
|
||||
stopped sync.Once
|
||||
forced sync.Once
|
||||
|
||||
log btclog.Logger
|
||||
|
||||
queueMtx sync.Mutex
|
||||
queueCond *sync.Cond
|
||||
queue *list.List
|
||||
|
||||
newBackupTasks chan *wtdb.BackupID
|
||||
|
||||
quit chan struct{}
|
||||
forceQuit chan struct{}
|
||||
shutdown chan struct{}
|
||||
}
|
||||
|
||||
// newTaskPipeline initializes a new taskPipeline.
|
||||
func newTaskPipeline(log btclog.Logger) *taskPipeline {
|
||||
rq := &taskPipeline{
|
||||
log: log,
|
||||
queue: list.New(),
|
||||
newBackupTasks: make(chan *wtdb.BackupID),
|
||||
quit: make(chan struct{}),
|
||||
forceQuit: make(chan struct{}),
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
rq.queueCond = sync.NewCond(&rq.queueMtx)
|
||||
|
||||
return rq
|
||||
}
|
||||
|
||||
// Start spins up the taskPipeline, making it eligible to begin receiving backup
|
||||
// tasks and deliver them to the receiver of NewBackupTasks.
|
||||
func (q *taskPipeline) Start() {
|
||||
q.started.Do(func() {
|
||||
go q.queueManager()
|
||||
})
|
||||
}
|
||||
|
||||
// Stop begins a graceful shutdown of the taskPipeline. This method returns once
|
||||
// all backupTasks have been delivered via NewBackupTasks, or a ForceQuit causes
|
||||
// the delivery of pending tasks to be interrupted.
|
||||
func (q *taskPipeline) Stop() {
|
||||
q.stopped.Do(func() {
|
||||
q.log.Debugf("Stopping task pipeline")
|
||||
|
||||
close(q.quit)
|
||||
q.signalUntilShutdown()
|
||||
|
||||
// Skip log if we also force quit.
|
||||
select {
|
||||
case <-q.forceQuit:
|
||||
default:
|
||||
q.log.Debugf("Task pipeline stopped successfully")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ForceQuit signals the taskPipeline to immediately exit, dropping any
|
||||
// backupTasks that have not been delivered via NewBackupTasks.
|
||||
func (q *taskPipeline) ForceQuit() {
|
||||
q.forced.Do(func() {
|
||||
q.log.Infof("Force quitting task pipeline")
|
||||
|
||||
close(q.forceQuit)
|
||||
q.signalUntilShutdown()
|
||||
|
||||
q.log.Infof("Task pipeline unclean shutdown complete")
|
||||
})
|
||||
}
|
||||
|
||||
// NewBackupTasks returns a read-only channel for enqueue backupTasks. The
|
||||
// channel will be closed after a call to Stop and all pending tasks have been
|
||||
// delivered, or if a call to ForceQuit is called before the pending entries
|
||||
// have been drained.
|
||||
func (q *taskPipeline) NewBackupTasks() <-chan *wtdb.BackupID {
|
||||
return q.newBackupTasks
|
||||
}
|
||||
|
||||
// QueueBackupTask enqueues a backupTask for reliable delivery to the consumer
|
||||
// of NewBackupTasks. If the taskPipeline is shutting down, ErrClientExiting is
|
||||
// returned. Otherwise, if QueueBackupTask returns nil it is guaranteed to be
|
||||
// delivered via NewBackupTasks unless ForceQuit is called before completion.
|
||||
func (q *taskPipeline) QueueBackupTask(task *wtdb.BackupID) error {
|
||||
q.queueCond.L.Lock()
|
||||
select {
|
||||
|
||||
// Reject new tasks after quit has been signaled.
|
||||
case <-q.quit:
|
||||
q.queueCond.L.Unlock()
|
||||
return ErrClientExiting
|
||||
|
||||
// Reject new tasks after force quit has been signaled.
|
||||
case <-q.forceQuit:
|
||||
q.queueCond.L.Unlock()
|
||||
return ErrClientExiting
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
// Queue the new task and signal the queue's condition variable to wake
|
||||
// up the queueManager for processing.
|
||||
q.queue.PushBack(task)
|
||||
q.queueCond.L.Unlock()
|
||||
|
||||
q.queueCond.Signal()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// queueManager processes all incoming backup requests that get added via
|
||||
// QueueBackupTask. The manager will exit
|
||||
//
|
||||
// NOTE: This method MUST be run as a goroutine.
|
||||
func (q *taskPipeline) queueManager() {
|
||||
defer close(q.shutdown)
|
||||
defer close(q.newBackupTasks)
|
||||
|
||||
for {
|
||||
q.queueCond.L.Lock()
|
||||
for q.queue.Front() == nil {
|
||||
q.queueCond.Wait()
|
||||
|
||||
select {
|
||||
case <-q.quit:
|
||||
// Exit only after the queue has been fully
|
||||
// drained.
|
||||
if q.queue.Len() == 0 {
|
||||
q.queueCond.L.Unlock()
|
||||
q.log.Debugf("Revoked state pipeline " +
|
||||
"flushed.")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
case <-q.forceQuit:
|
||||
q.queueCond.L.Unlock()
|
||||
q.log.Debugf("Revoked state pipeline force " +
|
||||
"quit.")
|
||||
|
||||
return
|
||||
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Pop the first element from the queue.
|
||||
e := q.queue.Front()
|
||||
|
||||
//nolint:forcetypeassert
|
||||
task := q.queue.Remove(e).(*wtdb.BackupID)
|
||||
q.queueCond.L.Unlock()
|
||||
|
||||
select {
|
||||
|
||||
// Backup task submitted to dispatcher. We don't select on quit
|
||||
// to ensure that we still drain tasks while shutting down.
|
||||
case q.newBackupTasks <- task:
|
||||
|
||||
// Force quit, return immediately to allow the client to exit.
|
||||
case <-q.forceQuit:
|
||||
q.log.Debugf("Revoked state pipeline force quit.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// signalUntilShutdown strobes the queue's condition variable to ensure the
|
||||
// queueManager reliably unblocks to check for the exit condition.
|
||||
func (q *taskPipeline) signalUntilShutdown() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Millisecond):
|
||||
q.queueCond.Signal()
|
||||
case <-q.shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
@ -94,6 +94,10 @@ var (
|
||||
// db-session-id -> last-channel-close-height
|
||||
cClosableSessionsBkt = []byte("client-closable-sessions-bucket")
|
||||
|
||||
// cTaskQueue is a top-level bucket where the disk queue may store its
|
||||
// content.
|
||||
cTaskQueue = []byte("client-task-queue")
|
||||
|
||||
// ErrTowerNotFound signals that the target tower was not found in the
|
||||
// database.
|
||||
ErrTowerNotFound = errors.New("tower not found")
|
||||
@ -2060,6 +2064,15 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
||||
}, func() {})
|
||||
}
|
||||
|
||||
// GetDBQueue returns a BackupID Queue instance under the given namespace.
|
||||
func (c *ClientDB) GetDBQueue(namespace []byte) Queue[*BackupID] {
|
||||
return NewQueueDB[*BackupID](
|
||||
c.db, namespace, func() *BackupID {
|
||||
return &BackupID{}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// putChannelToSessionMapping adds the given session ID to a channel's
|
||||
// cChanSessions bucket.
|
||||
func putChannelToSessionMapping(chanDetails kvdb.RwBucket,
|
||||
|
488
watchtower/wtdb/queue.go
Normal file
488
watchtower/wtdb/queue.go
Normal file
@ -0,0 +1,488 @@
|
||||
package wtdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
)
|
||||
|
||||
var (
|
||||
// queueMainBkt will hold the main queue contents. It will have the
|
||||
// following structure:
|
||||
// => oldestIndexKey => oldest index
|
||||
// => nextIndexKey => newest index
|
||||
// => itemsBkt => <index> -> item
|
||||
//
|
||||
// Any items added to the queue via Push, will be added to this queue.
|
||||
// Items will only be popped from this queue if the head queue is empty.
|
||||
queueMainBkt = []byte("queue-main")
|
||||
|
||||
// queueHeadBkt will hold the items that have been pushed to the head
|
||||
// of the queue. It will have the following structure:
|
||||
// => oldestIndexKey => oldest index
|
||||
// => nextIndexKey => newest index
|
||||
// => itemsBkt => <index> -> item
|
||||
//
|
||||
// If PushHead is called with a new set of items, then first all
|
||||
// remaining items in the head queue will be popped and added ot the
|
||||
// given set of items. Then, once the head queue is empty, the set of
|
||||
// items will be pushed to the queue. If this queue is not empty, then
|
||||
// Pop will pop items from this queue before popping from the main
|
||||
// queue.
|
||||
queueHeadBkt = []byte("queue-head")
|
||||
|
||||
// itemsBkt is a sub-bucket of both the main and head queue storing:
|
||||
// index -> encoded item
|
||||
itemsBkt = []byte("items")
|
||||
|
||||
// oldestIndexKey is a key of both the main and head queue storing the
|
||||
// index of the item at the head of the queue.
|
||||
oldestIndexKey = []byte("oldest-index")
|
||||
|
||||
// nextIndexKey is a key of both the main and head queue storing the
|
||||
// index of the item at the tail of the queue.
|
||||
nextIndexKey = []byte("next-index")
|
||||
|
||||
// ErrEmptyQueue is returned from Pop if there are no items left in
|
||||
// the queue.
|
||||
ErrEmptyQueue = errors.New("queue is empty")
|
||||
)
|
||||
|
||||
// Queue is an interface describing a FIFO queue for any generic type T.
|
||||
type Queue[T any] interface {
|
||||
// Len returns the number of tasks in the queue.
|
||||
Len() (uint64, error)
|
||||
|
||||
// Push pushes new T items to the tail of the queue.
|
||||
Push(items ...T) error
|
||||
|
||||
// PopUpTo attempts to pop up to n items from the head of the queue. If
|
||||
// no more items are in the queue then ErrEmptyQueue is returned.
|
||||
PopUpTo(n int) ([]T, error)
|
||||
|
||||
// PushHead pushes new T items to the head of the queue.
|
||||
PushHead(items ...T) error
|
||||
}
|
||||
|
||||
// Serializable is an interface must be satisfied for any type that the
|
||||
// DiskQueueDB should handle.
|
||||
type Serializable interface {
|
||||
Encode(w io.Writer) error
|
||||
Decode(r io.Reader) error
|
||||
}
|
||||
|
||||
// DiskQueueDB is a generic Bolt DB implementation of the Queue interface.
|
||||
type DiskQueueDB[T Serializable] struct {
|
||||
db kvdb.Backend
|
||||
topLevelBkt []byte
|
||||
constructor func() T
|
||||
}
|
||||
|
||||
// A compile-time check to ensure that DiskQueueDB implements the Queue
|
||||
// interface.
|
||||
var _ Queue[Serializable] = (*DiskQueueDB[Serializable])(nil)
|
||||
|
||||
// NewQueueDB constructs a new DiskQueueDB. A queueBktName must be provided so
|
||||
// that the DiskQueueDB can create its own namespace in the bolt db.
|
||||
func NewQueueDB[T Serializable](db kvdb.Backend, queueBktName []byte,
|
||||
constructor func() T) Queue[T] {
|
||||
|
||||
return &DiskQueueDB[T]{
|
||||
db: db,
|
||||
topLevelBkt: queueBktName,
|
||||
constructor: constructor,
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the number of tasks in the queue.
|
||||
//
|
||||
// NOTE: This is part of the Queue interface.
|
||||
func (d *DiskQueueDB[T]) Len() (uint64, error) {
|
||||
var res uint64
|
||||
err := kvdb.View(d.db, func(tx kvdb.RTx) error {
|
||||
var err error
|
||||
res, err = d.len(tx)
|
||||
|
||||
return err
|
||||
}, func() {
|
||||
res = 0
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Push adds a T to the tail of the queue.
|
||||
//
|
||||
// NOTE: This is part of the Queue interface.
|
||||
func (d *DiskQueueDB[T]) Push(items ...T) error {
|
||||
return d.db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
for _, item := range items {
|
||||
err := d.addItem(tx, queueMainBkt, item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
}
|
||||
|
||||
// PopUpTo attempts to pop up to n items from the queue. If the queue is empty,
|
||||
// then ErrEmptyQueue is returned.
|
||||
//
|
||||
// NOTE: This is part of the Queue interface.
|
||||
func (d *DiskQueueDB[T]) PopUpTo(n int) ([]T, error) {
|
||||
var items []T
|
||||
|
||||
err := d.db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
// Get the number of items in the queue.
|
||||
l, err := d.len(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are no items, then we are done.
|
||||
if l == 0 {
|
||||
return ErrEmptyQueue
|
||||
}
|
||||
|
||||
// If the number of items in the queue is less than the maximum
|
||||
// specified by the caller, then set the maximum to the number
|
||||
// of items that there actually are.
|
||||
num := n
|
||||
if l < uint64(n) {
|
||||
num = int(l)
|
||||
}
|
||||
|
||||
// Pop the specified number of items off of the queue.
|
||||
items = make([]T, 0, num)
|
||||
for i := 0; i < num; i++ {
|
||||
item, err := d.pop(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return err
|
||||
}, func() {
|
||||
items = nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// PushHead pushes new T items to the head of the queue. For this implementation
|
||||
// of the Queue interface, this will require popping all items currently in the
|
||||
// head queue and adding them after first adding the given list of items. Care
|
||||
// should thus be taken to never have an unbounded number of items in the head
|
||||
// queue.
|
||||
//
|
||||
// NOTE: This is part of the Queue interface.
|
||||
func (d *DiskQueueDB[T]) PushHead(items ...T) error {
|
||||
return d.db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
// Determine how many items are still in the head queue.
|
||||
numHead, err := d.numItems(tx, queueHeadBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a new in-memory list that will contain all the new
|
||||
// items along with the items currently in the queue.
|
||||
itemList := make([]T, 0, int(numHead)+len(items))
|
||||
|
||||
// Insert all the given items into the list first since these
|
||||
// should be at the head of the queue.
|
||||
itemList = append(itemList, items...)
|
||||
|
||||
// Now, read out all the items that are currently in the
|
||||
// persisted head queue and add them to the back of the list
|
||||
// of items to be added.
|
||||
for {
|
||||
t, err := d.nextItem(tx, queueHeadBkt)
|
||||
if errors.Is(err, ErrEmptyQueue) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
itemList = append(itemList, t)
|
||||
}
|
||||
|
||||
// Now the head queue is empty, the items can be pushed to the
|
||||
// queue.
|
||||
for _, item := range itemList {
|
||||
err := d.addItem(tx, queueHeadBkt, item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
}
|
||||
|
||||
// pop gets the next T item from the head of the queue. If no more items are in
|
||||
// the queue then ErrEmptyQueue is returned.
|
||||
func (d *DiskQueueDB[T]) pop(tx walletdb.ReadWriteTx) (T, error) {
|
||||
// First, check if there are items left in the head queue.
|
||||
item, err := d.nextItem(tx, queueHeadBkt)
|
||||
|
||||
// No error means that an item was found in the head queue.
|
||||
if err == nil {
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// Else, if error is not ErrEmptyQueue, then return the error.
|
||||
if !errors.Is(err, ErrEmptyQueue) {
|
||||
return item, err
|
||||
}
|
||||
|
||||
// Otherwise, the head queue is empty, so we now check if there are
|
||||
// items in the main queue.
|
||||
return d.nextItem(tx, queueMainBkt)
|
||||
}
|
||||
|
||||
// addItem adds the given item to the back of the given queue.
|
||||
func (d *DiskQueueDB[T]) addItem(tx kvdb.RwTx, queueName []byte, item T) error {
|
||||
var (
|
||||
namespacedBkt = tx.ReadWriteBucket(d.topLevelBkt)
|
||||
err error
|
||||
)
|
||||
if namespacedBkt == nil {
|
||||
namespacedBkt, err = tx.CreateTopLevelBucket(d.topLevelBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
mainTasksBucket, err := namespacedBkt.CreateBucketIfNotExists(
|
||||
cTaskQueue,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bucket, err := mainTasksBucket.CreateBucketIfNotExists(queueName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Find the index to use for placing this new item at the back of the
|
||||
// queue.
|
||||
var nextIndex uint64
|
||||
nextIndexB := bucket.Get(nextIndexKey)
|
||||
if nextIndexB != nil {
|
||||
nextIndex, err = readBigSize(nextIndexB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
nextIndexB, err = writeBigSize(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
tasksBucket, err := bucket.CreateBucketIfNotExists(itemsBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buff bytes.Buffer
|
||||
err = item.Encode(&buff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Put the new task in the assigned index.
|
||||
err = tasksBucket.Put(nextIndexB, buff.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Increment the next-index counter.
|
||||
nextIndex++
|
||||
nextIndexB, err = writeBigSize(nextIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(nextIndexKey, nextIndexB)
|
||||
}
|
||||
|
||||
// nextItem pops an item of the queue identified by the given namespace. If
|
||||
// there are no items on the queue then ErrEmptyQueue is returned.
|
||||
func (d *DiskQueueDB[T]) nextItem(tx kvdb.RwTx, queueName []byte) (T, error) {
|
||||
task := d.constructor()
|
||||
|
||||
namespacedBkt := tx.ReadWriteBucket(d.topLevelBkt)
|
||||
if namespacedBkt == nil {
|
||||
return task, ErrEmptyQueue
|
||||
}
|
||||
|
||||
mainTasksBucket := namespacedBkt.NestedReadWriteBucket(cTaskQueue)
|
||||
if mainTasksBucket == nil {
|
||||
return task, ErrEmptyQueue
|
||||
}
|
||||
|
||||
bucket, err := mainTasksBucket.CreateBucketIfNotExists(queueName)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
|
||||
// Get the index of the tail of the queue.
|
||||
var nextIndex uint64
|
||||
nextIndexB := bucket.Get(nextIndexKey)
|
||||
if nextIndexB != nil {
|
||||
nextIndex, err = readBigSize(nextIndexB)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get the index of the head of the queue.
|
||||
var oldestIndex uint64
|
||||
oldestIndexB := bucket.Get(oldestIndexKey)
|
||||
if oldestIndexB != nil {
|
||||
oldestIndex, err = readBigSize(oldestIndexB)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
} else {
|
||||
oldestIndexB, err = writeBigSize(0)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
}
|
||||
|
||||
// If the head and tail are equal, then there are no items in the queue.
|
||||
if oldestIndex == nextIndex {
|
||||
// Take this opportunity to reset both indexes to zero.
|
||||
zeroIndexB, err := writeBigSize(0)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
|
||||
err = bucket.Put(oldestIndexKey, zeroIndexB)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
|
||||
err = bucket.Put(nextIndexKey, zeroIndexB)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
|
||||
return task, ErrEmptyQueue
|
||||
}
|
||||
|
||||
// Otherwise, pop the item at the oldest index.
|
||||
tasksBucket := bucket.NestedReadWriteBucket(itemsBkt)
|
||||
if tasksBucket == nil {
|
||||
return task, fmt.Errorf("client-tasks bucket not found")
|
||||
}
|
||||
|
||||
item := tasksBucket.Get(oldestIndexB)
|
||||
if item == nil {
|
||||
return task, fmt.Errorf("no task found under index")
|
||||
}
|
||||
|
||||
err = tasksBucket.Delete(oldestIndexB)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
|
||||
// Increment the oldestIndex value so that it now points to the new
|
||||
// oldest item.
|
||||
oldestIndex++
|
||||
oldestIndexB, err = writeBigSize(oldestIndex)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
|
||||
err = bucket.Put(oldestIndexKey, oldestIndexB)
|
||||
if err != nil {
|
||||
return task, err
|
||||
}
|
||||
|
||||
if err = task.Decode(bytes.NewBuffer(item)); err != nil {
|
||||
return task, err
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// len returns the number of items in the queue. This will be the addition of
|
||||
// the number of items in the main queue and the number in the head queue.
|
||||
func (d *DiskQueueDB[T]) len(tx kvdb.RTx) (uint64, error) {
|
||||
numMain, err := d.numItems(tx, queueMainBkt)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
numHead, err := d.numItems(tx, queueHeadBkt)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return numMain + numHead, nil
|
||||
}
|
||||
|
||||
// numItems returns the number of items in the given queue.
|
||||
func (d *DiskQueueDB[T]) numItems(tx kvdb.RTx, queueName []byte) (uint64,
|
||||
error) {
|
||||
|
||||
// Get the queue bucket at the correct namespace.
|
||||
namespacedBkt := tx.ReadBucket(d.topLevelBkt)
|
||||
if namespacedBkt == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
mainTasksBucket := namespacedBkt.NestedReadBucket(cTaskQueue)
|
||||
if mainTasksBucket == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
bucket := mainTasksBucket.NestedReadBucket(queueName)
|
||||
if bucket == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
oldestIndex uint64
|
||||
nextIndex uint64
|
||||
err error
|
||||
)
|
||||
|
||||
// Get the next index key.
|
||||
nextIndexB := bucket.Get(nextIndexKey)
|
||||
if nextIndexB != nil {
|
||||
nextIndex, err = readBigSize(nextIndexB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get the oldest index.
|
||||
oldestIndexB := bucket.Get(oldestIndexKey)
|
||||
if oldestIndexB != nil {
|
||||
oldestIndex, err = readBigSize(oldestIndexB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return nextIndex - oldestIndex, nil
|
||||
}
|
144
watchtower/wtdb/queue_test.go
Normal file
144
watchtower/wtdb/queue_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
package wtdb_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDiskQueue ensures that the ClientDBs disk queue methods behave as is
|
||||
// expected of a queue.
|
||||
func TestDiskQueue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbs := []struct {
|
||||
name string
|
||||
init clientDBInit
|
||||
}{
|
||||
{
|
||||
name: "bbolt",
|
||||
init: func(t *testing.T) wtclient.DB {
|
||||
dbCfg := &kvdb.BoltConfig{
|
||||
DBTimeout: kvdb.DefaultDBTimeout,
|
||||
}
|
||||
|
||||
// Construct the ClientDB.
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, t.TempDir(), "wtclient.db",
|
||||
)(dbCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mock",
|
||||
init: func(t *testing.T) wtclient.DB {
|
||||
return wtmock.NewClientDB()
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
db := database
|
||||
t.Run(db.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testQueue(t, db.init(t))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testQueue(t *testing.T, db wtclient.DB) {
|
||||
namespace := []byte("test-namespace")
|
||||
queue := db.GetDBQueue(namespace)
|
||||
|
||||
addTasksToTail := func(tasks ...*wtdb.BackupID) {
|
||||
err := queue.Push(tasks...)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
addTasksToHead := func(tasks ...*wtdb.BackupID) {
|
||||
err := queue.PushHead(tasks...)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assertNumTasks := func(expNum int) {
|
||||
num, err := queue.Len()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, expNum, num)
|
||||
}
|
||||
|
||||
popAndAssert := func(expTasks ...*wtdb.BackupID) {
|
||||
tasks, err := queue.PopUpTo(len(expTasks))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, expTasks, tasks)
|
||||
}
|
||||
|
||||
// Create a few tasks that we use throughout the test.
|
||||
task1 := &wtdb.BackupID{CommitHeight: 1}
|
||||
task2 := &wtdb.BackupID{CommitHeight: 2}
|
||||
task3 := &wtdb.BackupID{CommitHeight: 3}
|
||||
task4 := &wtdb.BackupID{CommitHeight: 4}
|
||||
task5 := &wtdb.BackupID{CommitHeight: 5}
|
||||
task6 := &wtdb.BackupID{CommitHeight: 6}
|
||||
|
||||
// Namespace 1 should initially have no items.
|
||||
assertNumTasks(0)
|
||||
|
||||
// Now add a few items to the tail of the queue.
|
||||
addTasksToTail(task1, task2)
|
||||
|
||||
// Check that the number of tasks is now two.
|
||||
assertNumTasks(2)
|
||||
|
||||
// Pop a task, check that it is task 1 and assert that the number of
|
||||
// items left is now 1.
|
||||
popAndAssert(task1)
|
||||
assertNumTasks(1)
|
||||
|
||||
// Pop a task, check that it is task 2 and assert that the number of
|
||||
// items left is now 0.
|
||||
popAndAssert(task2)
|
||||
assertNumTasks(0)
|
||||
|
||||
// Once again add a few tasks.
|
||||
addTasksToTail(task3, task4)
|
||||
|
||||
// Now push some tasks to the head of the queue.
|
||||
addTasksToHead(task6, task5)
|
||||
|
||||
// Ensure that both the disk queue lengths are added together when
|
||||
// querying the length of the queue.
|
||||
assertNumTasks(4)
|
||||
|
||||
// Ensure that the order that the tasks are popped is correct.
|
||||
popAndAssert(task6, task5, task3, task4)
|
||||
|
||||
// We also want to test that the head queue works as expected and that.
|
||||
// To do this, we first push 4, 5 and 6 to the queue.
|
||||
addTasksToTail(task4, task5, task6)
|
||||
|
||||
// Now we push 1, 2 and 3 to the head.
|
||||
addTasksToHead(task1, task2, task3)
|
||||
|
||||
// Now, only pop item 1 from the queue and then re-add it to the head.
|
||||
popAndAssert(task1)
|
||||
addTasksToHead(task1)
|
||||
|
||||
// This should not have changed the order of the tasks, they should
|
||||
// still appear in the correct order.
|
||||
popAndAssert(task1, task2, task3, task4, task5, task6)
|
||||
}
|
@ -49,6 +49,8 @@ type ClientDB struct {
|
||||
nextIndex uint32
|
||||
indexes map[keyIndexKey]uint32
|
||||
legacyIndexes map[wtdb.TowerID]uint32
|
||||
|
||||
queues map[string]wtdb.Queue[*wtdb.BackupID]
|
||||
}
|
||||
|
||||
// NewClientDB initializes a new mock ClientDB.
|
||||
@ -68,6 +70,7 @@ func NewClientDB() *ClientDB {
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
closableSessions: make(map[wtdb.SessionID]uint32),
|
||||
queues: make(map[string]wtdb.Queue[*wtdb.BackupID]),
|
||||
}
|
||||
}
|
||||
|
||||
@ -568,6 +571,21 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
|
||||
return wtdb.ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
// GetDBQueue returns a BackupID Queue instance under the given name space.
|
||||
func (m *ClientDB) GetDBQueue(namespace []byte) wtdb.Queue[*wtdb.BackupID] {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if q, ok := m.queues[string(namespace)]; ok {
|
||||
return q
|
||||
}
|
||||
|
||||
q := NewQueueDB[*wtdb.BackupID]()
|
||||
m.queues[string(namespace)] = q
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// ListClosableSessions fetches and returns the IDs for all sessions marked as
|
||||
// closable.
|
||||
func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) {
|
||||
|
92
watchtower/wtmock/queue.go
Normal file
92
watchtower/wtmock/queue.go
Normal file
@ -0,0 +1,92 @@
|
||||
package wtmock
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
// DiskQueueDB is an in-memory implementation of the wtclient.Queue interface.
|
||||
type DiskQueueDB[T any] struct {
|
||||
disk *list.List
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewQueueDB constructs a new DiskQueueDB.
|
||||
func NewQueueDB[T any]() wtdb.Queue[T] {
|
||||
return &DiskQueueDB[T]{
|
||||
disk: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the number of tasks in the queue.
|
||||
//
|
||||
// NOTE: This is part of the wtclient.Queue interface.
|
||||
func (d *DiskQueueDB[T]) Len() (uint64, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
return uint64(d.disk.Len()), nil
|
||||
}
|
||||
|
||||
// Push adds new T items to the tail of the queue.
|
||||
//
|
||||
// NOTE: This is part of the wtclient.Queue interface.
|
||||
func (d *DiskQueueDB[T]) Push(items ...T) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
for _, item := range items {
|
||||
d.disk.PushBack(item)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PopUpTo attempts to pop up to n items from the queue. If the queue is empty,
|
||||
// then ErrEmptyQueue is returned.
|
||||
//
|
||||
// NOTE: This is part of the Queue interface.
|
||||
func (d *DiskQueueDB[T]) PopUpTo(n int) ([]T, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if d.disk.Len() == 0 {
|
||||
return nil, wtdb.ErrEmptyQueue
|
||||
}
|
||||
|
||||
num := n
|
||||
if d.disk.Len() < n {
|
||||
num = d.disk.Len()
|
||||
}
|
||||
|
||||
tasks := make([]T, 0, num)
|
||||
for i := 0; i < num; i++ {
|
||||
e := d.disk.Front()
|
||||
task, ok := d.disk.Remove(e).(T)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("queue item not of type %T",
|
||||
task)
|
||||
}
|
||||
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// PushHead pushes new T items to the head of the queue.
|
||||
//
|
||||
// NOTE: This is part of the wtclient.Queue interface.
|
||||
func (d *DiskQueueDB[T]) PushHead(items ...T) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
for i := len(items) - 1; i >= 0; i-- {
|
||||
d.disk.PushFront(items[i])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user