Merge pull request #7623 from ellemouton/queueBackupIDOnly

wtclient: queue backup id only
This commit is contained in:
Oliver Gugger 2023-04-24 15:36:09 +02:00 committed by GitHub
commit 4355ce62d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 460 additions and 258 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/go-errors/errors"
mig "github.com/lightningnetwork/lnd/channeldb/migration"
"github.com/lightningnetwork/lnd/channeldb/migration12"
@ -654,20 +655,94 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) (
// FetchChannel attempts to locate a channel specified by the passed channel
// point. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied. Optionally an existing db tx
// can be supplied.
// Optionally an existing db tx can be supplied.
func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
*OpenChannel, error) {
var (
targetChan *OpenChannel
targetChanPoint bytes.Buffer
)
var targetChanPoint bytes.Buffer
if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil {
return nil, err
}
targetChanPointBytes := targetChanPoint.Bytes()
selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error) {
return targetChanPointBytes, &chanPoint, nil
}
return c.channelScanner(tx, selector)
}
// FetchChannelByID attempts to locate a channel specified by the passed channel
// ID. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied.
func (c *ChannelStateDB) FetchChannelByID(tx kvdb.RTx, id lnwire.ChannelID) (
*OpenChannel, error) {
selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error) {
var (
targetChanPointBytes []byte
targetChanPoint *wire.OutPoint
// errChanFound is used to signal that the channel has
// been found so that iteration through the DB buckets
// can stop.
errChanFound = errors.New("channel found")
)
err := chainBkt.ForEach(func(k, _ []byte) error {
var outPoint wire.OutPoint
err := readOutpoint(bytes.NewReader(k), &outPoint)
if err != nil {
return err
}
chanID := lnwire.NewChanIDFromOutPoint(&outPoint)
if chanID != id {
return nil
}
targetChanPoint = &outPoint
targetChanPointBytes = k
return errChanFound
})
if err != nil && !errors.Is(err, errChanFound) {
return nil, nil, err
}
if targetChanPoint == nil {
return nil, nil, ErrChannelNotFound
}
return targetChanPointBytes, targetChanPoint, nil
}
return c.channelScanner(tx, selector)
}
// channelSelector describes a function that takes a chain-hash bucket from
// within the open-channel DB and returns the wanted channel point bytes, and
// channel point. It must return the ErrChannelNotFound error if the wanted
// channel is not in the given bucket.
type channelSelector func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error)
// channelScanner will traverse the DB to each chain-hash bucket of each node
// pub-key bucket in the open-channel-bucket. The chanSelector will then be used
// to fetch the wanted channel outpoint from the chain bucket.
func (c *ChannelStateDB) channelScanner(tx kvdb.RTx,
chanSelect channelSelector) (*OpenChannel, error) {
var (
targetChan *OpenChannel
// errChanFound is used to signal that the channel has been
// found so that iteration through the DB buckets can stop.
errChanFound = errors.New("channel found")
)
// chanScan will traverse the following bucket structure:
// * nodePub => chainHash => chanPoint
//
@ -685,8 +760,8 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
}
// Within the node channel bucket, are the set of node pubkeys
// we have channels with, we don't know the entire set, so
// we'll check them all.
// we have channels with, we don't know the entire set, so we'll
// check them all.
return openChanBucket.ForEach(func(nodePub, v []byte) error {
// Ensure that this is a key the same size as a pubkey,
// and also that it leads directly to a bucket.
@ -694,7 +769,9 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
return nil
}
nodeChanBucket := openChanBucket.NestedReadBucket(nodePub)
nodeChanBucket := openChanBucket.NestedReadBucket(
nodePub,
)
if nodeChanBucket == nil {
return nil
}
@ -715,20 +792,30 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
)
if chainBucket == nil {
return fmt.Errorf("unable to read "+
"bucket for chain=%x", chainHash[:])
"bucket for chain=%x",
chainHash)
}
// Finally we reach the leaf bucket that stores
// Finally, we reach the leaf bucket that stores
// all the chanPoints for this node.
targetChanBytes, chanPoint, err := chanSelect(
chainBucket,
)
if errors.Is(err, ErrChannelNotFound) {
return nil
} else if err != nil {
return err
}
chanBucket := chainBucket.NestedReadBucket(
targetChanPoint.Bytes(),
targetChanBytes,
)
if chanBucket == nil {
return nil
}
channel, err := fetchOpenChannel(
chanBucket, &chanPoint,
chanBucket, chanPoint,
)
if err != nil {
return err
@ -737,7 +824,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
targetChan = channel
targetChan.Db = c
return nil
return errChanFound
})
})
}
@ -748,7 +835,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
} else {
err = chanScan(tx)
}
if err != nil {
if err != nil && !errors.Is(err, errChanFound) {
return nil, err
}
@ -757,7 +844,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
}
// If we can't find the channel, then we return with an error, as we
// have nothing to backup.
// have nothing to back up.
return nil, ErrChannelNotFound
}

View file

