Merge pull request #7623 from ellemouton/queueBackupIDOnly

wtclient: queue backup id only
This commit is contained in:
Oliver Gugger 2023-04-24 15:36:09 +02:00 committed by GitHub
commit 4355ce62d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 460 additions and 258 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/go-errors/errors" "github.com/go-errors/errors"
mig "github.com/lightningnetwork/lnd/channeldb/migration" mig "github.com/lightningnetwork/lnd/channeldb/migration"
"github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration12"
@ -654,20 +655,94 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) (
// FetchChannel attempts to locate a channel specified by the passed channel // FetchChannel attempts to locate a channel specified by the passed channel
// point. If the channel cannot be found, then an error will be returned. // point. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied. Optionally an existing db tx // Optionally an existing db tx can be supplied.
// can be supplied.
func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
*OpenChannel, error) { *OpenChannel, error) {
var ( var targetChanPoint bytes.Buffer
targetChan *OpenChannel
targetChanPoint bytes.Buffer
)
if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil {
return nil, err return nil, err
} }
targetChanPointBytes := targetChanPoint.Bytes()
selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error) {
return targetChanPointBytes, &chanPoint, nil
}
return c.channelScanner(tx, selector)
}
// FetchChannelByID attempts to locate a channel specified by the passed channel
// ID. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied.
func (c *ChannelStateDB) FetchChannelByID(tx kvdb.RTx, id lnwire.ChannelID) (
*OpenChannel, error) {
selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error) {
var (
targetChanPointBytes []byte
targetChanPoint *wire.OutPoint
// errChanFound is used to signal that the channel has
// been found so that iteration through the DB buckets
// can stop.
errChanFound = errors.New("channel found")
)
err := chainBkt.ForEach(func(k, _ []byte) error {
var outPoint wire.OutPoint
err := readOutpoint(bytes.NewReader(k), &outPoint)
if err != nil {
return err
}
chanID := lnwire.NewChanIDFromOutPoint(&outPoint)
if chanID != id {
return nil
}
targetChanPoint = &outPoint
targetChanPointBytes = k
return errChanFound
})
if err != nil && !errors.Is(err, errChanFound) {
return nil, nil, err
}
if targetChanPoint == nil {
return nil, nil, ErrChannelNotFound
}
return targetChanPointBytes, targetChanPoint, nil
}
return c.channelScanner(tx, selector)
}
// channelSelector describes a function that takes a chain-hash bucket from
// within the open-channel DB and returns the wanted channel point bytes, and
// channel point. It must return the ErrChannelNotFound error if the wanted
// channel is not in the given bucket.
type channelSelector func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error)
// channelScanner will traverse the DB to each chain-hash bucket of each node
// pub-key bucket in the open-channel-bucket. The chanSelector will then be used
// to fetch the wanted channel outpoint from the chain bucket.
func (c *ChannelStateDB) channelScanner(tx kvdb.RTx,
chanSelect channelSelector) (*OpenChannel, error) {
var (
targetChan *OpenChannel
// errChanFound is used to signal that the channel has been
// found so that iteration through the DB buckets can stop.
errChanFound = errors.New("channel found")
)
// chanScan will traverse the following bucket structure: // chanScan will traverse the following bucket structure:
// * nodePub => chainHash => chanPoint // * nodePub => chainHash => chanPoint
// //
@ -685,8 +760,8 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
} }
// Within the node channel bucket, are the set of node pubkeys // Within the node channel bucket, are the set of node pubkeys
// we have channels with, we don't know the entire set, so // we have channels with, we don't know the entire set, so we'll
// we'll check them all. // check them all.
return openChanBucket.ForEach(func(nodePub, v []byte) error { return openChanBucket.ForEach(func(nodePub, v []byte) error {
// Ensure that this is a key the same size as a pubkey, // Ensure that this is a key the same size as a pubkey,
// and also that it leads directly to a bucket. // and also that it leads directly to a bucket.
@ -694,7 +769,9 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
return nil return nil
} }
nodeChanBucket := openChanBucket.NestedReadBucket(nodePub) nodeChanBucket := openChanBucket.NestedReadBucket(
nodePub,
)
if nodeChanBucket == nil { if nodeChanBucket == nil {
return nil return nil
} }
@ -715,20 +792,30 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
) )
if chainBucket == nil { if chainBucket == nil {
return fmt.Errorf("unable to read "+ return fmt.Errorf("unable to read "+
"bucket for chain=%x", chainHash[:]) "bucket for chain=%x",
chainHash)
} }
// Finally we reach the leaf bucket that stores // Finally, we reach the leaf bucket that stores
// all the chanPoints for this node. // all the chanPoints for this node.
targetChanBytes, chanPoint, err := chanSelect(
chainBucket,
)
if errors.Is(err, ErrChannelNotFound) {
return nil
} else if err != nil {
return err
}
chanBucket := chainBucket.NestedReadBucket( chanBucket := chainBucket.NestedReadBucket(
targetChanPoint.Bytes(), targetChanBytes,
) )
if chanBucket == nil { if chanBucket == nil {
return nil return nil
} }
channel, err := fetchOpenChannel( channel, err := fetchOpenChannel(
chanBucket, &chanPoint, chanBucket, chanPoint,
) )
if err != nil { if err != nil {
return err return err
@ -737,7 +824,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
targetChan = channel targetChan = channel
targetChan.Db = c targetChan.Db = c
return nil return errChanFound
}) })
}) })
} }
@ -748,7 +835,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
} else { } else {
err = chanScan(tx) err = chanScan(tx)
} }
if err != nil { if err != nil && !errors.Is(err, errChanFound) {
return nil, err return nil, err
} }
@ -757,7 +844,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
} }
// If we can't find the channel, then we return with an error, as we // If we can't find the channel, then we return with an error, as we
// have nothing to backup. // have nothing to back up.
return nil, ErrChannelNotFound return nil, ErrChannelNotFound
} }

View file

@ -12,7 +12,6 @@ import (
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -238,10 +237,16 @@ func TestFetchChannel(t *testing.T) {
// The decoded channel state should be identical to what we stored // The decoded channel state should be identical to what we stored
// above. // above.
if !reflect.DeepEqual(channelState, dbChannel) { require.Equal(t, channelState, dbChannel)
t.Fatalf("channel state doesn't match:: %v vs %v",
spew.Sdump(channelState), spew.Sdump(dbChannel)) // Next, attempt to fetch the channel by its channel ID.
} chanID := lnwire.NewChanIDFromOutPoint(&channelState.FundingOutpoint)
dbChannel, err = cdb.FetchChannelByID(nil, chanID)
require.NoError(t, err, "unable to fetch channel")
// The decoded channel state should be identical to what we stored
// above.
require.Equal(t, channelState, dbChannel)
// If we attempt to query for a non-existent channel, then we should // If we attempt to query for a non-existent channel, then we should
// get an error. // get an error.
@ -252,9 +257,11 @@ func TestFetchChannel(t *testing.T) {
channelState2.FundingOutpoint.Index = uniqueOutputIndex.Load() channelState2.FundingOutpoint.Index = uniqueOutputIndex.Load()
_, err = cdb.FetchChannel(nil, channelState2.FundingOutpoint) _, err = cdb.FetchChannel(nil, channelState2.FundingOutpoint)
if err == nil { require.ErrorIs(t, err, ErrChannelNotFound)
t.Fatalf("expected query to fail")
} chanID2 := lnwire.NewChanIDFromOutPoint(&channelState2.FundingOutpoint)
_, err = cdb.FetchChannelByID(nil, chanID2)
require.ErrorIs(t, err, ErrChannelNotFound)
} }
func genRandomChannelShell() (*ChannelShell, error) { func genRandomChannelShell() (*ChannelShell, error) {

View file

@ -6,7 +6,15 @@
implementation](https://github.com/lightningnetwork/lnd/pull/7377) logic in implementation](https://github.com/lightningnetwork/lnd/pull/7377) logic in
different update types. different update types.
## Watchtowers
* Let the task pipeline [only carry
wtdb.BackupIDs](https://github.com/lightningnetwork/lnd/pull/7623) instead of
the entire retribution struct. This reduces the amount of data that needs to
be held in memory.
# Contributors (Alphabetical Order) # Contributors (Alphabetical Order)
* Elle Mouton
* Jordi Montes * Jordi Montes

View file

@ -7,7 +7,6 @@ import (
"github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
@ -255,11 +254,8 @@ type TowerClient interface {
// state. If the method returns nil, the backup is guaranteed to be // state. If the method returns nil, the backup is guaranteed to be
// successful unless the tower is unavailable and client is force quit, // successful unless the tower is unavailable and client is force quit,
// or the justice transaction would create dust outputs when trying to // or the justice transaction would create dust outputs when trying to
// abide by the negotiated policy. If the channel we're trying to back // abide by the negotiated policy.
// up doesn't have a tweak for the remote party's output, then BackupState(chanID *lnwire.ChannelID, stateNum uint64) error
// isTweakless should be true.
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution,
channeldb.ChannelType) error
} }
// InterceptableHtlcForwarder is the interface to set the interceptor // InterceptableHtlcForwarder is the interface to set the interceptor

View file

@ -2022,11 +2022,6 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// We've received a revocation from the remote chain, if valid, // We've received a revocation from the remote chain, if valid,
// this moves the remote chain forward, and expands our // this moves the remote chain forward, and expands our
// revocation window. // revocation window.
//
// Before advancing our remote chain, we will record the
// current commit tx, which is used by the TowerClient to
// create backups.
oldCommitTx := l.channel.State().RemoteCommitment.CommitTx
// We now process the message and advance our remote commit // We now process the message and advance our remote commit
// chain. // chain.
@ -2063,24 +2058,15 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// create a backup for the current state. // create a backup for the current state.
if l.cfg.TowerClient != nil { if l.cfg.TowerClient != nil {
state := l.channel.State() state := l.channel.State()
breachInfo, err := lnwallet.NewBreachRetribution(
state, state.RemoteCommitment.CommitHeight-1, 0,
// OldCommitTx is the breaching tx at height-1.
oldCommitTx,
)
if err != nil {
l.fail(LinkFailureError{code: ErrInternalError},
"failed to load breach info: %v", err)
return
}
chanID := l.ChanID() chanID := l.ChanID()
err = l.cfg.TowerClient.BackupState( err = l.cfg.TowerClient.BackupState(
&chanID, breachInfo, state.ChanType, &chanID, state.RemoteCommitment.CommitHeight-1,
) )
if err != nil { if err != nil {
l.fail(LinkFailureError{code: ErrInternalError}, l.fail(LinkFailureError{code: ErrInternalError},
"unable to queue breach backup: %v", err) "unable to queue breach backup: %v",
err)
return return
} }
} }

View file

@ -1523,12 +1523,37 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
) )
} }
// buildBreachRetribution is a call-back that can be used to
// query the BreachRetribution info and channel type given a
// channel ID and commitment height.
buildBreachRetribution := func(chanID lnwire.ChannelID,
commitHeight uint64) (*lnwallet.BreachRetribution,
channeldb.ChannelType, error) {
channel, err := s.chanStateDB.FetchChannelByID(
nil, chanID,
)
if err != nil {
return nil, 0, err
}
br, err := lnwallet.NewBreachRetribution(
channel, commitHeight, 0, nil,
)
if err != nil {
return nil, 0, err
}
return br, channel.ChanType, nil
}
fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID
s.towerClient, err = wtclient.New(&wtclient.Config{ s.towerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel, FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange, BuildBreachRetribution: buildBreachRetribution,
ChainNotifier: s.cc.ChainNotifier, SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription, SubscribeChannelEvents: func() (subscribe.Subscription,
error) { error) {
@ -1558,9 +1583,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
blob.Type(blob.FlagAnchorChannel) blob.Type(blob.FlagAnchorChannel)
s.anchorTowerClient, err = wtclient.New(&wtclient.Config{ s.anchorTowerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel, FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange, BuildBreachRetribution: buildBreachRetribution,
ChainNotifier: s.cc.ChainNotifier, SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription, SubscribeChannelEvents: func() (subscribe.Subscription,
error) { error) {

View file

@ -10,7 +10,6 @@ import (
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -40,7 +39,6 @@ import (
type backupTask struct { type backupTask struct {
id wtdb.BackupID id wtdb.BackupID
breachInfo *lnwallet.BreachRetribution breachInfo *lnwallet.BreachRetribution
chanType channeldb.ChannelType
// state-dependent variables // state-dependent variables
@ -55,11 +53,79 @@ type backupTask struct {
outputs []*wire.TxOut outputs []*wire.TxOut
} }
// newBackupTask initializes a new backupTask and populates all state-dependent // newBackupTask initializes a new backupTask.
// variables. func newBackupTask(id wtdb.BackupID, sweepPkScript []byte) *backupTask {
func newBackupTask(chanID *lnwire.ChannelID, return &backupTask{
breachInfo *lnwallet.BreachRetribution, id: id,
sweepPkScript []byte, chanType channeldb.ChannelType) *backupTask { sweepPkScript: sweepPkScript,
}
}
// inputs returns all non-dust inputs that we will attempt to spend from.
//
// NOTE: Ordering of the inputs is not critical as we sort the transaction with
// BIP69 in a later stage.
func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
inputs := make(map[wire.OutPoint]input.Input)
if t.toLocalInput != nil {
inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput
}
if t.toRemoteInput != nil {
inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput
}
return inputs
}
// addrType returns the type of an address after parsing it and matching it to
// the set of known script templates.
func addrType(pkScript []byte) txscript.ScriptClass {
// We pass in a set of dummy chain params here as they're only needed
// to make the address struct, which we're ignoring anyway (scripts are
// always the same, it's addresses that change across chains).
scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs(
pkScript, &chaincfg.MainNetParams,
)
return scriptClass
}
// addScriptWeight parses the passed pkScript and adds the computed weight cost
// were the script to be added to the justice transaction.
func addScriptWeight(weightEstimate *input.TxWeightEstimator,
pkScript []byte) error {
switch addrType(pkScript) {
case txscript.WitnessV0PubKeyHashTy:
weightEstimate.AddP2WKHOutput()
case txscript.WitnessV0ScriptHashTy:
weightEstimate.AddP2WSHOutput()
case txscript.WitnessV1TaprootTy:
weightEstimate.AddP2TROutput()
default:
return fmt.Errorf("invalid addr type: %v", addrType(pkScript))
}
return nil
}
// bindSession first populates all state-dependent variables of the task. Then
// it determines if the backupTask is compatible with the passed SessionInfo's
// policy. If no error is returned, the task has been bound to the session and
// can be queued to upload to the tower. Otherwise, the bind failed and should
// be rescheduled with a different session.
func (t *backupTask) bindSession(session *wtdb.ClientSessionBody,
newBreachRetribution BreachRetributionBuilder) error {
breachInfo, chanType, err := newBreachRetribution(
t.id.ChanID, t.id.CommitHeight,
)
if err != nil {
return err
}
// Parse the non-dust outputs from the breach transaction, // Parse the non-dust outputs from the breach transaction,
// simultaneously computing the total amount contained in the inputs // simultaneously computing the total amount contained in the inputs
@ -123,76 +189,11 @@ func newBackupTask(chanID *lnwire.ChannelID,
totalAmt += breachInfo.LocalOutputSignDesc.Output.Value totalAmt += breachInfo.LocalOutputSignDesc.Output.Value
} }
return &backupTask{ t.breachInfo = breachInfo
id: wtdb.BackupID{ t.toLocalInput = toLocalInput
ChanID: *chanID, t.toRemoteInput = toRemoteInput
CommitHeight: breachInfo.RevokedStateNum, t.totalAmt = btcutil.Amount(totalAmt)
},
breachInfo: breachInfo,
chanType: chanType,
toLocalInput: toLocalInput,
toRemoteInput: toRemoteInput,
totalAmt: btcutil.Amount(totalAmt),
sweepPkScript: sweepPkScript,
}
}
// inputs returns all non-dust inputs that we will attempt to spend from.
//
// NOTE: Ordering of the inputs is not critical as we sort the transaction with
// BIP69.
func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
inputs := make(map[wire.OutPoint]input.Input)
if t.toLocalInput != nil {
inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput
}
if t.toRemoteInput != nil {
inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput
}
return inputs
}
// addrType returns the type of an address after parsing it and matching it to
// the set of known script templates.
func addrType(pkScript []byte) txscript.ScriptClass {
// We pass in a set of dummy chain params here as they're only needed
// to make the address struct, which we're ignoring anyway (scripts are
// always the same, it's addresses that change across chains).
scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs(
pkScript, &chaincfg.MainNetParams,
)
return scriptClass
}
// addScriptWeight parses the passed pkScript and adds the computed weight cost
// were the script to be added to the justice transaction.
func addScriptWeight(weightEstimate *input.TxWeightEstimator,
pkScript []byte) error {
switch addrType(pkScript) { //nolint: whitespace
case txscript.WitnessV0PubKeyHashTy:
weightEstimate.AddP2WKHOutput()
case txscript.WitnessV0ScriptHashTy:
weightEstimate.AddP2WSHOutput()
case txscript.WitnessV1TaprootTy:
weightEstimate.AddP2TROutput()
default:
return fmt.Errorf("invalid addr type: %v", addrType(pkScript))
}
return nil
}
// bindSession determines if the backupTask is compatible with the passed
// SessionInfo's policy. If no error is returned, the task has been bound to the
// session and can be queued to upload to the tower. Otherwise, the bind failed
// and should be rescheduled with a different session.
func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// First we'll begin by deriving a weight estimate for the justice // First we'll begin by deriving a weight estimate for the justice
// transaction. The final weight can be different depending on whether // transaction. The final weight can be different depending on whether
// the watchtower is taking a reward. // the watchtower is taking a reward.
@ -208,7 +209,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// original weight estimate. For anchor channels we'll go ahead // original weight estimate. For anchor channels we'll go ahead
// an use the correct penalty witness when signing our justice // an use the correct penalty witness when signing our justice
// transactions. // transactions.
if t.chanType.HasAnchors() { if chanType.HasAnchors() {
weightEstimate.AddWitnessInput( weightEstimate.AddWitnessInput(
input.ToLocalPenaltyWitnessSize, input.ToLocalPenaltyWitnessSize,
) )
@ -222,8 +223,10 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// Legacy channels (both tweaked and non-tweaked) spend from // Legacy channels (both tweaked and non-tweaked) spend from
// P2WKH output. Anchor channels spend a to-remote confirmed // P2WKH output. Anchor channels spend a to-remote confirmed
// P2WSH output. // P2WSH output.
if t.chanType.HasAnchors() { if chanType.HasAnchors() {
weightEstimate.AddWitnessInput(input.ToRemoteConfirmedWitnessSize) weightEstimate.AddWitnessInput(
input.ToRemoteConfirmedWitnessSize,
)
} else { } else {
weightEstimate.AddWitnessInput(input.P2WKHWitnessSize) weightEstimate.AddWitnessInput(input.P2WKHWitnessSize)
} }
@ -231,7 +234,8 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// All justice transactions will either use segwit v0 (p2wkh + p2wsh) // All justice transactions will either use segwit v0 (p2wkh + p2wsh)
// or segwit v1 (p2tr). // or segwit v1 (p2tr).
if err := addScriptWeight(&weightEstimate, t.sweepPkScript); err != nil { err = addScriptWeight(&weightEstimate, t.sweepPkScript)
if err != nil {
return err return err
} }
@ -244,9 +248,9 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
} }
} }
if t.chanType.HasAnchors() != session.Policy.IsAnchorChannel() { if chanType.HasAnchors() != session.Policy.IsAnchorChannel() {
log.Criticalf("Invalid task (has_anchors=%t) for session "+ log.Criticalf("Invalid task (has_anchors=%t) for session "+
"(has_anchors=%t)", t.chanType.HasAnchors(), "(has_anchors=%t)", chanType.HasAnchors(),
session.Policy.IsAnchorChannel()) session.Policy.IsAnchorChannel())
} }

View file

@ -483,19 +483,19 @@ func TestBackupTask(t *testing.T) {
func testBackupTask(t *testing.T, test backupTaskTest) { func testBackupTask(t *testing.T, test backupTaskTest) {
// Create a new backupTask from the channel id and breach info. // Create a new backupTask from the channel id and breach info.
task := newBackupTask( id := wtdb.BackupID{
&test.chanID, test.breachInfo, test.expSweepScript, ChanID: test.chanID,
test.chanType, CommitHeight: test.breachInfo.RevokedStateNum,
) }
task := newBackupTask(id, test.expSweepScript)
// Assert that all parameters set during initialization are properly // getBreachInfo is a helper closure that returns the breach retribution
// populated. // info and channel type for the given channel and commit height.
require.Equal(t, test.chanID, task.id.ChanID) getBreachInfo := func(id lnwire.ChannelID, commitHeight uint64) (
require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight) *lnwallet.BreachRetribution, channeldb.ChannelType, error) {
require.Equal(t, test.expTotalAmt, task.totalAmt)
require.Equal(t, test.breachInfo, task.breachInfo) return test.breachInfo, test.chanType, nil
require.Equal(t, test.expToLocalInput, task.toLocalInput) }
require.Equal(t, test.expToRemoteInput, task.toRemoteInput)
// Reconstruct the expected input.Inputs that will be returned by the // Reconstruct the expected input.Inputs that will be returned by the
// task's inputs() method. // task's inputs() method.
@ -515,9 +515,18 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
// Now, bind the session to the task. If successful, this locks in the // Now, bind the session to the task. If successful, this locks in the
// session's negotiated parameters and allows the backup task to derive // session's negotiated parameters and allows the backup task to derive
// the final free variables in the justice transaction. // the final free variables in the justice transaction.
err := task.bindSession(test.session) err := task.bindSession(test.session, getBreachInfo)
require.ErrorIs(t, err, test.bindErr) require.ErrorIs(t, err, test.bindErr)
// Assert that all parameters set during after binding the backup task
// are properly populated.
require.Equal(t, test.chanID, task.id.ChanID)
require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight)
require.Equal(t, test.expTotalAmt, task.totalAmt)
require.Equal(t, test.breachInfo, task.breachInfo)
require.Equal(t, test.expToLocalInput, task.toLocalInput)
require.Equal(t, test.expToRemoteInput, task.toRemoteInput)
// Exit early if the bind was supposed to fail. But first, we check that // Exit early if the bind was supposed to fail. But first, we check that
// all fields set during a bind are still unset. This ensure that a // all fields set during a bind are still unset. This ensure that a
// failed bind doesn't have side-effects if the task is retried with a // failed bind doesn't have side-effects if the task is retried with a

View file

@ -34,8 +34,8 @@ const (
// a read before breaking out of a blocking read. // a read before breaking out of a blocking read.
DefaultReadTimeout = 15 * time.Second DefaultReadTimeout = 15 * time.Second
// DefaultWriteTimeout specifies the default duration we will wait during // DefaultWriteTimeout specifies the default duration we will wait
// a write before breaking out of a blocking write. // during a write before breaking out of a blocking write.
DefaultWriteTimeout = 15 * time.Second DefaultWriteTimeout = 15 * time.Second
// DefaultStatInterval specifies the default interval between logging // DefaultStatInterval specifies the default interval between logging
@ -136,11 +136,8 @@ type Client interface {
// state. If the method returns nil, the backup is guaranteed to be // state. If the method returns nil, the backup is guaranteed to be
// successful unless the client is force quit, or the justice // successful unless the client is force quit, or the justice
// transaction would create dust outputs when trying to abide by the // transaction would create dust outputs when trying to abide by the
// negotiated policy. If the channel we're trying to back up doesn't // negotiated policy.
// have a tweak for the remote party's output, then isTweakless should BackupState(chanID *lnwire.ChannelID, stateNum uint64) error
// be true.
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution,
channeldb.ChannelType) error
// Start initializes the watchtower client, allowing it process requests // Start initializes the watchtower client, allowing it process requests
// to backup revoked channel states. // to backup revoked channel states.
@ -178,6 +175,11 @@ type Config struct {
// ChainNotifier can be used to subscribe to block notifications. // ChainNotifier can be used to subscribe to block notifications.
ChainNotifier chainntnfs.ChainNotifier ChainNotifier chainntnfs.ChainNotifier
// BuildBreachRetribution is a function closure that allows the client
// fetch the breach retribution info for a certain channel at a certain
// revoked commitment height.
BuildBreachRetribution BreachRetributionBuilder
// NewAddress generates a new on-chain sweep pkscript. // NewAddress generates a new on-chain sweep pkscript.
NewAddress func() ([]byte, error) NewAddress func() ([]byte, error)
@ -240,6 +242,12 @@ type Config struct {
SessionCloseRange uint32 SessionCloseRange uint32
} }
// BreachRetributionBuilder is a function that can be used to construct a
// BreachRetribution from a channel ID and a commitment height.
type BreachRetributionBuilder func(id lnwire.ChannelID,
commitHeight uint64) (*lnwallet.BreachRetribution,
channeldb.ChannelType, error)
// newTowerMsg is an internal message we'll use within the TowerClient to signal // newTowerMsg is an internal message we'll use within the TowerClient to signal
// that a new tower can be considered. // that a new tower can be considered.
type newTowerMsg struct { type newTowerMsg struct {
@ -293,7 +301,7 @@ type TowerClient struct {
activeSessions sessionQueueSet activeSessions sessionQueueSet
sessionQueue *sessionQueue sessionQueue *sessionQueue
prevTask *backupTask prevTask *wtdb.BackupID
closableSessionQueue *sessionCloseMinHeap closableSessionQueue *sessionCloseMinHeap
@ -378,6 +386,9 @@ func New(config *Config) (*TowerClient, error) {
return return
} }
c.backupMu.Lock()
defer c.backupMu.Unlock()
// Take the highest commit height found in the session's acked // Take the highest commit height found in the session's acked
// updates. // updates.
height, ok := c.chanCommitHeights[chanID] height, ok := c.chanCommitHeights[chanID]
@ -558,8 +569,11 @@ func (c *TowerClient) Start() error {
// committed but unacked state updates. This ensures that these // committed but unacked state updates. This ensures that these
// sessions will be able to flush the committed updates after a // sessions will be able to flush the committed updates after a
// restart. // restart.
fetchCommittedUpdates := c.cfg.DB.FetchSessionCommittedUpdates
for _, session := range c.candidateSessions { for _, session := range c.candidateSessions {
committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID) committedUpdates, err := fetchCommittedUpdates(
&session.ID,
)
if err != nil { if err != nil {
returnErr = err returnErr = err
return return
@ -788,40 +802,41 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
// - client is force quit, // - client is force quit,
// - justice transaction would create dust outputs when trying to abide by the // - justice transaction would create dust outputs when trying to abide by the
// negotiated policy, or // negotiated policy, or
// - breached outputs contain too little value to sweep at the target sweep fee // - breached outputs contain too little value to sweep at the target sweep
// rate. // fee rate.
func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
breachInfo *lnwallet.BreachRetribution, stateNum uint64) error {
chanType channeldb.ChannelType) error {
// Retrieve the cached sweep pkscript used for this channel. // Make sure that this channel is registered with the tower client.
c.backupMu.Lock() c.backupMu.Lock()
summary, ok := c.summaries[*chanID] if _, ok := c.summaries[*chanID]; !ok {
if !ok {
c.backupMu.Unlock() c.backupMu.Unlock()
return ErrUnregisteredChannel return ErrUnregisteredChannel
} }
// Ignore backups that have already been presented to the client. // Ignore backups that have already been presented to the client.
height, ok := c.chanCommitHeights[*chanID] height, ok := c.chanCommitHeights[*chanID]
if ok && breachInfo.RevokedStateNum <= height { if ok && stateNum <= height {
c.backupMu.Unlock() c.backupMu.Unlock()
c.log.Debugf("Ignoring duplicate backup for chanid=%v at height=%d", c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+
chanID, breachInfo.RevokedStateNum) "height=%d", chanID, stateNum)
return nil return nil
} }
// This backup has a higher commit height than any known backup for this // This backup has a higher commit height than any known backup for this
// channel. We'll update our tip so that we won't accept it again if the // channel. We'll update our tip so that we won't accept it again if the
// link flaps. // link flaps.
c.chanCommitHeights[*chanID] = breachInfo.RevokedStateNum c.chanCommitHeights[*chanID] = stateNum
c.backupMu.Unlock() c.backupMu.Unlock()
task := newBackupTask( id := &wtdb.BackupID{
chanID, breachInfo, summary.SweepPkScript, chanType, ChanID: *chanID,
) CommitHeight: stateNum,
}
return c.pipeline.QueueBackupTask(task) return c.pipeline.QueueBackupTask(id)
} }
// nextSessionQueue attempts to fetch an active session from our set of // nextSessionQueue attempts to fetch an active session from our set of
@ -1315,7 +1330,7 @@ func (c *TowerClient) backupDispatcher() {
return return
} }
c.log.Debugf("Processing %v", task.id) c.log.Debugf("Processing %v", task)
c.stats.taskReceived() c.stats.taskReceived()
c.processTask(task) c.processTask(task)
@ -1345,8 +1360,22 @@ func (c *TowerClient) backupDispatcher() {
// sessionQueue hasn't been exhausted before proceeding to the next task. Tasks // sessionQueue hasn't been exhausted before proceeding to the next task. Tasks
// that are rejected because the active sessionQueue is full will be cached as // that are rejected because the active sessionQueue is full will be cached as
// the prevTask, and should be reprocessed after obtaining a new sessionQueue. // the prevTask, and should be reprocessed after obtaining a new sessionQueue.
func (c *TowerClient) processTask(task *backupTask) { func (c *TowerClient) processTask(task *wtdb.BackupID) {
status, accepted := c.sessionQueue.AcceptTask(task) c.backupMu.Lock()
summary, ok := c.summaries[task.ChanID]
if !ok {
c.backupMu.Unlock()
log.Infof("not processing task for unregistered channel: %s",
task.ChanID)
return
}
c.backupMu.Unlock()
backupTask := newBackupTask(*task, summary.SweepPkScript)
status, accepted := c.sessionQueue.AcceptTask(backupTask)
if accepted { if accepted {
c.taskAccepted(task, status) c.taskAccepted(task, status)
} else { } else {
@ -1359,9 +1388,11 @@ func (c *TowerClient) processTask(task *backupTask) {
// prevTask is always removed as a result of this call. The client's // prevTask is always removed as a result of this call. The client's
// sessionQueue will be removed if accepting the task left the sessionQueue in // sessionQueue will be removed if accepting the task left the sessionQueue in
// an exhausted state. // an exhausted state.
func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) { func (c *TowerClient) taskAccepted(task *wtdb.BackupID,
c.log.Infof("Queued %v successfully for session %v", newStatus reserveStatus) {
task.id, c.sessionQueue.ID())
c.log.Infof("Queued %v successfully for session %v", task,
c.sessionQueue.ID())
c.stats.taskAccepted() c.stats.taskAccepted()
@ -1394,7 +1425,9 @@ func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) {
// the sessionQueue to find a new session. If the sessionQueue was not // the sessionQueue to find a new session. If the sessionQueue was not
// exhausted, the client marks the task as ineligible, as this implies we // exhausted, the client marks the task as ineligible, as this implies we
// couldn't construct a valid justice transaction given the session's policy. // couldn't construct a valid justice transaction given the session's policy.
func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) { func (c *TowerClient) taskRejected(task *wtdb.BackupID,
curStatus reserveStatus) {
switch curStatus { switch curStatus {
// The sessionQueue has available capacity but the task was rejected, // The sessionQueue has available capacity but the task was rejected,
@ -1402,14 +1435,14 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) {
case reserveAvailable: case reserveAvailable:
c.stats.taskIneligible() c.stats.taskIneligible()
c.log.Infof("Ignoring ineligible %v", task.id) c.log.Infof("Ignoring ineligible %v", task)
err := c.cfg.DB.MarkBackupIneligible( err := c.cfg.DB.MarkBackupIneligible(
task.id.ChanID, task.id.CommitHeight, task.ChanID, task.CommitHeight,
) )
if err != nil { if err != nil {
c.log.Errorf("Unable to mark %v ineligible: %v", c.log.Errorf("Unable to mark %v ineligible: %v",
task.id, err) task, err)
// It is safe to not handle this error, even if we could // It is safe to not handle this error, even if we could
// not persist the result. At worst, this task may be // not persist the result. At worst, this task may be
@ -1429,7 +1462,7 @@ func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) {
c.stats.sessionExhausted() c.stats.sessionExhausted()
c.log.Debugf("Session %v exhausted, %v queued for next session", c.log.Debugf("Session %v exhausted, %v queued for next session",
c.sessionQueue.ID(), task.id) c.sessionQueue.ID(), task)
// Cache the task that we pulled off, so that we can process it // Cache the task that we pulled off, so that we can process it
// once a new session queue is available. // once a new session queue is available.
@ -1485,7 +1518,9 @@ func (c *TowerClient) readMessage(peer wtserver.Peer) (wtwire.Message, error) {
} }
// sendMessage sends a watchtower wire message to the target peer. // sendMessage sends a watchtower wire message to the target peer.
func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error { func (c *TowerClient) sendMessage(peer wtserver.Peer,
msg wtwire.Message) error {
// Encode the next wire message into the buffer. // Encode the next wire message into the buffer.
// TODO(conner): use buffer pool // TODO(conner): use buffer pool
var b bytes.Buffer var b bytes.Buffer
@ -1521,16 +1556,17 @@ func (c *TowerClient) newSessionQueue(s *ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue { updates []wtdb.CommittedUpdate) *sessionQueue {
return newSessionQueue(&sessionQueueConfig{ return newSessionQueue(&sessionQueueConfig{
ClientSession: s, ClientSession: s,
ChainHash: c.cfg.ChainHash, ChainHash: c.cfg.ChainHash,
Dial: c.dial, Dial: c.dial,
ReadMessage: c.readMessage, ReadMessage: c.readMessage,
SendMessage: c.sendMessage, SendMessage: c.sendMessage,
Signer: c.cfg.Signer, Signer: c.cfg.Signer,
DB: c.cfg.DB, DB: c.cfg.DB,
MinBackoff: c.cfg.MinBackoff, MinBackoff: c.cfg.MinBackoff,
MaxBackoff: c.cfg.MaxBackoff, MaxBackoff: c.cfg.MaxBackoff,
Log: c.log, Log: c.log,
BuildBreachRetribution: c.cfg.BuildBreachRetribution,
}, updates) }, updates)
} }
@ -1653,7 +1689,9 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error {
// negotiations and from being used for any subsequent backups until it's added // negotiations and from being used for any subsequent backups until it's added
// again. If an address is provided, then this call only serves as a way of // again. If an address is provided, then this call only serves as a way of
// removing the address from the watchtower instead. // removing the address from the watchtower instead.
func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey,
addr net.Addr) error {
errChan := make(chan error, 1) errChan := make(chan error, 1)
select { select {
@ -1734,8 +1772,9 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// If our active session queue corresponds to the stale tower, we'll // If our active session queue corresponds to the stale tower, we'll
// proceed to negotiate a new one. // proceed to negotiate a new one.
if c.sessionQueue != nil { if c.sessionQueue != nil {
activeTower := c.sessionQueue.tower.IdentityKey.SerializeCompressed() towerKey := c.sessionQueue.tower.IdentityKey
if bytes.Equal(pubKey, activeTower) {
if bytes.Equal(pubKey, towerKey.SerializeCompressed()) {
c.sessionQueue = nil c.sessionQueue = nil
} }
} }

View file

@ -36,8 +36,6 @@ import (
) )
const ( const (
csvDelay uint32 = 144
towerAddrStr = "18.28.243.2:9911" towerAddrStr = "18.28.243.2:9911"
towerAddr2Str = "19.29.244.3:9912" towerAddr2Str = "19.29.244.3:9912"
) )
@ -73,6 +71,16 @@ var (
addrScript, _ = txscript.PayToAddrScript(addr) addrScript, _ = txscript.PayToAddrScript(addr)
waitTime = 5 * time.Second waitTime = 5 * time.Second
defaultTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
}
highSweepRateTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1000000, // The high sweep fee creates dust.
}
) )
// randPrivKey generates a new secp keypair, and returns the public key. // randPrivKey generates a new secp keypair, and returns the public key.
@ -509,6 +517,15 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
SessionCloseRange: 1, SessionCloseRange: 1,
} }
h.clientCfg.BuildBreachRetribution = func(id lnwire.ChannelID,
commitHeight uint64) (*lnwallet.BreachRetribution,
channeldb.ChannelType, error) {
_, retribution := h.channelFromID(id).getState(commitHeight)
return retribution, channeldb.SingleFunderBit, nil
}
if !cfg.noServerStart { if !cfg.noServerStart {
h.startServer() h.startServer()
t.Cleanup(h.stopServer) t.Cleanup(h.stopServer)
@ -619,6 +636,21 @@ func (h *testHarness) channel(id uint64) *mockChannel {
return c return c
} }
// channelFromID retrieves the channel corresponding to id.
//
// NOTE: The method fails if a channel for id does not exist.
func (h *testHarness) channelFromID(chanID lnwire.ChannelID) *mockChannel {
h.t.Helper()
h.mu.Lock()
defer h.mu.Unlock()
c, ok := h.channels[chanID]
require.Truef(h.t, ok, "unable to fetch channel %s", chanID)
return c
}
// closeChannel marks a channel as closed. // closeChannel marks a channel as closed.
// //
// NOTE: The method fails if a channel for id does not exist. // NOTE: The method fails if a channel for id does not exist.
@ -699,10 +731,9 @@ func (h *testHarness) backupState(id, i uint64, expErr error) {
_, retribution := h.channel(id).getState(i) _, retribution := h.channel(id).getState(i)
chanID := chanIDFromInt(id) chanID := chanIDFromInt(id)
err := h.client.BackupState(
&chanID, retribution, channeldb.SingleFunderBit, err := h.client.BackupState(&chanID, retribution.RevokedStateNum)
) require.ErrorIs(h.t, expErr, err)
require.ErrorIs(h.t, err, expErr)
} }
// sendPayments instructs the channel identified by id to send amt to the remote // sendPayments instructs the channel identified by id to send amt to the remote
@ -823,7 +854,7 @@ func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
require.Lenf(h.t, matches, len(hints), "expected: %d matches, got: %d", require.Lenf(h.t, matches, len(hints), "expected: %d matches, got: %d",
len(hints), len(matches)) len(hints), len(matches))
// Assert that all of the matches correspond to a session with the // Assert that all the matches correspond to a session with the
// expected policy. // expected policy.
for _, match := range matches { for _, match := range matches {
matchPolicy := match.SessionInfo.Policy matchPolicy := match.SessionInfo.Policy
@ -969,11 +1000,6 @@ const (
remoteBalance = lnwire.MilliSatoshi(200000000) remoteBalance = lnwire.MilliSatoshi(200000000)
) )
var defaultTxPolicy = wtpolicy.TxPolicy{
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
}
type clientTest struct { type clientTest struct {
name string name string
cfg harnessCfg cfg harnessCfg
@ -1072,7 +1098,7 @@ var clientTests = []clientTest{
// pipeline is always flushed before it exits. // pipeline is always flushed before it exits.
go h.client.Stop() go h.client.Stop()
// Wait for all of the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.waitServerUpdates(hints, time.Second) h.waitServerUpdates(hints, time.Second)
}, },
@ -1086,10 +1112,7 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{ TxPolicy: highSweepRateTxPolicy,
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: 1000000, // high sweep fee creates dust
},
MaxUpdates: 20000, MaxUpdates: 20000,
}, },
}, },
@ -1177,7 +1200,7 @@ var clientTests = []clientTest{
// the tower to receive the remaining states. // the tower to receive the remaining states.
h.backupStates(chanID, numSent, numUpdates, nil) h.backupStates(chanID, numSent, numUpdates, nil)
// Wait for all of the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.waitServerUpdates(hints, time.Second) h.waitServerUpdates(hints, time.Second)
@ -1230,7 +1253,7 @@ var clientTests = []clientTest{
h.serverCfg.NoAckUpdates = false h.serverCfg.NoAckUpdates = false
h.startServer() h.startServer()
// Wait for all of the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.waitServerUpdates(hints, waitTime) h.waitServerUpdates(hints, waitTime)
}, },
@ -1252,9 +1275,11 @@ var clientTests = []clientTest{
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {
var ( var (
capacity = h.cfg.localBalance + h.cfg.remoteBalance capacity = h.cfg.localBalance +
h.cfg.remoteBalance
paymentAmt = lnwire.MilliSatoshi(2000000) paymentAmt = lnwire.MilliSatoshi(2000000)
numSends = uint64(h.cfg.localBalance / paymentAmt) numSends = uint64(h.cfg.localBalance) /
uint64(paymentAmt)
numRecvs = uint64(capacity / paymentAmt) numRecvs = uint64(capacity / paymentAmt)
numUpdates = numSends + numRecvs // 200 updates numUpdates = numSends + numRecvs // 200 updates
chanID = uint64(0) chanID = uint64(0)
@ -1262,11 +1287,15 @@ var clientTests = []clientTest{
// Send money to the remote party until all funds are // Send money to the remote party until all funds are
// depleted. // depleted.
sendHints := h.sendPayments(chanID, 0, numSends, paymentAmt) sendHints := h.sendPayments(
chanID, 0, numSends, paymentAmt,
)
// Now, sequentially receive the entire channel balance // Now, sequentially receive the entire channel balance
// from the remote party. // from the remote party.
recvHints := h.recvPayments(chanID, numSends, numUpdates, paymentAmt) recvHints := h.recvPayments(
chanID, numSends, numUpdates, paymentAmt,
)
// Collect the hints generated by both sending and // Collect the hints generated by both sending and
// receiving. // receiving.
@ -1275,7 +1304,7 @@ var clientTests = []clientTest{
// Backup the channel's states the client. // Backup the channel's states the client.
h.backupStates(chanID, 0, numUpdates, nil) h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.waitServerUpdates(hints, 3*time.Second) h.waitServerUpdates(hints, 3*time.Second)
}, },
@ -1292,10 +1321,7 @@ var clientTests = []clientTest{
}, },
}, },
fn: func(h *testHarness) { fn: func(h *testHarness) {
const ( const numUpdates = 5
numUpdates = 5
numChans = 10
)
// Initialize and register an additional 9 channels. // Initialize and register an additional 9 channels.
for id := uint64(1); id < 10; id++ { for id := uint64(1); id < 10; id++ {
@ -1323,7 +1349,7 @@ var clientTests = []clientTest{
// Test reliable flush under multi-client scenario. // Test reliable flush under multi-client scenario.
go h.client.Stop() go h.client.Stop()
// Wait for all of the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.waitServerUpdates(hints, 10*time.Second) h.waitServerUpdates(hints, 10*time.Second)
}, },
@ -1372,7 +1398,7 @@ var clientTests = []clientTest{
// Now, queue the retributions for backup. // Now, queue the retributions for backup.
h.backupStates(chanID, 0, numUpdates, nil) h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.waitServerUpdates(hints, waitTime) h.waitServerUpdates(hints, waitTime)
@ -1426,7 +1452,7 @@ var clientTests = []clientTest{
// Now, queue the retributions for backup. // Now, queue the retributions for backup.
h.backupStates(chanID, 0, numUpdates, nil) h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.waitServerUpdates(hints, waitTime) h.waitServerUpdates(hints, waitTime)
@ -1469,11 +1495,11 @@ var clientTests = []clientTest{
require.NoError(h.t, h.client.Stop()) require.NoError(h.t, h.client.Stop())
// Record the policy that the first half was stored // Record the policy that the first half was stored
// under. We'll expect the second half to also be stored // under. We'll expect the second half to also be
// under the original policy, since we are only adjusting // stored under the original policy, since we are only
// the MaxUpdates. The client should detect that the // adjusting the MaxUpdates. The client should detect
// two policies have equivalent TxPolicies and continue // that the two policies have equivalent TxPolicies and
// using the first. // continue using the first.
expPolicy := h.clientCfg.Policy expPolicy := h.clientCfg.Policy
// Restart the client with a new policy. // Restart the client with a new policy.
@ -1483,7 +1509,7 @@ var clientTests = []clientTest{
// Now, queue the second half of the retributions. // Now, queue the second half of the retributions.
h.backupStates(chanID, numUpdates/2, numUpdates, nil) h.backupStates(chanID, numUpdates/2, numUpdates, nil)
// Wait for all of the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.waitServerUpdates(hints, waitTime) h.waitServerUpdates(hints, waitTime)
@ -1534,7 +1560,7 @@ var clientTests = []clientTest{
// the second half should actually be sent. // the second half should actually be sent.
h.backupStates(chanID, 0, numUpdates, nil) h.backupStates(chanID, 0, numUpdates, nil)
// Wait for all of the updates to be populated in the // Wait for all the updates to be populated in the
// server's database. // server's database.
h.waitServerUpdates(hints, waitTime) h.waitServerUpdates(hints, waitTime)
}, },
@ -1599,10 +1625,7 @@ var clientTests = []clientTest{
localBalance: localBalance, localBalance: localBalance,
remoteBalance: remoteBalance, remoteBalance: remoteBalance,
policy: wtpolicy.Policy{ policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{ TxPolicy: defaultTxPolicy,
BlobType: blob.TypeAltruistCommit,
SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
},
MaxUpdates: 5, MaxUpdates: 5,
}, },
}, },
@ -1629,7 +1652,7 @@ var clientTests = []clientTest{
// Back up the remaining two states. Once the first is // Back up the remaining two states. Once the first is
// processed, the session will be exhausted but the // processed, the session will be exhausted but the
// client won't be able to regnegotiate a session for // client won't be able to renegotiate a session for
// the final state. We'll only wait for the first five // the final state. We'll only wait for the first five
// states to arrive at the tower. // states to arrive at the tower.
h.backupStates(chanID, maxUpdates-1, numUpdates, nil) h.backupStates(chanID, maxUpdates-1, numUpdates, nil)

View file

@ -45,7 +45,8 @@ type sessionQueueConfig struct {
Dial func(keychain.SingleKeyECDH, *lnwire.NetAddress) (wtserver.Peer, Dial func(keychain.SingleKeyECDH, *lnwire.NetAddress) (wtserver.Peer,
error) error)
// SendMessage encodes, encrypts, and writes a message to the given peer. // SendMessage encodes, encrypts, and writes a message to the given
// peer.
SendMessage func(wtserver.Peer, wtwire.Message) error SendMessage func(wtserver.Peer, wtwire.Message) error
// ReadMessage receives, decypts, and decodes a message from the given // ReadMessage receives, decypts, and decodes a message from the given
@ -56,6 +57,11 @@ type sessionQueueConfig struct {
// for justice transaction inputs. // for justice transaction inputs.
Signer input.Signer Signer input.Signer
// BuildBreachRetribution is a function closure that allows the client
// to fetch the breach retribution info for a certain channel at a
// certain revoked commitment height.
BuildBreachRetribution BreachRetributionBuilder
// DB provides access to the client's stable storage. // DB provides access to the client's stable storage.
DB DB DB DB
@ -219,7 +225,10 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
// //
// TODO(conner): queue backups and retry with different session params. // TODO(conner): queue backups and retry with different session params.
case reserveAvailable: case reserveAvailable:
err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody) err := task.bindSession(
&q.cfg.ClientSession.ClientSessionBody,
q.cfg.BuildBreachRetribution,
)
if err != nil { if err != nil {
q.queueCond.L.Unlock() q.queueCond.L.Unlock()
q.log.Debugf("SessionQueue(%s) rejected %v: %v ", q.log.Debugf("SessionQueue(%s) rejected %v: %v ",
@ -343,8 +352,8 @@ func (q *sessionQueue) drainBackups() {
// before attempting to dequeue any pending updates. // before attempting to dequeue any pending updates.
stateUpdate, isPending, backupID, err := q.nextStateUpdate() stateUpdate, isPending, backupID, err := q.nextStateUpdate()
if err != nil { if err != nil {
q.log.Errorf("SessionQueue(%v) unable to get next state "+ q.log.Errorf("SessionQueue(%v) unable to get next "+
"update: %v", q.ID(), err) "state update: %v", q.ID(), err)
return return
} }

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/btcsuite/btclog" "github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
) )
// taskPipeline implements a reliable, in-order queue that ensures its queue // taskPipeline implements a reliable, in-order queue that ensures its queue
@ -25,7 +26,7 @@ type taskPipeline struct {
queueCond *sync.Cond queueCond *sync.Cond
queue *list.List queue *list.List
newBackupTasks chan *backupTask newBackupTasks chan *wtdb.BackupID
quit chan struct{} quit chan struct{}
forceQuit chan struct{} forceQuit chan struct{}
@ -37,7 +38,7 @@ func newTaskPipeline(log btclog.Logger) *taskPipeline {
rq := &taskPipeline{ rq := &taskPipeline{
log: log, log: log,
queue: list.New(), queue: list.New(),
newBackupTasks: make(chan *backupTask), newBackupTasks: make(chan *wtdb.BackupID),
quit: make(chan struct{}), quit: make(chan struct{}),
forceQuit: make(chan struct{}), forceQuit: make(chan struct{}),
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
@ -91,7 +92,7 @@ func (q *taskPipeline) ForceQuit() {
// channel will be closed after a call to Stop and all pending tasks have been // channel will be closed after a call to Stop and all pending tasks have been
// delivered, or if a call to ForceQuit is called before the pending entries // delivered, or if a call to ForceQuit is called before the pending entries
// have been drained. // have been drained.
func (q *taskPipeline) NewBackupTasks() <-chan *backupTask { func (q *taskPipeline) NewBackupTasks() <-chan *wtdb.BackupID {
return q.newBackupTasks return q.newBackupTasks
} }
@ -99,7 +100,7 @@ func (q *taskPipeline) NewBackupTasks() <-chan *backupTask {
// of NewBackupTasks. If the taskPipeline is shutting down, ErrClientExiting is // of NewBackupTasks. If the taskPipeline is shutting down, ErrClientExiting is
// returned. Otherwise, if QueueBackupTask returns nil it is guaranteed to be // returned. Otherwise, if QueueBackupTask returns nil it is guaranteed to be
// delivered via NewBackupTasks unless ForceQuit is called before completion. // delivered via NewBackupTasks unless ForceQuit is called before completion.
func (q *taskPipeline) QueueBackupTask(task *backupTask) error { func (q *taskPipeline) QueueBackupTask(task *wtdb.BackupID) error {
q.queueCond.L.Lock() q.queueCond.L.Lock()
select { select {
@ -116,8 +117,8 @@ func (q *taskPipeline) QueueBackupTask(task *backupTask) error {
default: default:
} }
// Queue the new task and signal the queue's condition variable to wake up // Queue the new task and signal the queue's condition variable to wake
// the queueManager for processing. // up the queueManager for processing.
q.queue.PushBack(task) q.queue.PushBack(task)
q.queueCond.L.Unlock() q.queueCond.L.Unlock()
@ -141,16 +142,21 @@ func (q *taskPipeline) queueManager() {
select { select {
case <-q.quit: case <-q.quit:
// Exit only after the queue has been fully drained. // Exit only after the queue has been fully
// drained.
if q.queue.Len() == 0 { if q.queue.Len() == 0 {
q.queueCond.L.Unlock() q.queueCond.L.Unlock()
q.log.Debugf("Revoked state pipeline flushed.") q.log.Debugf("Revoked state pipeline " +
"flushed.")
return return
} }
case <-q.forceQuit: case <-q.forceQuit:
q.queueCond.L.Unlock() q.queueCond.L.Unlock()
q.log.Debugf("Revoked state pipeline force quit.") q.log.Debugf("Revoked state pipeline force " +
"quit.")
return return
default: default:
@ -159,13 +165,15 @@ func (q *taskPipeline) queueManager() {
// Pop the first element from the queue. // Pop the first element from the queue.
e := q.queue.Front() e := q.queue.Front()
task := q.queue.Remove(e).(*backupTask)
//nolint:forcetypeassert
task := q.queue.Remove(e).(*wtdb.BackupID)
q.queueCond.L.Unlock() q.queueCond.L.Unlock()
select { select {
// Backup task submitted to dispatcher. We don't select on quit to // Backup task submitted to dispatcher. We don't select on quit
// ensure that we still drain tasks while shutting down. // to ensure that we still drain tasks while shutting down.
case q.newBackupTasks <- task: case q.newBackupTasks <- task:
// Force quit, return immediately to allow the client to exit. // Force quit, return immediately to allow the client to exit.