@ -12,7 +12,6 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
@ -238,10 +237,16 @@ func TestFetchChannel(t *testing.T) {
// The decoded channel state should be identical to what we stored
// above.
if !reflect.DeepEqual(channelState, dbChannel) {
t.Fatalf("channel state doesn't match:: %v vs %v",
spew.Sdump(channelState), spew.Sdump(dbChannel))
}
require.Equal(t, channelState, dbChannel)
// Next, attempt to fetch the channel by its channel ID.
chanID := lnwire.NewChanIDFromOutPoint(&channelState.FundingOutpoint)
dbChannel, err = cdb.FetchChannelByID(nil, chanID)
require.NoError(t, err, "unable to fetch channel")
// The decoded channel state should be identical to what we stored
// above.
require.Equal(t, channelState, dbChannel)
// If we attempt to query for a non-existent channel, then we should
// get an error.
@ -252,9 +257,11 @@ func TestFetchChannel(t *testing.T) {
channelState2.FundingOutpoint.Index = uniqueOutputIndex.Load()
_, err = cdb.FetchChannel(nil, channelState2.FundingOutpoint)
if err == nil {
t.Fatalf("expected query to fail")
}
require.ErrorIs(t, err, ErrChannelNotFound)
chanID2 := lnwire.NewChanIDFromOutPoint(&channelState2.FundingOutpoint)
_, err = cdb.FetchChannelByID(nil, chanID2)
require.ErrorIs(t, err, ErrChannelNotFound)
}
func genRandomChannelShell() (*ChannelShell, error) {

View file

@ -6,7 +6,15 @@
implementation](https://github.com/lightningnetwork/lnd/pull/7377) logic in
different update types.
## Watchtowers
* Let the task pipeline [only carry
wtdb.BackupIDs](https://github.com/lightningnetwork/lnd/pull/7623) instead of
the entire retribution struct. This reduces the amount of data that needs to
be held in memory.
# Contributors (Alphabetical Order)
* Elle Mouton
* Jordi Montes

View file

@ -7,7 +7,6 @@ import (
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
@ -255,11 +254,8 @@ type TowerClient interface {
// state. If the method returns nil, the backup is guaranteed to be
// successful unless the tower is unavailable and client is force quit,
// or the justice transaction would create dust outputs when trying to
// abide by the negotiated policy. If the channel we're trying to back
// up doesn't have a tweak for the remote party's output, then
// isTweakless should be true.
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution,
channeldb.ChannelType) error
// abide by the negotiated policy.
BackupState(chanID *lnwire.ChannelID, stateNum uint64) error
}
// InterceptableHtlcForwarder is the interface to set the interceptor

View file

@ -2022,11 +2022,6 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// We've received a revocation from the remote chain, if valid,
// this moves the remote chain forward, and expands our
// revocation window.
//
// Before advancing our remote chain, we will record the
// current commit tx, which is used by the TowerClient to
// create backups.
oldCommitTx := l.channel.State().RemoteCommitment.CommitTx
// We now process the message and advance our remote commit
// chain.
@ -2063,24 +2058,15 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// create a backup for the current state.
if l.cfg.TowerClient != nil {
state := l.channel.State()
breachInfo, err := lnwallet.NewBreachRetribution(
state, state.RemoteCommitment.CommitHeight-1, 0,
// OldCommitTx is the breaching tx at height-1.
oldCommitTx,
)
if err != nil {
l.fail(LinkFailureError{code: ErrInternalError},
"failed to load breach info: %v", err)
return
}
chanID := l.ChanID()
err = l.cfg.TowerClient.BackupState(
&chanID, breachInfo, state.ChanType,
&chanID, state.RemoteCommitment.CommitHeight-1,
)
if err != nil {
l.fail(LinkFailureError{code: ErrInternalError},
"unable to queue breach backup: %v", err)
"unable to queue breach backup: %v",
err)
return
}
}

View file

@ -1523,12 +1523,37 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
)
}
// buildBreachRetribution is a call-back that can be used to
// query the BreachRetribution info and channel type given a
// channel ID and commitment height.
buildBreachRetribution := func(chanID lnwire.ChannelID,
commitHeight uint64) (*lnwallet.BreachRetribution,
channeldb.ChannelType, error) {
channel, err := s.chanStateDB.FetchChannelByID(
nil, chanID,
)
if err != nil {
return nil, 0, err
}
br, err := lnwallet.NewBreachRetribution(
channel, commitHeight, 0, nil,
)
if err != nil {
return nil, 0, err
}
return br, channel.ChanType, nil
}
fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID
s.towerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
FetchClosedChannel: fetchClosedChannel,
BuildBreachRetribution: buildBreachRetribution,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription,
error) {
@ -1558,9 +1583,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
blob.Type(blob.FlagAnchorChannel)
s.anchorTowerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
FetchClosedChannel: fetchClosedChannel,
BuildBreachRetribution: buildBreachRetribution,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription,
error) {

View file

@ -10,7 +10,6 @@ import (
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
@ -40,7 +39,6 @@ import (
type backupTask struct {
id wtdb.BackupID
breachInfo *lnwallet.BreachRetribution
chanType channeldb.ChannelType
// state-dependent variables
@ -55,11 +53,79 @@ type backupTask struct {
outputs []*wire.TxOut
}
// newBackupTask initializes a new backupTask and populates all state-dependent
// variables.
func newBackupTask(chanID *lnwire.ChannelID,
breachInfo *lnwallet.BreachRetribution,
sweepPkScript []byte, chanType channeldb.ChannelType) *backupTask {
// newBackupTask initializes a new backupTask.
func newBackupTask(id wtdb.BackupID, sweepPkScript []byte) *backupTask {
return &backupTask{
id: id,
sweepPkScript: sweepPkScript,
}
}
// inputs returns all non-dust inputs that we will attempt to spend from.
//
// NOTE: Ordering of the inputs is not critical as we sort the transaction with
// BIP69 in a later stage.
func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
inputs := make(map[wire.OutPoint]input.Input)
if t.toLocalInput != nil {
inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput
}
if t.toRemoteInput != nil {
inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput
}
return inputs
}
// addrType returns the type of an address after parsing it and matching it to
// the set of known script templates.
func addrType(pkScript []byte) txscript.ScriptClass {
// We pass in a set of dummy chain params here as they're only needed
// to make the address struct, which we're ignoring anyway (scripts are
// always the same, it's addresses that change across chains).
scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs(
pkScript, &chaincfg.MainNetParams,
)
return scriptClass
}
// addScriptWeight parses the passed pkScript and adds the computed weight cost
// were the script to be added to the justice transaction.
func addScriptWeight(weightEstimate *input.TxWeightEstimator,
pkScript []byte) error {
switch addrType(pkScript) {
case txscript.WitnessV0PubKeyHashTy:
weightEstimate.AddP2WKHOutput()
case txscript.WitnessV0ScriptHashTy:
weightEstimate.AddP2WSHOutput()
case txscript.WitnessV1TaprootTy:
weightEstimate.AddP2TROutput()
default:
return fmt.Errorf("invalid addr type: %v", addrType(pkScript))
}
return nil
}
// bindSession first populates all state-dependent variables of the task. Then
// it determines if the backupTask is compatible with the passed SessionInfo's
// policy. If no error is returned, the task has been bound to the session and
// can be queued to upload to the tower. Otherwise, the bind failed and should
// be rescheduled with a different session.
func (t *backupTask) bindSession(session *wtdb.ClientSessionBody,
newBreachRetribution BreachRetributionBuilder) error {
breachInfo, chanType, err := newBreachRetribution(
t.id.ChanID, t.id.CommitHeight,
)
if err != nil {
return err
}
// Parse the non-dust outputs from the breach transaction,
// simultaneously computing the total amount contained in the inputs
@ -123,76 +189,11 @@ func newBackupTask(chanID *lnwire.ChannelID,
totalAmt += breachInfo.LocalOutputSignDesc.Output.Value
}
return &backupTask{
id: wtdb.BackupID{
ChanID: *chanID,
CommitHeight: breachInfo.RevokedStateNum,
},
breachInfo: breachInfo,
chanType: chanType,
toLocalInput: toLocalInput,
toRemoteInput: toRemoteInput,
totalAmt: btcutil.Amount(totalAmt),
sweepPkScript: sweepPkScript,
}
}
t.breachInfo = breachInfo
t.toLocalInput = toLocalInput
t.toRemoteInput = toRemoteInput
t.totalAmt = btcutil.Amount(totalAmt)
// inputs returns all non-dust inputs that we will attempt to spend from.
//
// NOTE: Ordering of the inputs is not critical as we sort the transaction with
// BIP69.
func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
inputs := make(map[wire.OutPoint]input.Input)
if t.toLocalInput != nil {
inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput
}
if t.toRemoteInput != nil {
inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput
}
return inputs
}
// addrType returns the type of an address after parsing it and matching it to
// the set of known script templates.
func addrType(pkScript []byte) txscript.ScriptClass {
// We pass in a set of dummy chain params here as they're only needed
// to make the address struct, which we're ignoring anyway (scripts are
// always the same, it's addresses that change across chains).
scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs(
pkScript, &chaincfg.MainNetParams,
)
return scriptClass
}
// addScriptWeight parses the passed pkScript and adds the computed weight cost
// were the script to be added to the justice transaction.
func addScriptWeight(weightEstimate *input.TxWeightEstimator,
pkScript []byte) error {
switch addrType(pkScript) { //nolint: whitespace
case txscript.WitnessV0PubKeyHashTy:
weightEstimate.AddP2WKHOutput()
case txscript.WitnessV0ScriptHashTy:
weightEstimate.AddP2WSHOutput()
case txscript.WitnessV1TaprootTy:
weightEstimate.AddP2TROutput()
default:
return fmt.Errorf("invalid addr type: %v", addrType(pkScript))
}
return nil
}
// bindSession determines if the backupTask is compatible with the passed
// SessionInfo's policy. If no error is returned, the task has been bound to the
// session and can be queued to upload to the tower. Otherwise, the bind failed
// and should be rescheduled with a different session.
func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// First we'll begin by deriving a weight estimate for the justice
// transaction. The final weight can be different depending on whether
// the watchtower is taking a reward.
@ -208,7 +209,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// original weight estimate. For anchor channels we'll go ahead
// an use the correct penalty witness when signing our justice
// transactions.
if t.chanType.HasAnchors() {
if chanType.HasAnchors() {
weightEstimate.AddWitnessInput(
input.ToLocalPenaltyWitnessSize,
)
@ -222,8 +223,10 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// Legacy channels (both tweaked and non-tweaked) spend from
// P2WKH output. Anchor channels spend a to-remote confirmed
// P2WSH output.
if t.chanType.HasAnchors() {
weightEstimate.AddWitnessInput(input.ToRemoteConfirmedWitnessSize)
if chanType.HasAnchors() {
weightEstimate.AddWitnessInput(
input.ToRemoteConfirmedWitnessSize,
)
} else {
weightEstimate.AddWitnessInput(input.P2WKHWitnessSize)
}
@ -231,7 +234,8 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// All justice transactions will either use segwit v0 (p2wkh + p2wsh)
// or segwit v1 (p2tr).
if err := addScriptWeight(&weightEstimate, t.sweepPkScript); err != nil {
err = addScriptWeight(&weightEstimate, t.sweepPkScript)
if err != nil {
return err
}
@ -244,9 +248,9 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
}
}
if t.chanType.HasAnchors() != session.Policy.IsAnchorChannel() {
if chanType.HasAnchors() != session.Policy.IsAnchorChannel() {
log.Criticalf("Invalid task (has_anchors=%t) for session "+
"(has_anchors=%t)", t.chanType.HasAnchors(),
"(has_anchors=%t)", chanType.HasAnchors(),
session.Policy.IsAnchorChannel())
}

View file

@ -483,19 +483,19 @@ func TestBackupTask(t *testing.T) {
func testBackupTask(t *testing.T, test backupTaskTest) {
// Create a new backupTask from the channel id and breach info.
task := newBackupTask(
&test.chanID, test.breachInfo, test.expSweepScript,
test.chanType,
)
id := wtdb.BackupID{
ChanID: test.chanID,
CommitHeight: test.breachInfo.RevokedStateNum,
}
task := newBackupTask(id, test.expSweepScript)
// Assert that all parameters set during initialization are properly
// populated.
require.Equal(t, test.chanID, task.id.ChanID)
require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight)
require.Equal(t, test.expTotalAmt, task.totalAmt)
require.Equal(t, test.breachInfo, task.breachInfo)
require.Equal(t, test.expToLocalInput, task.toLocalInput)
require.Equal(t, test.expToRemoteInput, task.toRemoteInput)
// getBreachInfo is a helper closure that returns the breach retribution
// info and channel type for the given channel and commit height.
getBreachInfo := func(id lnwire.ChannelID, commitHeight uint64) (
*lnwallet.BreachRetribution, channeldb.ChannelType, error) {
return test.breachInfo, test.chanType, nil
}
// Reconstruct the expected input.Inputs that will be returned by the
// task's inputs() method.
@ -515,9 +515,18 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
// Now, bind the session to the task. If successful, this locks in the
// session's negotiated parameters and allows the backup task to derive
// the final free variables in the justice transaction.
err := task.bindSession(test.session)
err := task.bindSession(test.session, getBreachInfo)
require.ErrorIs(t, err, test.bindErr)
// Assert that all parameters set during after binding the backup task
// are properly populated.
require.Equal(t, test.chanID, task.id.ChanID)
require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight)
require.Equal(t, test.expTotalAmt, task.totalAmt)
require.Equal(t, test.breachInfo, task.breachInfo)
require.Equal(t, test.expToLocalInput, task.toLocalInput)
require.Equal(t, test.expToRemoteInput, task.toRemoteInput)
// Exit early if the bind was supposed to fail. But first, we check that
// all fields set during a bind are still unset. This ensure that a
// failed bind doesn't have side-effects if the task is retried with a

View file

@ -34,8 +34,8 @@ const (
// a read before breaking out of a blocking read.
DefaultReadTimeout = 15 * time.Second
// DefaultWriteTimeout specifies the default duration we will wait during
// a write before breaking out of a blocking write.
// DefaultWriteTimeout specifies the default duration we will wait
// during a write before breaking out of a blocking write.
DefaultWriteTimeout = 15 * time.Second
// DefaultStatInterval specifies the default interval between logging
@ -136,11 +136,8 @@ type Client interface {
// state. If the method returns nil, the backup is guaranteed to be
// successful unless the client is force quit, or the justice
// transaction would create dust outputs when trying to abide by the
// negotiated policy. If the channel we're trying to back up doesn't
// have a tweak for the remote party's output, then isTweakless should
// be true.
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution,
channeldb.ChannelType) error
// negotiated policy.
BackupState(chanID *lnwire.ChannelID, stateNum uint64) error
// Start initializes the watchtower client, allowing it process requests
// to backup revoked channel states.
@ -178,6 +175,11 @@ type Config struct {
// ChainNotifier can be used to subscribe to block notifications.
ChainNotifier chainntnfs.ChainNotifier
// BuildBreachRetribution is a function closure that allows the client
// fetch the breach retribution info for a certain channel at a certain
// revoked commitment height.
BuildBreachRetribution BreachRetributionBuilder
// NewAddress generates a new on-chain sweep pkscript.
NewAddress func() ([]byte, error)
@ -240,6 +242,12 @@ type Config struct {
SessionCloseRange uint32
}
// BreachRetributionBuilder is a function that can be used to construct a
// BreachRetribution from a channel ID and a commitment height.
type BreachRetributionBuilder func(id lnwire.ChannelID,
commitHeight uint64) (*lnwallet.BreachRetribution,
channeldb.ChannelType, error)
// newTowerMsg is an internal message we'll use within the TowerClient to signal
// that a new tower can be considered.
type newTowerMsg struct {
@ -293,7 +301,7 @@ type TowerClient struct {
activeSessions sessionQueueSet
sessionQueue *sessionQueue
prevTask *backupTask
prevTask *wtdb.BackupID
closableSessionQueue *sessionCloseMinHeap
@ -378,6 +386,9 @@ func New(config *Config) (*TowerClient, error) {
return
}
c.backupMu.Lock()
defer c.backupMu.Unlock()
// Take the highest commit height found in the session's acked
// updates.
height, ok := c.chanCommitHeights[chanID]
@ -558,8 +569,11 @@ func (c *TowerClient) Start() error {
// committed but unacked state updates. This ensures that these
// sessions will be able to flush the committed updates after a
// restart.
fetchCommittedUpdates := c.cfg.DB.FetchSessionCommittedUpdates
for _, session := range c.candidateSessions {
committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID)
committedUpdates, err := fetchCommittedUpdates(
&session.ID,
)
if err != nil {
returnErr = err
return
@ -788,40 +802,41 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
// - client is force quit,
// - justice transaction would create dust outputs when trying to abide by the
// negotiated policy, or
// - breached outputs contain too little value to sweep at the target sweep fee
// rate.
// - breached outputs contain too little value to sweep at the target sweep
// fee rate.
func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
breachInfo *lnwallet.BreachRetribution,
chanType channeldb.ChannelType) error {
stateNum uint64) error {
// Retrieve the cached sweep pkscript used for this channel.
// Make sure that this channel is registered with the tower client.
c.backupMu.Lock()
summary, ok := c.summaries[*chanID]
if !ok {
if _, ok := c.summaries[*chanID]; !ok {
c.backupMu.Unlock()
return ErrUnregisteredChannel
}
// Ignore backups that have already been presented to the client.
height, ok := c.chanCommitHeights[*chanID]
if ok && breachInfo.RevokedStateNum <= height {
if ok && stateNum <= height {
c.backupMu.Unlock()
c.log.Debugf("Ignoring duplicate backup for chanid=%v at height=%d",
chanID, breachInfo.RevokedStateNum)
c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+
"height=%d", chanID, stateNum)
return nil
}
// This backup has a higher commit height than any known backup for this
// channel. We'll update our tip so that we won't accept it again if the
// link flaps.
c.chanCommitHeights[*chanID] = breachInfo.RevokedStateNum
c.chanCommitHeights[*chanID] = stateNum
c.backupMu.Unlock()
task := newBackupTask(
chanID, breachInfo, summary.SweepPkScript, chanType,
)
id := &wtdb.BackupID{
ChanID: *chanID,
CommitHeight: stateNum,
}
return c.pipeline.QueueBackupTask(task)
return c.pipeline.QueueBackupTask(id)
}
// nextSessionQueue attempts to fetch an active session from our set of
@ -1315,7 +1330,7 @@ func (c *TowerClient) backupDispatcher() {
return
}
c.log.Debugf("Processing %v", task.id)
c.log.Debugf("Processing %v", task)
c.stats.taskReceived()
c.processTask(task)
@ -1345,8 +1360,22 @@ func (c *TowerClient) backupDispatcher() {
// sessionQueue hasn't been exhausted before proceeding to the next task. Tasks
// that are rejected because the active sessionQueue is full will be cached as
// the prevTask, and should be reprocessed after obtaining a new sessionQueue.
func (c *TowerClient) processTask(task *backupTask) {
status, accepted := c.sessionQueue.AcceptTask(task)
func (c *TowerClient) processTask(task *wtdb.BackupID) {
c.backupMu.Lock()
summary, ok := c.summaries[task.ChanID]
if !ok {
c.backupMu.Unlock()
log.Infof("not processing task for unregistered channel: %s",
task.ChanID)
return
}
c.backupMu.Unlock()
backupTask := newBackupTask(*task, summary.SweepPkScript)
status, accepted := c.sessionQueue.AcceptTask(backupTask)
if accepted {
c.taskAccepted(task, status)
} else {
@ -1359,9 +1388,11 @@ func (c *TowerClient) processTask(task *backupTask) {
// prevTask is always removed as a result of this call. The client's
// sessionQueue will be removed if accepting the task left the sessionQueue in
// an exhausted state.
func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) {
c.log.Infof("Queued %v successfully for session %v",
task.id, c.sessionQueue.ID())
func (c *TowerClient) taskAccepted(task *wtdb.BackupID,
newStatus reserveStatus) {
c.log.Infof("Queued %v successfully for session %v", task,
c.sessionQueue.ID())
c.stats.taskAccepted()
@ -1394,7 +1425,9 @@ func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) {
// the sessionQueue to find a new session. If the sessionQueue was not
// exhausted, the client marks the task as ineligible, as this implies we
// couldn't construct a valid justice transaction given the session's policy.
func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) {
func (c *TowerClient) taskRejected(task *wtdb.BackupID,
curStatus reserveStatus) {
switch curStatus {
// The sessionQueue has available capacity but the task was rejected,
@ -1402,14 +1435,14 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) {
case reserveAvailable:
c.stats.taskIneligible()
c.log.Infof("Ignoring ineligible %v", task.id)
c.log.Infof("Ignoring ineligible %v", task)
err := c.cfg.DB.MarkBackupIneligible(
task.id.ChanID, task.id.CommitHeight,
task.ChanID, task.CommitHeight,
)
if err != nil {
c.log.Errorf("Unable to mark %v ineligible: %v",
task.id, err)
task, err)
// It is safe to not handle this error, even if we could
// not persist the result. At worst, this task may be
@ -1429,7 +1462,7 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) {
c.stats.sessionExhausted()
c.log.Debugf("Session %v exhausted, %v queued for next session",
c.sessionQueue.ID(), task.id)
c.sessionQueue.ID(), task)
// Cache the task that we pulled off, so that we can process it
// once a new session queue is available.
@ -1485,7 +1518,9 @@ func (c *TowerClient) readMessage(peer wtserver.Peer) (wtwire.Message, error) {
}
// sendMessage sends a watchtower wire message to the target peer.
func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error {
func (c *TowerClient) sendMessage(peer wtserver.Peer,
msg wtwire.Message) error {
// Encode the next wire message into the buffer.
// TODO(conner): use buffer pool
var b bytes.Buffer
@ -1521,16 +1556,17 @@ func (c *TowerClient) newSessionQueue(s *ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
return newSessionQueue(&sessionQueueConfig{
ClientSession: s,
ChainHash: c.cfg.ChainHash,
Dial: c.dial,
ReadMessage: c.readMessage,
SendMessage: c.sendMessage,
Signer: c.cfg.Signer,
DB: c.cfg.DB,
MinBackoff: c.cfg.MinBackoff,
MaxBackoff: c.cfg.MaxBackoff,
Log: c.log,
ClientSession: s,
ChainHash: c.cfg.ChainHash,
Dial: c.dial,
ReadMessage: c.readMessage,
SendMessage: c.sendMessage,
Signer: c.cfg.Signer,
DB: c.cfg.DB,
MinBackoff: c.cfg.MinBackoff,
MaxBackoff: c.cfg.MaxBackoff,
Log: c.log,
BuildBreachRetribution: c.cfg.BuildBreachRetribution,
}, updates)
}
@ -1653,7 +1689,9 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error {
// negotiations and from being used for any subsequent backups until it's added
// again. If an address is provided, then this call only serves as a way of
// removing the address from the watchtower instead.
func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey,
addr net.Addr) error {
errChan := make(chan error, 1)
select {
@ -1734,8 +1772,9 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// If our active session queue corresponds to the stale tower, we'll
// proceed to negotiate a new one.
if c.sessionQueue != nil {
activeTower := c.sessionQueue.tower.IdentityKey.SerializeCompressed()
if bytes.Equal(pubKey, activeTower) {
towerKey := c.sessionQueue.tower.IdentityKey
if bytes.Equal(pubKey, towerKey.SerializeCompressed()) {
c.sessionQueue = nil
}
}

View file

@ -36,8 +36,6 @@ import (
)
const (
csvDelay uint32 = 144
towerAddrStr = "18.28.243.2:9911"
towerAddr2Str = "19.29.244.3:9912"
)
@ -73,6 +71,16 @@ var (
addrScript, _ = txscript.PayToAddrScript(addr)
waitTime = 5 * time.Second
defaultTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
}
highSweepRateTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1000000, // The high sweep fee creates dust.
}
)
// randPrivKey generates a new secp keypair, and returns the public key.
@ -509,6 +517,15 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
SessionCloseRange: 1,
}
h.clientCfg.BuildBreachRetribution = func(id lnwire.ChannelID,
commitHeight uint64) (*lnwallet.BreachRetribution,
channeldb.ChannelType, error) {
_, retribution := h.channelFromID(id).getState(commitHeight)
return retribution, channeldb.SingleFunderBit, nil
}
if !cfg.noServerStart {
h.startServer()
t.Cleanup(h.stopServer)
@ -619,6 +636,21 @@ func (h *testHarness) channel(id uint64) *mockChannel {
return c
}
// channelFromID retrieves the channel corresponding to id.
//
// NOTE: The method fails if a channel for id does not exist.
func (h *testHarness) channelFromID(chanID lnwire.ChannelID) *mockChannel {
h.t.Helper()
h.mu.Lock()
defer h.mu.Unlock()
c, ok := h.channels[chanID]
require.Truef(h.t, ok, "unable to fetch channel %s", chanID)
return c
}
// closeChannel marks a channel as closed.
//
// NOTE: The method fails if a channel for id does not exist.
@ -699,10 +731,9 @@ func (h *testHarness) backupState(id, i uint64, expErr error) {
_, retribution := h.channel(id).getState(i)
chanID := chanIDFromInt(id)
err := h.client.BackupState(
&chanID, retribution, channeldb.SingleFunderBit,
)
require.ErrorIs(h.t, err, expErr)
err := h.client.BackupState(&chanID, retribution.RevokedStateNum)
require.ErrorIs(h.t, expErr, err)
}
// sendPayments instructs the channel identified by id to send amt to the remote
@ -823,7 +854,7 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
require.Lenf(h.t, matches, len(hints), "expected: %d matches, got: %d",
len(hints), len(matches))
// Assert that all of the matches correspond to a session with the
// Assert that all the matches correspond to a session with the
// expected policy.
for _, match := range matches {
matchPolicy := match.SessionInfo.Policy
@ -969,11 +1000,6 @@ const (
remoteBalance = lnwire.MilliSatoshi(200000000)
)
var defaultTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
}
type clientTest struct {
name string
cfg harnessCfg
@ -1072,7 +1098,7 @@ var clientTests = []clientTest{
// pipeline is always flushed before it exits.
go h.client.Stop()
// Wait for all of the updates to be populated in the
// Wait for all the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, time.Second)
},
@ -1086,10 +1112,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1000000, // high sweep fee creates dust
},
TxPolicy: highSweepRateTxPolicy,
MaxUpdates: 20000,
},
},
@ -1177,7 +1200,7 @@ var clientTests = []clientTest{
// the tower to receive the remaining states.
h.backupStates(chanID, numSent, numUpdates, nil)
// Wait for all of the updates to be populated in the
// Wait for all the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, time.Second)
@ -1230,7 +1253,7 @@ var clientTests = []clientTest{
h.serverCfg.NoAckUpdates = false
h.startServer()
// Wait for all of the updates to be populated in the
// Wait for all the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, waitTime)
},
@ -1252,9 +1275,11 @@ var clientTests = []clientTest{
},
fn: func(h *testHarness) {
var (
capacity = h.cfg.localBalance + h.cfg.remoteBalance
capacity = h.cfg.localBalance +
h.cfg.remoteBalance
paymentAmt = lnwire.MilliSatoshi(2000000)
numSends = uint64(h.cfg.localBalance / paymentAmt)
numSends = uint64(h.cfg.localBalance) /
uint64(paymentAmt)
numRecvs = uint64(capacity / paymentAmt)
numUpdates = numSends + numRecvs // 200 updates
chanID = uint64(0)
@ -1262,11 +1287,15 @@ var clientTests = []clientTest{
// Send money to the remote party until all funds are
// depleted.
sendHints := h.sendPayments(chanID, 0, numSends, paymentAmt)
sendHints := h.sendPayments(
chanID, 0, numSends, paymentAmt,
)
// Now, sequentially receive the entire channel balance
// from the remote party.
recvHints := h.recvPayments(chanID, numSends, numUpdates, paymentAmt)
recvHints := h.recvPayments(
chanID, numSends, numUpdates, paymentAmt,
)
// Collect the hints generated by both sending and
// receiving.
@ -1275,7 +1304,7 @@ var clientTests = []clientTest{
// Backup the channel's states the client.
h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the
// Wait for all the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, 3*time.Second)
},
@ -1292,10 +1321,7 @@ var clientTests = []clientTest{
},
},
fn: func(h *testHarness) {
const (
numUpdates = 5
numChans = 10
)
const numUpdates = 5
// Initialize and register an additional 9 channels.
for id := uint64(1); id < 10; id++ {
@ -1323,7 +1349,7 @@ var clientTests = []clientTest{
// Test reliable flush under multi-client scenario.
go h.client.Stop()
// Wait for all of the updates to be populated in the
// Wait for all the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, 10*time.Second)
},
@ -1372,7 +1398,7 @@ var clientTests = []clientTest{
// Now, queue the retributions for backup.
h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the
// Wait for all the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, waitTime)
@ -1426,7 +1452,7 @@ var clientTests = []clientTest{
// Now, queue the retributions for backup.
h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the
// Wait for all the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, waitTime)
@ -1469,11 +1495,11 @@ var clientTests = []clientTest{
require.NoError(h.t, h.client.Stop())
// Record the policy that the first half was stored
// under. We'll expect the second half to also be stored
// under the original policy, since we are only adjusting
// the MaxUpdates. The client should detect that the
// two policies have equivalent TxPolicies and continue
// using the first.
// under. We'll expect the second half to also be
// stored under the original policy, since we are only
// adjusting the MaxUpdates. The client should detect
// that the two policies have equivalent TxPolicies and
// continue using the first.
expPolicy := h.clientCfg.Policy
// Restart the client with a new policy.
@ -1483,7 +1509,7 @@ var clientTests = []clientTest{
// Now, queue the second half of the retributions.
h.backupStates(chanID, numUpdates/2, numUpdates, nil)
// Wait for all of the updates to be populated in the
// Wait for all the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, waitTime)
@ -1534,7 +1560,7 @@ var clientTests = []clientTest{
// the second half should actually be sent.
h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the
// Wait for all the updates to be populated in the
// server's database.
h.waitServerUpdates(hints, waitTime)
},
@ -1599,10 +1625,7 @@ var clientTests = []clientTest{
localBalance: localBalance,
remoteBalance: remoteBalance,
policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
TxPolicy: defaultTxPolicy,
MaxUpdates: 5,
},
},
@ -1629,7 +1652,7 @@ var clientTests = []clientTest{
// Back up the remaining two states. Once the first is
// processed, the session will be exhausted but the
// client won't be able to regnegotiate a session for
// client won't be able to renegotiate a session for
// the final state. We'll only wait for the first five
// states to arrive at the tower.
h.backupStates(chanID, maxUpdates-1, numUpdates, nil)

View file

@ -45,7 +45,8 @@ type sessionQueueConfig struct {
Dial func(keychain.SingleKeyECDH, *lnwire.NetAddress) (wtserver.Peer,
error)
// SendMessage encodes, encrypts, and writes a message to the given peer.
// SendMessage encodes, encrypts, and writes a message to the given
// peer.
SendMessage func(wtserver.Peer, wtwire.Message) error
// ReadMessage receives, decypts, and decodes a message from the given
@ -56,6 +57,11 @@ type sessionQueueConfig struct {
// for justice transaction inputs.
Signer input.Signer
// BuildBreachRetribution is a function closure that allows the client
// to fetch the breach retribution info for a certain channel at a
// certain revoked commitment height.
BuildBreachRetribution BreachRetributionBuilder
// DB provides access to the client's stable storage.
DB DB
@ -219,7 +225,10 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
//
// TODO(conner): queue backups and retry with different session params.
case reserveAvailable:
err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody)
err := task.bindSession(
&q.cfg.ClientSession.ClientSessionBody,
q.cfg.BuildBreachRetribution,
)
if err != nil {
q.queueCond.L.Unlock()
q.log.Debugf("SessionQueue(%s) rejected %v: %v ",
@ -343,8 +352,8 @@ func (q *sessionQueue) drainBackups() {
// before attempting to dequeue any pending updates.
stateUpdate, isPending, backupID, err := q.nextStateUpdate()
if err != nil {
q.log.Errorf("SessionQueue(%v) unable to get next state "+
"update: %v", q.ID(), err)
q.log.Errorf("SessionQueue(%v) unable to get next "+
"state update: %v", q.ID(), err)
return
}

View file

@ -6,6 +6,7 @@ import (
"time"
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
)
// taskPipeline implements a reliable, in-order queue that ensures its queue
@ -25,7 +26,7 @@ type taskPipeline struct {
queueCond *sync.Cond
queue *list.List
newBackupTasks chan *backupTask
newBackupTasks chan *wtdb.BackupID
quit chan struct{}
forceQuit chan struct{}
@ -37,7 +38,7 @@ func newTaskPipeline(log btclog.Logger) *taskPipeline {
rq := &taskPipeline{
log: log,
queue: list.New(),
newBackupTasks: make(chan *backupTask),
newBackupTasks: make(chan *wtdb.BackupID),
quit: make(chan struct{}),
forceQuit: make(chan struct{}),
shutdown: make(chan struct{}),
@ -91,7 +92,7 @@ func (q *taskPipeline) ForceQuit() {
// channel will be closed after a call to Stop and all pending tasks have been
// delivered, or if a call to ForceQuit is called before the pending entries
// have been drained.
func (q *taskPipeline) NewBackupTasks() <-chan *backupTask {
func (q *taskPipeline) NewBackupTasks() <-chan *wtdb.BackupID {
return q.newBackupTasks
}
@ -99,7 +100,7 @@ func (q *taskPipeline) NewBackupTasks() <-chan *backupTask {
// of NewBackupTasks. If the taskPipeline is shutting down, ErrClientExiting is
// returned. Otherwise, if QueueBackupTask returns nil it is guaranteed to be
// delivered via NewBackupTasks unless ForceQuit is called before completion.
func (q *taskPipeline) QueueBackupTask(task *backupTask) error {
func (q *taskPipeline) QueueBackupTask(task *wtdb.BackupID) error {
q.queueCond.L.Lock()
select {
@ -116,8 +117,8 @@ func (q *taskPipeline) QueueBackupTask(task *backupTask) error {
default:
}
// Queue the new task and signal the queue's condition variable to wake up
// the queueManager for processing.
// Queue the new task and signal the queue's condition variable to wake
// up the queueManager for processing.
q.queue.PushBack(task)
q.queueCond.L.Unlock()
@ -141,16 +142,21 @@ func (q *taskPipeline) queueManager() {
select {
case <-q.quit:
// Exit only after the queue has been fully drained.
// Exit only after the queue has been fully
// drained.
if q.queue.Len() == 0 {
q.queueCond.L.Unlock()
q.log.Debugf("Revoked state pipeline flushed.")
q.log.Debugf("Revoked state pipeline " +
"flushed.")
return
}
case <-q.forceQuit:
q.queueCond.L.Unlock()
q.log.Debugf("Revoked state pipeline force quit.")
q.log.Debugf("Revoked state pipeline force " +
"quit.")
return
default:
@ -159,13 +165,15 @@ func (q *taskPipeline) queueManager() {
// Pop the first element from the queue.
e := q.queue.Front()
task := q.queue.Remove(e).(*backupTask)
//nolint:forcetypeassert
task := q.queue.Remove(e).(*wtdb.BackupID)
q.queueCond.L.Unlock()
select {
// Backup task submitted to dispatcher. We don't select on quit to
// ensure that we still drain tasks while shutting down.
// Backup task submitted to dispatcher. We don't select on quit
// to ensure that we still drain tasks while shutting down.
case q.newBackupTasks <- task:
// Force quit, return immediately to allow the client to exit.