package wtclient_test import ( "bytes" "encoding/binary" "errors" "fmt" "net" "sync" "testing" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtserver" "github.com/stretchr/testify/require" ) const ( towerAddrStr = "18.28.243.2:9911" towerAddr2Str = "19.29.244.3:9912" timeout = 200 * time.Millisecond ) var ( revPrivBytes = []byte{ 0x8f, 0x4b, 0x51, 0x83, 0xa9, 0x34, 0xbd, 0x5f, 0x74, 0x6c, 0x9d, 0x5c, 0xae, 0x88, 0x2d, 0x31, 0x06, 0x90, 0xdd, 0x8c, 0x9b, 0x31, 0xbc, 0xd1, 0x78, 0x91, 0x88, 0x2a, 0xf9, 0x74, 0xa0, 0xef, } toLocalPrivBytes = []byte{ 0xde, 0x17, 0xc1, 0x2f, 0xdc, 0x1b, 0xc0, 0xc6, 0x59, 0x5d, 0xf9, 0xc1, 0x3e, 0x89, 0xbc, 0x6f, 0x01, 0x85, 0x45, 0x76, 0x26, 0xce, 0x9c, 0x55, 0x3b, 0xc9, 0xec, 0x3d, 0xd8, 0x8b, 0xac, 0xa8, } toRemotePrivBytes = []byte{ 0x28, 0x59, 0x6f, 0x36, 0xb8, 0x9f, 0x19, 0x5d, 0xcb, 0x07, 0x48, 0x8a, 0xe5, 0x89, 0x71, 0x74, 0x70, 0x4c, 0xff, 0x1e, 0x9c, 0x00, 0x93, 0xbe, 0xe2, 0x2e, 0x68, 0x08, 0x4c, 0xb4, 0x0f, 0x4f, } // addr is the server's reward address given to watchtower clients. addr, _ = btcutil.DecodeAddress( "tb1pw8gzj8clt3v5lxykpgacpju5n8xteskt7gxhmudu6pa70nwfhe6s3unsyk", &chaincfg.TestNet3Params, ) addrScript, _ = txscript.PayToAddrScript(addr) waitTime = 15 * time.Second defaultTxPolicy = wtpolicy.TxPolicy{ BlobType: blob.TypeAltruistCommit, SweepFeeRate: wtpolicy.DefaultSweepFeeRate, } highSweepRateTxPolicy = wtpolicy.TxPolicy{ BlobType: blob.TypeAltruistCommit, SweepFeeRate: 1000000, // The high sweep fee creates dust. } ) // randPrivKey generates a new secp keypair, and returns the public key. func randPrivKey(t *testing.T) *btcec.PrivateKey { t.Helper() sk, err := btcec.NewPrivateKey() require.NoError(t, err, "unable to generate pubkey") return sk } type mockNet struct { mu sync.RWMutex connCallbacks map[string]func(wtserver.Peer) } func newMockNet() *mockNet { return &mockNet{ connCallbacks: make(map[string]func(peer wtserver.Peer)), } } func (m *mockNet) Dial(_, _ string, _ time.Duration) (net.Conn, error) { return nil, nil } func (m *mockNet) LookupHost(_ string) ([]string, error) { panic("not implemented") } func (m *mockNet) LookupSRV(_, _, _ string) (string, []*net.SRV, error) { panic("not implemented") } func (m *mockNet) ResolveTCPAddr(_, _ string) (*net.TCPAddr, error) { panic("not implemented") } func (m *mockNet) AuthDial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress, _ tor.DialFunc) (wtserver.Peer, error) { localPk := local.PubKey() localAddr := &net.TCPAddr{ IP: net.IP{0x32, 0x31, 0x30, 0x29}, Port: 36723, } localPeer, remotePeer := wtmock.NewMockConn( localPk, netAddr.IdentityKey, localAddr, netAddr.Address, 0, ) m.mu.RLock() defer m.mu.RUnlock() cb, ok := m.connCallbacks[netAddr.String()] if !ok { return nil, fmt.Errorf("no callback registered for this peer") } cb(remotePeer) return localPeer, nil } func (m *mockNet) registerConnCallback(netAddr *lnwire.NetAddress, cb func(wtserver.Peer)) { m.mu.Lock() defer m.mu.Unlock() m.connCallbacks[netAddr.String()] = cb } func (m *mockNet) removeConnCallback(netAddr *lnwire.NetAddress) { m.mu.Lock() defer m.mu.Unlock() delete(m.connCallbacks, netAddr.String()) } type mockChannel struct { mu sync.Mutex commitHeight uint64 retributions map[uint64]*lnwallet.BreachRetribution localBalance lnwire.MilliSatoshi remoteBalance lnwire.MilliSatoshi revSK *btcec.PrivateKey revPK *btcec.PublicKey revKeyLoc keychain.KeyLocator toRemoteSK *btcec.PrivateKey toRemotePK *btcec.PublicKey toRemoteKeyLoc keychain.KeyLocator toLocalPK *btcec.PublicKey // only need to generate to-local script dustLimit lnwire.MilliSatoshi csvDelay uint32 } func newMockChannel(t *testing.T, signer *wtmock.MockSigner, localAmt, remoteAmt lnwire.MilliSatoshi) *mockChannel { // Generate the revocation, to-local, and to-remote keypairs. revSK := randPrivKey(t) revPK := revSK.PubKey() toLocalSK := randPrivKey(t) toLocalPK := toLocalSK.PubKey() toRemoteSK := randPrivKey(t) toRemotePK := toRemoteSK.PubKey() // Register the revocation secret key and the to-remote secret key with // the signer. We will not need to sign with the to-local key, as this // is to be known only by the counterparty. revKeyLoc := signer.AddPrivKey(revSK) toRemoteKeyLoc := signer.AddPrivKey(toRemoteSK) c := &mockChannel{ retributions: make(map[uint64]*lnwallet.BreachRetribution), localBalance: localAmt, remoteBalance: remoteAmt, revSK: revSK, revPK: revPK, revKeyLoc: revKeyLoc, toLocalPK: toLocalPK, toRemoteSK: toRemoteSK, toRemotePK: toRemotePK, toRemoteKeyLoc: toRemoteKeyLoc, dustLimit: 546000, csvDelay: 144, } // Create the initial remote commitment with the initial balances. c.createRemoteCommitTx(t) return c } func (c *mockChannel) createRemoteCommitTx(t *testing.T) { t.Helper() // Construct the to-local witness script. toLocalScript, err := input.CommitScriptToSelf( c.csvDelay, c.toLocalPK, c.revPK, ) require.NoError(t, err, "unable to create to-local script") // Compute the to-local witness script hash. toLocalScriptHash, err := input.WitnessScriptHash(toLocalScript) require.NoError(t, err, "unable to create to-local witness script hash") // Compute the to-remote witness script hash. toRemoteScriptHash, err := input.CommitScriptUnencumbered(c.toRemotePK) require.NoError(t, err, "unable to create to-remote script") // Construct the remote commitment txn, containing the to-local and // to-remote outputs. The balances are flipped since the transaction is // from the PoV of the remote party. We don't need any inputs for this // test. We increment the version with the commit height to ensure that // all commitment transactions are unique even if the same distribution // of funds is used more than once. commitTxn := &wire.MsgTx{ Version: int32(c.commitHeight + 1), } var ( toLocalSignDesc *input.SignDescriptor toRemoteSignDesc *input.SignDescriptor ) var outputIndex int if c.remoteBalance >= c.dustLimit { commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{ Value: int64(c.remoteBalance.ToSatoshis()), PkScript: toLocalScriptHash, }) // Create the sign descriptor used to sign for the to-local // input. toLocalSignDesc = &input.SignDescriptor{ KeyDesc: keychain.KeyDescriptor{ KeyLocator: c.revKeyLoc, PubKey: c.revPK, }, WitnessScript: toLocalScript, Output: commitTxn.TxOut[outputIndex], HashType: txscript.SigHashAll, } outputIndex++ } if c.localBalance >= c.dustLimit { commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{ Value: int64(c.localBalance.ToSatoshis()), PkScript: toRemoteScriptHash, }) // Create the sign descriptor used to sign for the to-remote // input. toRemoteSignDesc = &input.SignDescriptor{ KeyDesc: keychain.KeyDescriptor{ KeyLocator: c.toRemoteKeyLoc, PubKey: c.toRemotePK, }, WitnessScript: toRemoteScriptHash, Output: commitTxn.TxOut[outputIndex], HashType: txscript.SigHashAll, } outputIndex++ } txid := commitTxn.TxHash() var ( toLocalOutPoint wire.OutPoint toRemoteOutPoint wire.OutPoint ) outputIndex = 0 if toLocalSignDesc != nil { toLocalOutPoint = wire.OutPoint{ Hash: txid, Index: uint32(outputIndex), } outputIndex++ } if toRemoteSignDesc != nil { toRemoteOutPoint = wire.OutPoint{ Hash: txid, Index: uint32(outputIndex), } outputIndex++ } commitKeyRing := &lnwallet.CommitmentKeyRing{ RevocationKey: c.revPK, ToRemoteKey: c.toLocalPK, ToLocalKey: c.toRemotePK, } retribution := &lnwallet.BreachRetribution{ BreachTxHash: commitTxn.TxHash(), RevokedStateNum: c.commitHeight, KeyRing: commitKeyRing, RemoteDelay: c.csvDelay, LocalOutpoint: toRemoteOutPoint, LocalOutputSignDesc: toRemoteSignDesc, RemoteOutpoint: toLocalOutPoint, RemoteOutputSignDesc: toLocalSignDesc, } c.retributions[c.commitHeight] = retribution c.commitHeight++ } // advanceState creates the next channel state and retribution without altering // channel balances. func (c *mockChannel) advanceState(t *testing.T) { c.mu.Lock() defer c.mu.Unlock() c.createRemoteCommitTx(t) } // sendPayment creates the next channel state and retribution after transferring // amt to the remote party. func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliSatoshi) { t.Helper() c.mu.Lock() defer c.mu.Unlock() require.GreaterOrEqualf(t, c.localBalance, amt, "insufficient funds "+ "to send, need: %v, have: %v", amt, c.localBalance) c.localBalance -= amt c.remoteBalance += amt c.createRemoteCommitTx(t) } // receivePayment creates the next channel state and retribution after // transferring amt to the local party. func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliSatoshi) { t.Helper() c.mu.Lock() defer c.mu.Unlock() require.GreaterOrEqualf(t, c.remoteBalance, amt, "insufficient funds "+ "to recv, need: %v, have: %v", amt, c.remoteBalance) c.localBalance += amt c.remoteBalance -= amt c.createRemoteCommitTx(t) } // getState retrieves the channel's commitment and retribution at state i. func (c *mockChannel) getState( i uint64) (chainhash.Hash, *lnwallet.BreachRetribution) { c.mu.Lock() defer c.mu.Unlock() retribution := c.retributions[i] return retribution.BreachTxHash, retribution } type testHarness struct { t *testing.T cfg harnessCfg signer *wtmock.MockSigner capacity lnwire.MilliSatoshi clientDB *wtdb.ClientDB clientCfg *wtclient.Config client wtclient.Client server *serverHarness net *mockNet blockEvents *mockBlockSub height int32 channelEvents *mockSubscription sendUpdatesOn bool mu sync.Mutex channels map[lnwire.ChannelID]*mockChannel closedChannels map[lnwire.ChannelID]uint32 quit chan struct{} } type harnessCfg struct { localBalance lnwire.MilliSatoshi remoteBalance lnwire.MilliSatoshi policy wtpolicy.Policy noRegisterChan0 bool noAckCreateSession bool noServerStart bool } func newClientDB(t *testing.T) *wtdb.ClientDB { dbCfg := &kvdb.BoltConfig{ DBTimeout: kvdb.DefaultDBTimeout, } // Construct the ClientDB. dir := t.TempDir() bdb, err := wtdb.NewBoltBackendCreator(true, dir, "wtclient.db")(dbCfg) require.NoError(t, err) clientDB, err := wtdb.OpenClientDB(bdb) require.NoError(t, err) return clientDB } func newHarness(t *testing.T, cfg harnessCfg) *testHarness { signer := wtmock.NewMockSigner() mockNet := newMockNet() clientDB := newClientDB(t) server := newServerHarness( t, mockNet, towerAddrStr, func(serverCfg *wtserver.Config) { serverCfg.NoAckCreateSession = cfg.noAckCreateSession }, ) h := &testHarness{ t: t, cfg: cfg, signer: signer, capacity: cfg.localBalance + cfg.remoteBalance, clientDB: clientDB, server: server, net: mockNet, blockEvents: newMockBlockSub(t), channelEvents: newMockSubscription(t), channels: make(map[lnwire.ChannelID]*mockChannel), closedChannels: make(map[lnwire.ChannelID]uint32), quit: make(chan struct{}), } t.Cleanup(func() { close(h.quit) }) fetchChannel := func(id lnwire.ChannelID) ( *channeldb.ChannelCloseSummary, error) { h.mu.Lock() defer h.mu.Unlock() height, ok := h.closedChannels[id] if !ok { return nil, channeldb.ErrClosedChannelNotFound } return &channeldb.ChannelCloseSummary{CloseHeight: height}, nil } h.clientCfg = &wtclient.Config{ Signer: signer, SubscribeChannelEvents: func() (subscribe.Subscription, error) { return h.channelEvents, nil }, FetchClosedChannel: fetchChannel, ChainNotifier: h.blockEvents, Dial: mockNet.Dial, DB: clientDB, AuthDial: mockNet.AuthDial, SecretKeyRing: wtmock.NewSecretKeyRing(), Policy: cfg.policy, NewAddress: func() ([]byte, error) { return addrScript, nil }, ReadTimeout: timeout, WriteTimeout: timeout, MinBackoff: time.Millisecond, MaxBackoff: time.Second, SessionCloseRange: 1, MaxTasksInMemQueue: 2, } h.clientCfg.BuildBreachRetribution = func(id lnwire.ChannelID, commitHeight uint64) (*lnwallet.BreachRetribution, channeldb.ChannelType, error) { _, retribution := h.channelFromID(id).getState(commitHeight) return retribution, channeldb.SingleFunderBit, nil } if !cfg.noServerStart { h.server.start() t.Cleanup(h.server.stop) } h.startClient() t.Cleanup(func() { require.NoError(t, h.client.Stop()) require.NoError(t, h.clientDB.Close()) }) h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) if !cfg.noRegisterChan0 { h.registerChannel(0) } return h } // mine mimics the mining of new blocks by sending new block notifications. func (h *testHarness) mine(numBlocks int) { h.t.Helper() for i := 0; i < numBlocks; i++ { h.height++ h.blockEvents.sendNewBlock(h.height) } } // startClient creates a new server using the harness's current clientCf and // starts it. func (h *testHarness) startClient() { h.t.Helper() towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr) require.NoError(h.t, err) towerAddr := &lnwire.NetAddress{ IdentityKey: h.server.cfg.NodeKeyECDH.PubKey(), Address: towerTCPAddr, } h.client, err = wtclient.New(h.clientCfg) require.NoError(h.t, err) require.NoError(h.t, h.client.Start()) require.NoError(h.t, h.client.AddTower(towerAddr)) } // chanIDFromInt creates a unique channel id given a unique integral id. func chanIDFromInt(id uint64) lnwire.ChannelID { var chanID lnwire.ChannelID binary.BigEndian.PutUint64(chanID[:8], id) return chanID } // makeChannel creates new channel with id, using the localAmt and remoteAmt as // the starting balances. The channel will be available by using h.channel(id). // // NOTE: The method fails if channel for id already exists. func (h *testHarness) makeChannel(id uint64, localAmt, remoteAmt lnwire.MilliSatoshi) { h.t.Helper() chanID := chanIDFromInt(id) c := newMockChannel(h.t, h.signer, localAmt, remoteAmt) c.mu.Lock() _, ok := h.channels[chanID] if !ok { h.channels[chanID] = c } c.mu.Unlock() require.Falsef(h.t, ok, "channel %d already created", id) } // channel retrieves the channel corresponding to id. // // NOTE: The method fails if a channel for id does not exist. func (h *testHarness) channel(id uint64) *mockChannel { h.t.Helper() h.mu.Lock() c, ok := h.channels[chanIDFromInt(id)] h.mu.Unlock() require.Truef(h.t, ok, "unable to fetch channel %d", id) return c } // channelFromID retrieves the channel corresponding to id. // // NOTE: The method fails if a channel for id does not exist. func (h *testHarness) channelFromID(chanID lnwire.ChannelID) *mockChannel { h.t.Helper() h.mu.Lock() defer h.mu.Unlock() c, ok := h.channels[chanID] require.Truef(h.t, ok, "unable to fetch channel %s", chanID) return c } // closeChannel marks a channel as closed. // // NOTE: The method fails if a channel for id does not exist. func (h *testHarness) closeChannel(id uint64, height uint32) { h.t.Helper() h.mu.Lock() defer h.mu.Unlock() chanID := chanIDFromInt(id) _, ok := h.channels[chanID] require.Truef(h.t, ok, "unable to fetch channel %d", id) h.closedChannels[chanID] = height delete(h.channels, chanID) chanPointHash, err := chainhash.NewHash(chanID[:]) require.NoError(h.t, err) if !h.sendUpdatesOn { return } h.channelEvents.sendUpdate(channelnotifier.ClosedChannelEvent{ CloseSummary: &channeldb.ChannelCloseSummary{ ChanPoint: wire.OutPoint{ Hash: *chanPointHash, Index: 0, }, CloseHeight: height, }, }) } // registerChannel registers the channel identified by id with the client. func (h *testHarness) registerChannel(id uint64) { h.t.Helper() chanID := chanIDFromInt(id) err := h.client.RegisterChannel(chanID) require.NoError(h.t, err) } // advanceChannelN calls advanceState on the channel identified by id the number // of provided times and returns the breach hints corresponding to the new // states. func (h *testHarness) advanceChannelN(id uint64, n int) []blob.BreachHint { h.t.Helper() channel := h.channel(id) var hints []blob.BreachHint for i := uint64(0); i < uint64(n); i++ { channel.advanceState(h.t) breachTxID, _ := h.channel(id).getState(i) hints = append(hints, blob.NewBreachHintFromHash(&breachTxID)) } return hints } // backupStates instructs the channel identified by id to send backups to the // client for states in the range [to, from). func (h *testHarness) backupStates(id, from, to uint64, expErr error) { h.t.Helper() for i := from; i < to; i++ { h.backupState(id, i, expErr) } } // backupStates instructs the channel identified by id to send a backup for // state i. func (h *testHarness) backupState(id, i uint64, expErr error) { h.t.Helper() _, retribution := h.channel(id).getState(i) chanID := chanIDFromInt(id) err := h.client.BackupState(&chanID, retribution.RevokedStateNum) require.ErrorIs(h.t, expErr, err) } // sendPayments instructs the channel identified by id to send amt to the remote // party for each state in from-to times and returns the breach hints for states // [from, to). func (h *testHarness) sendPayments(id, from, to uint64, amt lnwire.MilliSatoshi) []blob.BreachHint { h.t.Helper() channel := h.channel(id) var hints []blob.BreachHint for i := from; i < to; i++ { h.channel(id).sendPayment(h.t, amt) breachTxID, _ := channel.getState(i) hints = append(hints, blob.NewBreachHintFromHash(&breachTxID)) } return hints } // receivePayment instructs the channel identified by id to recv amt from the // remote party for each state in from-to times and returns the breach hints for // states [from, to). func (h *testHarness) recvPayments(id, from, to uint64, amt lnwire.MilliSatoshi) []blob.BreachHint { h.t.Helper() channel := h.channel(id) var hints []blob.BreachHint for i := from; i < to; i++ { channel.receivePayment(h.t, amt) breachTxID, _ := channel.getState(i) hints = append(hints, blob.NewBreachHintFromHash(&breachTxID)) } return hints } // addTower adds a tower found at `addr` to the client. func (h *testHarness) addTower(addr *lnwire.NetAddress) { h.t.Helper() err := h.client.AddTower(addr) require.NoError(h.t, err) } // removeTower removes a tower from the client. If `addr` is specified, then the // only said address is removed from the tower. func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { h.t.Helper() err := h.client.RemoveTower(pubKey, addr) require.NoError(h.t, err) } // relevantSessions returns a list of session IDs that have acked updates for // the given channel ID. func (h *testHarness) relevantSessions(chanID uint64) []wtdb.SessionID { h.t.Helper() var ( sessionIDs []wtdb.SessionID cID = chanIDFromInt(chanID) ) collectSessions := wtdb.WithPerNumAckedUpdates( func(session *wtdb.ClientSession, id lnwire.ChannelID, _ uint16) { if !bytes.Equal(id[:], cID[:]) { return } sessionIDs = append(sessionIDs, session.ID) }, ) _, err := h.clientDB.ListClientSessions(nil, collectSessions) require.NoError(h.t, err) return sessionIDs } // isSessionClosable returns true if the given session has been marked as // closable in the DB. func (h *testHarness) isSessionClosable(id wtdb.SessionID) bool { h.t.Helper() cs, err := h.clientDB.ListClosableSessions() require.NoError(h.t, err) _, ok := cs[id] return ok } // mockSubscription is a mock subscription client that blocks on sends into the // updates channel. type mockSubscription struct { t *testing.T updates chan interface{} // Embed the subscription interface in this mock so that we satisfy it. subscribe.Subscription } // newMockSubscription creates a mock subscription. func newMockSubscription(t *testing.T) *mockSubscription { t.Helper() return &mockSubscription{ t: t, updates: make(chan interface{}), } } // sendUpdate sends an update into our updates channel, mocking the dispatch of // an update from a subscription server. This call will fail the test if the // update is not consumed within our timeout. func (m *mockSubscription) sendUpdate(update interface{}) { select { case m.updates <- update: case <-time.After(waitTime): m.t.Fatalf("update: %v timeout", update) } } // Updates returns the updates channel for the mock. func (m *mockSubscription) Updates() <-chan interface{} { return m.updates } // mockBlockSub mocks out the ChainNotifier. type mockBlockSub struct { t *testing.T events chan *chainntnfs.BlockEpoch chainntnfs.ChainNotifier } // newMockBlockSub creates a new mockBlockSub. func newMockBlockSub(t *testing.T) *mockBlockSub { t.Helper() return &mockBlockSub{ t: t, events: make(chan *chainntnfs.BlockEpoch), } } // RegisterBlockEpochNtfn returns a channel that can be used to listen for new // blocks. func (m *mockBlockSub) RegisterBlockEpochNtfn(_ *chainntnfs.BlockEpoch) ( *chainntnfs.BlockEpochEvent, error) { return &chainntnfs.BlockEpochEvent{ Epochs: m.events, }, nil } // sendNewBlock will send a new block on the notification channel. func (m *mockBlockSub) sendNewBlock(height int32) { select { case m.events <- &chainntnfs.BlockEpoch{Height: height}: case <-time.After(waitTime): m.t.Fatalf("timed out sending block: %d", height) } } // serverHarness represents a mock watchtower server. type serverHarness struct { t *testing.T net *mockNet cfg *wtserver.Config addr *lnwire.NetAddress db *wtmock.TowerDB server *wtserver.Server } // newServerHarness constructs a new mock watchtower server. func newServerHarness(t *testing.T, mockNet *mockNet, netAddr string, opt func(cfg *wtserver.Config)) *serverHarness { towerTCPAddr, err := net.ResolveTCPAddr("tcp", netAddr) require.NoError(t, err, "Unable to resolve tower TCP addr") privKey, err := btcec.NewPrivateKey() require.NoError(t, err, "Unable to generate tower private key") privKeyECDH := &keychain.PrivKeyECDH{PrivKey: privKey} towerPubKey := privKey.PubKey() towerAddr := &lnwire.NetAddress{ IdentityKey: towerPubKey, Address: towerTCPAddr, } db := wtmock.NewTowerDB() cfg := &wtserver.Config{ DB: db, ReadTimeout: timeout, WriteTimeout: timeout, NodeKeyECDH: privKeyECDH, NewAddress: func() (btcutil.Address, error) { return addr, nil }, } if opt != nil { opt(cfg) } server, err := wtserver.New(cfg) require.NoError(t, err, "unable to create wtserver") return &serverHarness{ t: t, net: mockNet, cfg: cfg, db: db, addr: towerAddr, server: server, } } // start creates a new server using the harness's current server cfg and starts // it after registering its Dial callback with the mockNet. func (s *serverHarness) start() { s.t.Helper() var err error s.server, err = wtserver.New(s.cfg) require.NoError(s.t, err) s.net.registerConnCallback(s.addr, s.server.InboundPeerConnected) require.NoError(s.t, s.server.Start()) } // stop halts the server and removes its Dial callback from the mockNet. func (s *serverHarness) stop() { s.t.Helper() require.NoError(s.t, s.server.Stop()) s.net.removeConnCallback(s.addr) } // restart stops the server, applies any given config tweaks and then starts the // server again. func (s *serverHarness) restart(op func(cfg *wtserver.Config)) { s.stop() defer s.start() if op == nil { return } op(s.cfg) } // waitForUpdates blocks until the breach hints provided all appear in the // watchtower's database or the timeout expires. This is used to test that the // client in fact sends the updates to the server, even if it is offline. func (s *serverHarness) waitForUpdates(hints []blob.BreachHint, timeout time.Duration) { s.t.Helper() // If no breach hints are provided, we will wait out the full timeout to // assert that no updates appear. wantUpdates := len(hints) > 0 hintSet := make(map[blob.BreachHint]struct{}) for _, hint := range hints { hintSet[hint] = struct{}{} } require.Lenf(s.t, hints, len(hintSet), "breach hints are not unique, "+ "list-len: %d set-len: %d", len(hints), len(hintSet)) // Closure to assert the server's matches are consistent with the hint // set. serverHasHints := func(matches []wtdb.Match) bool { if len(hintSet) != len(matches) { return false } for _, match := range matches { _, ok := hintSet[match.Hint] require.Truef(s.t, ok, "match %v in db is not in "+ "hint set", match.Hint) } return true } require.Truef(s.t, timeout.Seconds() > 1, "timeout must be set to "+ "greater than 1 second") failTimeout := time.After(timeout) for { select { case <-time.After(time.Second): matches, err := s.db.QueryMatches(hints) require.NoError(s.t, err, "unable to query for hints") if wantUpdates && serverHasHints(matches) { return } if wantUpdates { s.t.Logf("Received %d/%d\n", len(matches), len(hints)) } case <-failTimeout: matches, err := s.db.QueryMatches(hints) require.NoError(s.t, err, "unable to query for hints") require.Truef(s.t, serverHasHints(matches), "breach "+ "hints not received, only got %d/%d", len(matches), len(hints)) return } } } // assertUpdatesForPolicy queries the server db for matches using the provided // breach hints, then asserts that each match has a session with the expected // policy. func (s *serverHarness) assertUpdatesForPolicy(hints []blob.BreachHint, expPolicy wtpolicy.Policy) { // Query for matches on the provided hints. matches, err := s.db.QueryMatches(hints) require.NoError(s.t, err) // Assert that the number of matches is exactly the number of provided // hints. require.Lenf(s.t, matches, len(hints), "expected: %d matches, got: %d", len(hints), len(matches)) // Assert that all the matches correspond to a session with the // expected policy. for _, match := range matches { matchPolicy := match.SessionInfo.Policy require.Equal(s.t, expPolicy, matchPolicy) } } const ( localBalance = lnwire.MilliSatoshi(100000000) remoteBalance = lnwire.MilliSatoshi(200000000) ) type clientTest struct { name string cfg harnessCfg fn func(*testHarness) } var clientTests = []clientTest{ { // Asserts that client will return the ErrUnregisteredChannel // error when trying to backup states for a channel that has not // been registered (and received it's pkscript). name: "backup unregistered channel", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, noRegisterChan0: true, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) // Advance the channel and backup the retributions. We // expect ErrUnregisteredChannel to be returned since // the channel was not registered during harness // creation. h.advanceChannelN(chanID, numUpdates) h.backupStates( chanID, 0, numUpdates, wtclient.ErrUnregisteredChannel, ) }, }, { // Asserts that the client returns an ErrClientExiting when // trying to backup channels after the Stop method has been // called. name: "backup after stop", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) // Stop the client, subsequent backups should fail. h.client.Stop() // Advance the channel and try to back up the states. We // expect ErrClientExiting to be returned from // BackupState. h.advanceChannelN(chanID, numUpdates) h.backupStates( chanID, 0, numUpdates, wtclient.ErrClientExiting, ) }, }, { // Asserts that the client will continue to back up all states // that have previously been enqueued before it finishes // exiting. name: "backup reliable flush", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) // Generate numUpdates retributions and back them up to // the tower. hints := h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates, nil) // Wait for all the updates to be populated in the // server's database. h.server.waitForUpdates(hints, waitTime) }, }, { // Assert that the client will not send out backups for states // whose justice transactions are ineligible for backup, e.g. // creating dust outputs. name: "backup dust ineligible", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: highSweepRateTxPolicy, MaxUpdates: 20000, }, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) // Create the retributions and queue them for backup. h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates, nil) // Ensure that no updates are received by the server, // since they should all be marked as ineligible. h.server.waitForUpdates(nil, waitTime) }, }, { // Verifies that the client will properly retransmit a committed // state update to the watchtower after a restart if the update // was not acked while the client was active last. name: "committed update restart", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) hints := h.advanceChannelN(0, numUpdates) var numSent uint64 // Add the first two states to the client's pipeline. h.backupStates(chanID, 0, 2, nil) numSent = 2 // Wait for both to be reflected in the server's // database. h.server.waitForUpdates(hints[:numSent], waitTime) // Now, restart the server and prevent it from acking // state updates. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckUpdates = true }) // Send the next state update to the tower. Since the // tower isn't acking state updates, we expect this // update to be committed and sent by the session queue, // but it will never receive an ack. h.backupState(chanID, numSent, nil) numSent++ // Stop the client to abort the state updates it has // queued. require.NoError(h.t, h.client.Stop()) // Restart the server and allow it to ack the updates // after the client retransmits the unacked update. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckUpdates = false }) // Restart the client and allow it to process the // committed update. h.startClient() // Wait for the committed update to be accepted by the // tower. h.server.waitForUpdates(hints[:numSent], waitTime) // Finally, send the rest of the updates and wait for // the tower to receive the remaining states. h.backupStates(chanID, numSent, numUpdates, nil) // Wait for all the updates to be populated in the // server's database. h.server.waitForUpdates(hints, waitTime) }, }, { // Asserts that the client will continue to retry sending state // updates if it doesn't receive an ack from the server. The // client is expected to flush everything in its in-memory // pipeline once the server begins sending acks again. name: "no ack from server", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( numUpdates = 100 chanID = 0 ) // Generate the retributions that will be backed up. hints := h.advanceChannelN(chanID, numUpdates) // Restart the server and prevent it from acking state // updates. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckUpdates = true }) // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) // Give the client time to saturate a large number of // session queues for which the server has not acked the // state updates that it has received. time.Sleep(time.Second) // Restart the server and allow it to ack the updates // after the client retransmits the unacked updates. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckUpdates = false }) // Wait for all the updates to be populated in the // server's database. h.server.waitForUpdates(hints, waitTime) }, }, { // Asserts that the client is able to send state updates to the // tower for a full range of channel values, assuming the sweep // fee rates permit it. We expect all of these to be successful // since a sweep transactions spending only from one output is // less expensive than one that sweeps both. name: "send and recv", cfg: harnessCfg{ localBalance: 100000001, // ensure (% amt != 0) remoteBalance: 200000001, // ensure (% amt != 0) policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 1000, }, }, fn: func(h *testHarness) { var ( capacity = h.cfg.localBalance + h.cfg.remoteBalance paymentAmt = lnwire.MilliSatoshi(2000000) numSends = uint64(h.cfg.localBalance) / uint64(paymentAmt) numRecvs = uint64(capacity / paymentAmt) numUpdates = numSends + numRecvs // 200 updates chanID = uint64(0) ) // Send money to the remote party until all funds are // depleted. sendHints := h.sendPayments( chanID, 0, numSends, paymentAmt, ) // Now, sequentially receive the entire channel balance // from the remote party. recvHints := h.recvPayments( chanID, numSends, numUpdates, paymentAmt, ) // Collect the hints generated by both sending and // receiving. hints := append(sendHints, recvHints...) // Backup the channel's states the client. h.backupStates(chanID, 0, numUpdates, nil) // Wait for all the updates to be populated in the // server's database. h.server.waitForUpdates(hints, waitTime) }, }, { // Asserts that the client is able to support multiple links. name: "multiple link backup", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const numUpdates = 5 // Initialize and register an additional 9 channels. for id := uint64(1); id < 10; id++ { h.makeChannel( id, h.cfg.localBalance, h.cfg.remoteBalance, ) h.registerChannel(id) } // Generate the retributions for all 10 channels and // collect the breach hints. var hints []blob.BreachHint for id := uint64(0); id < 10; id++ { chanHints := h.advanceChannelN(id, numUpdates) hints = append(hints, chanHints...) } // Provided all retributions to the client from all // channels. for id := uint64(0); id < 10; id++ { h.backupStates(id, 0, numUpdates, nil) } // Wait for all the updates to be populated in the // server's database. h.server.waitForUpdates(hints, 10*time.Second) }, }, { name: "create session no ack", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noAckCreateSession: true, }, fn: func(h *testHarness) { const ( chanID = 0 numUpdates = 3 ) // Generate the retributions that will be backed up. hints := h.advanceChannelN(chanID, numUpdates) // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) // Since the client is unable to create a session, the // server should have no updates. h.server.waitForUpdates(nil, waitTime) // Stop the client since it has queued backups. require.NoError(h.t, h.client.Stop()) // Restart the server and allow it to ack session // creation. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckCreateSession = false }) // Restart the client with the same policy, which will // immediately try to overwrite the old session with an // identical one. h.startClient() // Wait for all the updates to be populated in the // server's database. h.server.waitForUpdates(hints, waitTime) // Assert that the server has updates for the clients // most recent policy. h.server.assertUpdatesForPolicy( hints, h.clientCfg.Policy, ) }, }, { name: "create session no ack change policy", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noAckCreateSession: true, }, fn: func(h *testHarness) { const ( chanID = 0 numUpdates = 3 ) // Generate the retributions that will be backed up. hints := h.advanceChannelN(chanID, numUpdates) // Now, queue the retributions for backup. h.backupStates(chanID, 0, numUpdates, nil) // Since the client is unable to create a session, the // server should have no updates. h.server.waitForUpdates(nil, waitTime) // Stop the client since it has queued backups. require.NoError(h.t, h.client.Stop()) // Restart the server and allow it to ack session // creation. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckCreateSession = false }) // Restart the client with a new policy, which will // immediately try to overwrite the prior session with // the old policy. h.clientCfg.Policy.SweepFeeRate *= 2 h.startClient() // Wait for all the updates to be populated in the // server's database. h.server.waitForUpdates(hints, waitTime) // Assert that the server has updates for the clients // most recent policy. h.server.assertUpdatesForPolicy( hints, h.clientCfg.Policy, ) }, }, { // Asserts that the client will not request a new session if // already has an existing session with the same TxPolicy. This // permits the client to continue using policies that differ in // operational parameters, but don't manifest in different // justice transactions. name: "create session change policy same txpolicy", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 10, }, }, fn: func(h *testHarness) { const ( chanID = 0 numUpdates = 6 ) // Generate the retributions that will be backed up. hints := h.advanceChannelN(chanID, numUpdates) // Now, queue the first half of the retributions. h.backupStates(chanID, 0, numUpdates/2, nil) // Wait for the server to collect the first half. h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Stop the client, which should have no more backups. require.NoError(h.t, h.client.Stop()) // Record the policy that the first half was stored // under. We'll expect the second half to also be // stored under the original policy, since we are only // adjusting the MaxUpdates. The client should detect // that the two policies have equivalent TxPolicies and // continue using the first. expPolicy := h.clientCfg.Policy // Restart the client with a new policy. h.clientCfg.Policy.MaxUpdates = 20 h.startClient() // Now, queue the second half of the retributions. h.backupStates(chanID, numUpdates/2, numUpdates, nil) // Wait for all the updates to be populated in the // server's database. h.server.waitForUpdates(hints, waitTime) // Assert that the server has updates for the client's // original policy. h.server.assertUpdatesForPolicy(hints, expPolicy) }, }, { // Asserts that the client will deduplicate backups presented by // a channel both in memory and after a restart. The client // should only accept backups with a commit height greater than // any processed already processed for a given policy. name: "dedup backups", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( numUpdates = 10 chanID = 0 ) // Generate the retributions that will be backed up. hints := h.advanceChannelN(chanID, numUpdates) // Queue the first half of the retributions twice, the // second batch should be entirely deduped by the // client's in-memory tracking. h.backupStates(chanID, 0, numUpdates/2, nil) h.backupStates(chanID, 0, numUpdates/2, nil) // Wait for the first half of the updates to be // populated in the server's database. h.server.waitForUpdates(hints[:len(hints)/2], waitTime) // Restart the client, so we can ensure the deduping is // maintained across restarts. require.NoError(h.t, h.client.Stop()) h.startClient() // Try to back up the full range of retributions. Only // the second half should actually be sent. h.backupStates(chanID, 0, numUpdates, nil) // Wait for all the updates to be populated in the // server's database. h.server.waitForUpdates(hints, waitTime) }, }, { // Asserts that the client can continue making backups to a // tower that's been re-added after it's been removed. name: "re-add removed tower", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( chanID = 0 numUpdates = 4 ) // Create four channel updates and only back up the // first two. hints := h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates/2, nil) h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Fully remove the tower, causing its existing sessions // to be marked inactive. h.removeTower(h.server.addr.IdentityKey, nil) // Back up the remaining states. Since the tower has // been removed, it shouldn't receive any updates. h.backupStates(chanID, numUpdates/2, numUpdates, nil) h.server.waitForUpdates(nil, waitTime) // Re-add the tower. We prevent the tower from acking // session creation to ensure the inactive sessions are // not used. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckCreateSession = true }) h.addTower(h.server.addr) h.server.waitForUpdates(nil, waitTime) // Finally, allow the tower to ack session creation, // allowing the state updates to be sent through the new // session. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckCreateSession = false }) h.server.waitForUpdates(hints[numUpdates/2:], waitTime) }, }, { // Assert that if a client changes the address for a server and // then tries to back up updates then the client will switch to // the new address. name: "change address of existing session", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( chanID = 0 numUpdates = 6 maxUpdates = 5 ) // Advance the channel to create all states. hints := h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates/2, nil) // Wait for the first half of the updates to be // populated in the server's database. h.server.waitForUpdates(hints[:len(hints)/2], waitTime) // Stop the server. h.server.stop() // Change the address of the server. towerTCPAddr, err := net.ResolveTCPAddr( "tcp", towerAddr2Str, ) require.NoError(h.t, err) oldAddr := h.server.addr.Address towerAddr := &lnwire.NetAddress{ IdentityKey: h.server.addr.IdentityKey, Address: towerTCPAddr, } h.server.addr = towerAddr // Add the new tower address to the client. err = h.client.AddTower(towerAddr) require.NoError(h.t, err) // Remove the old tower address from the client. err = h.client.RemoveTower( towerAddr.IdentityKey, oldAddr, ) require.NoError(h.t, err) // Restart the server. h.server.start() // Now attempt to back up the rest of the updates. h.backupStates(chanID, numUpdates/2, maxUpdates, nil) // Assert that the server does receive the updates. h.server.waitForUpdates(hints[:maxUpdates], waitTime) }, }, { // Assert that a user is able to remove a tower address during // session negotiation as long as the address in question is not // currently being used. name: "removing a tower during session negotiation", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noServerStart: true, }, fn: func(h *testHarness) { // The server has not started yet and so no session // negotiation with the server will be in progress, so // the client should be able to remove the server. err := wait.NoError(func() error { return h.client.RemoveTower( h.server.addr.IdentityKey, nil, ) }, waitTime) require.NoError(h.t, err) // Set the server up so that its Dial function hangs // when the client calls it. This will force the client // to remain in the state where it has locked the // address of the server. h.server.server, err = wtserver.New(h.server.cfg) require.NoError(h.t, err) cancel := make(chan struct{}) h.net.registerConnCallback( h.server.addr, func(peer wtserver.Peer) { select { case <-h.quit: case <-cancel: } }, ) // Also add a new tower address. towerTCPAddr, err := net.ResolveTCPAddr( "tcp", towerAddr2Str, ) require.NoError(h.t, err) towerAddr := &lnwire.NetAddress{ IdentityKey: h.server.addr.IdentityKey, Address: towerTCPAddr, } // Register the new address in the mock-net. h.net.registerConnCallback( towerAddr, h.server.server.InboundPeerConnected, ) // Now start the server. require.NoError(h.t, h.server.server.Start()) // Re-add the server to the client err = h.client.AddTower(h.server.addr) require.NoError(h.t, err) // Also add the new tower address. err = h.client.AddTower(towerAddr) require.NoError(h.t, err) // Assert that if the client attempts to remove the // tower's first address, then it will error due to // address currently being locked for session // negotiation. err = wait.Predicate(func() bool { err = h.client.RemoveTower( h.server.addr.IdentityKey, h.server.addr.Address, ) return errors.Is(err, wtclient.ErrAddrInUse) }, waitTime) require.NoError(h.t, err) // Assert that the second address can be removed since // it is not being used for session negotiation. err = wait.NoError(func() error { return h.client.RemoveTower( h.server.addr.IdentityKey, towerTCPAddr, ) }, waitTime) require.NoError(h.t, err) // Allow the dial to the first address to stop hanging. close(cancel) // Assert that the client can now remove the first // address. err = wait.NoError(func() error { return h.client.RemoveTower( h.server.addr.IdentityKey, nil, ) }, waitTime) require.NoError(h.t, err) }, }, { name: "assert that sessions are correctly marked as closable", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const numUpdates = 5 // In this test we assert that a channel is correctly // marked as closed and that sessions are also correctly // marked as closable. // We start with the sendUpdatesOn parameter set to // false so that we can test that channels are correctly // evaluated at startup. h.sendUpdatesOn = false // Advance channel 0 to create all states and back them // all up. This will saturate the session with updates // for channel 0 which means that the session should be // considered closable when channel 0 is closed. hints := h.advanceChannelN(0, numUpdates) h.backupStates(0, 0, numUpdates, nil) h.server.waitForUpdates(hints, waitTime) // We expect only 1 session to have updates for this // channel. sessionIDs := h.relevantSessions(0) require.Len(h.t, sessionIDs, 1) // Since channel 0 is still open, the session should not // yet be closable. require.False(h.t, h.isSessionClosable(sessionIDs[0])) // Close the channel. h.closeChannel(0, 1) // Since updates are currently not being sent, we expect // the session to still not be marked as closable. require.False(h.t, h.isSessionClosable(sessionIDs[0])) // Restart the client. require.NoError(h.t, h.client.Stop()) h.startClient() // The session should now have been marked as closable. err := wait.Predicate(func() bool { return h.isSessionClosable(sessionIDs[0]) }, waitTime) require.NoError(h.t, err) // Now we set sendUpdatesOn to true and do the same with // a new channel. A restart should now not be necessary // anymore. h.sendUpdatesOn = true h.makeChannel( 1, h.cfg.localBalance, h.cfg.remoteBalance, ) h.registerChannel(1) hints = h.advanceChannelN(1, numUpdates) h.backupStates(1, 0, numUpdates, nil) h.server.waitForUpdates(hints, waitTime) // Determine the ID of the session of interest. sessionIDs = h.relevantSessions(1) // We expect only 1 session to have updates for this // channel. require.Len(h.t, sessionIDs, 1) // Assert that the session is not yet closable since // the channel is still open. require.False(h.t, h.isSessionClosable(sessionIDs[0])) // Now close the channel. h.closeChannel(1, 1) // Since the updates have been turned on, the session // should now show up as closable. err = wait.Predicate(func() bool { return h.isSessionClosable(sessionIDs[0]) }, waitTime) require.NoError(h.t, err) // Now we test that a session must be exhausted with all // channels closed before it is seen as closable. h.makeChannel( 2, h.cfg.localBalance, h.cfg.remoteBalance, ) h.registerChannel(2) // Fill up only half of the session updates. hints = h.advanceChannelN(2, numUpdates) h.backupStates(2, 0, numUpdates/2, nil) h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Determine the ID of the session of interest. sessionIDs = h.relevantSessions(2) // We expect only 1 session to have updates for this // channel. require.Len(h.t, sessionIDs, 1) // Now close the channel. h.closeChannel(2, 1) // The session should _not_ be closable due to it not // being exhausted yet. require.False(h.t, h.isSessionClosable(sessionIDs[0])) // Create a new channel. h.makeChannel( 3, h.cfg.localBalance, h.cfg.remoteBalance, ) h.registerChannel(3) hints = h.advanceChannelN(3, numUpdates) h.backupStates(3, 0, numUpdates, nil) h.server.waitForUpdates(hints, waitTime) // Close it. h.closeChannel(3, 1) // Now the session should be closable. err = wait.Predicate(func() bool { return h.isSessionClosable(sessionIDs[0]) }, waitTime) require.NoError(h.t, err) // Now we will mine a few blocks. This will cause the // necessary session-close-range to be exceeded meaning // that the client should send the DeleteSession message // to the server. We will assert that both the client // and server have deleted the appropriate sessions and // channel info. // Before we mine blocks, assert that the client // currently has 3 closable sessions. closableSess, err := h.clientDB.ListClosableSessions() require.NoError(h.t, err) require.Len(h.t, closableSess, 3) // Assert that the server is also aware of all of these // sessions. for sid := range closableSess { _, err := h.server.db.GetSessionInfo(&sid) require.NoError(h.t, err) } // Also make a note of the total number of sessions the // client has. sessions, err := h.clientDB.ListClientSessions(nil) require.NoError(h.t, err) require.Len(h.t, sessions, 4) h.mine(3) // The client should no longer have any closable // sessions and the total list of client sessions should // no longer include the three that it previously had // marked as closable. The server should also no longer // have these sessions in its DB. err = wait.Predicate(func() bool { sess, err := h.clientDB.ListClientSessions(nil) require.NoError(h.t, err) cs, err := h.clientDB.ListClosableSessions() require.NoError(h.t, err) if len(sess) != 1 || len(cs) != 0 { return false } for sid := range closableSess { _, ok := sess[sid] if ok { return false } _, err := h.server.db.GetSessionInfo( &sid, ) if !errors.Is( err, wtdb.ErrSessionNotFound, ) { return false } } return true }, waitTime) require.NoError(h.t, err) }, }, { // Demonstrate that the client is able to recover after // deleting its database by skipping through key indices until // it gets to one that does not result in the // CreateSessionCodeAlreadyExists error code being returned from // the server. name: "continue after client database deletion", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) // Generate numUpdates retributions. hints := h.advanceChannelN(chanID, numUpdates) // Back half of the states up. h.backupStates(chanID, 0, numUpdates/2, nil) // Wait for the updates to be populated in the server's // database. h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Now stop the client and reset its database. require.NoError(h.t, h.client.Stop()) db := newClientDB(h.t) h.clientDB = db h.clientCfg.DB = db // Restart the client. h.startClient() // We need to re-register the channel due to the client // db being reset. h.registerChannel(0) // Attempt to back up the remaining tasks. h.backupStates(chanID, numUpdates/2, numUpdates, nil) // Show that the server does get the remaining updates. h.server.waitForUpdates(hints[numUpdates/2:], waitTime) }, }, { // This test demonstrates that if there is no active session, // the updates are persisted to disk on restart and reliably // sent. name: "in-mem updates not lost on restart", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, // noServerStart ensures that the server does not // automatically start on creation of the test harness. // This ensures that the client does not initially have // any active sessions. noServerStart: true, }, fn: func(h *testHarness) { const ( chanID = 0 numUpdates = 5 ) // Try back up the first few states of the client's // channel. Since the server has not yet started, the // client should have no active session yet and so these // updates will just be kept in an in-memory queue. hints := h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates/2, nil) // Restart the Client. And also now start the server. require.NoError(h.t, h.client.Stop()) h.server.start() h.startClient() // Back up a few more states. h.backupStates(chanID, numUpdates/2, numUpdates, nil) // Assert that the server does receive ALL the updates. h.server.waitForUpdates(hints[0:numUpdates], waitTime) }, }, { // Assert that the client is able to switch to a new tower if // the primary one goes down. name: "switch to new tower", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) // Generate numUpdates retributions and back a few of // them up to the main tower. hints := h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates/2, nil) // Wait for all the backed up updates to be populated in // the server's database. h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Now we add a new tower. server2 := newServerHarness( h.t, h.net, towerAddr2Str, nil, ) server2.start() h.addTower(server2.addr) // Stop the old tower and remove it from the client. h.server.stop() h.removeTower(h.server.addr.IdentityKey, nil) // Back up the remaining states. h.backupStates(chanID, numUpdates/2, numUpdates, nil) // Assert that the new tower has the remaining states. server2.waitForUpdates(hints[numUpdates/2:], waitTime) }, }, { // Show that if a client switches to a new tower _after_ backup // tasks have been bound to the session with the first old tower // then these updates are replayed onto the new tower. name: "switch to new tower after tasks are bound", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) // Generate numUpdates retributions and back a few of // them up to the main tower. hints := h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates/2, nil) // Wait for all these updates to be populated in the // server's database. h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Now stop the server. h.server.stop() // Back up a few more tasks. This will bind the // backup tasks to the session with the old server. h.backupStates(chanID, numUpdates/2, numUpdates-1, nil) // Now we add a new tower. server2 := newServerHarness( h.t, h.net, towerAddr2Str, nil, ) server2.start() h.addTower(server2.addr) // Now we can remove the old one. err := wait.Predicate(func() bool { err := h.client.RemoveTower( h.server.addr.IdentityKey, nil, ) return err == nil }, waitTime) require.NoError(h.t, err) // Back up the final task. h.backupStates(chanID, numUpdates-1, numUpdates, nil) // Show that all the backups (the ones added while no // towers were online and the one added after adding the // second tower) are backed up to the second tower. server2.waitForUpdates( hints[numUpdates/2:numUpdates], waitTime, ) }, }, { // Assert that a client is able to remove a tower if there are // persisted un-acked updates. This tests the case where the // client is not-restarted meaning that the un-acked updates // will still be in the pending queue. name: "can remove due to un-acked updates (no client restart)", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) // Generate numUpdates retributions and back a few of // them up to the main tower. hints := h.advanceChannelN(chanID, numUpdates) h.backupStates(chanID, 0, numUpdates/2, nil) // Wait for all these updates to be populated in the // server's database. h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Now stop the server and restart it with the // NoAckUpdates set to true. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckUpdates = true }) // Back up the remaining tasks. This will bind the // backup tasks to the session with the server. The // client will also persist the updates. h.backupStates(chanID, numUpdates/2, numUpdates, nil) tower, err := h.clientDB.LoadTower( h.server.addr.IdentityKey, ) require.NoError(h.t, err) // Wait till the updates have been persisted. err = wait.Predicate(func() bool { var numCommittedUpdates int countUpdates := func(_ *wtdb.ClientSession, update *wtdb.CommittedUpdate) { numCommittedUpdates++ } _, err := h.clientDB.ListClientSessions( &tower.ID, wtdb.WithPerCommittedUpdate( countUpdates, ), ) require.NoError(h.t, err) return numCommittedUpdates == 1 }, waitTime) require.NoError(h.t, err) // Now remove the tower. err = h.client.RemoveTower( h.server.addr.IdentityKey, nil, ) require.NoError(h.t, err) // Add a new tower. server2 := newServerHarness( h.t, h.net, towerAddr2Str, nil, ) server2.start() h.addTower(server2.addr) // Now we assert that the backups are backed up to the // new tower. server2.waitForUpdates(hints[numUpdates/2:], waitTime) }, }, { // Assert that a client is able to remove a tower if there are // persisted un-acked updates _and_ the client is restarted // before the tower is removed. name: "can remove tower with un-acked updates (with restart)", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( numUpdates = 5 chanID = 0 ) // Generate numUpdates retributions. hints := h.advanceChannelN(chanID, numUpdates) // Back half of the states up. h.backupStates(chanID, 0, numUpdates/2, nil) // Wait for the updates to be populated in the server's // database. h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Now stop the server and restart it with the // NoAckUpdates set to true. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckUpdates = true }) // Back up the remaining tasks. This will bind the // backup tasks to the session with the server. The // client will also attempt to get the ack for one // update which will cause a CommittedUpdate to be // persisted. h.backupStates(chanID, numUpdates/2, numUpdates, nil) tower, err := h.clientDB.LoadTower( h.server.addr.IdentityKey, ) require.NoError(h.t, err) // Wait till the updates have been persisted. err = wait.Predicate(func() bool { var numCommittedUpdates int countUpdates := func(_ *wtdb.ClientSession, update *wtdb.CommittedUpdate) { numCommittedUpdates++ } _, err := h.clientDB.ListClientSessions( &tower.ID, wtdb.WithPerCommittedUpdate( countUpdates, ), ) require.NoError(h.t, err) return numCommittedUpdates == 1 }, waitTime) require.NoError(h.t, err) // Now restart the client. This ensures that the // updates are no longer in the pending queue. require.NoError(h.t, h.client.Stop()) h.startClient() // Now remove the tower. err = h.client.RemoveTower( h.server.addr.IdentityKey, nil, ) require.NoError(h.t, err) // Add a new tower. server2 := newServerHarness( h.t, h.net, towerAddr2Str, nil, ) server2.start() h.addTower(server2.addr) // Now we assert that the backups are backed up to the // new tower. server2.waitForUpdates(hints[numUpdates/2:], waitTime) }, }, { // This test shows that if a channel is closed while an update // for that channel still exists in an in-memory queue // somewhere then it is handled correctly by treating it as a // rogue update. name: "channel closed while update is un-acked", cfg: harnessCfg{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, fn: func(h *testHarness) { const ( numUpdates = 10 chanIDInt = 0 ) h.sendUpdatesOn = true // Advance the channel with a few updates. hints := h.advanceChannelN(chanIDInt, numUpdates) // Backup a few these updates and wait for them to // arrive at the server. Note that we back up enough // updates to saturate the session so that the session // is considered closable when the channel is deleted. h.backupStates(chanIDInt, 0, numUpdates/2, nil) h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Now, restart the server in a state where it will not // ack updates. This will allow us to wait for an // update to be un-acked and persisted. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckUpdates = true }) // Backup a few more of the update. These should remain // in the client as un-acked. h.backupStates( chanIDInt, numUpdates/2, numUpdates-1, nil, ) // Wait for the tasks to be bound to sessions. fetchSessions := h.clientDB.FetchSessionCommittedUpdates err := wait.Predicate(func() bool { sessions, err := h.clientDB.ListClientSessions( nil, ) require.NoError(h.t, err) var updates []wtdb.CommittedUpdate for id := range sessions { updates, err = fetchSessions(&id) require.NoError(h.t, err) if len(updates) != numUpdates-1 { return true } } return false }, waitTime) require.NoError(h.t, err) // Now we close this channel while the update for it has // not yet been acked. h.closeChannel(chanIDInt, 1) // Closable sessions should now be one. err = wait.Predicate(func() bool { cs, err := h.clientDB.ListClosableSessions() require.NoError(h.t, err) return len(cs) == 1 }, waitTime) require.NoError(h.t, err) // Now, restart the server and allow it to ack updates // again. h.server.restart(func(cfg *wtserver.Config) { cfg.NoAckUpdates = false }) // Mine a few blocks so that the session close range is // surpassed. h.mine(3) // Wait for there to be no more closable sessions on the // client side. err = wait.Predicate(func() bool { cs, err := h.clientDB.ListClosableSessions() require.NoError(h.t, err) return len(cs) == 0 }, waitTime) require.NoError(h.t, err) // Wait for channel to be "unregistered". chanID := chanIDFromInt(chanIDInt) err = wait.Predicate(func() bool { err := h.client.BackupState(&chanID, 0) return errors.Is( err, wtclient.ErrUnregisteredChannel, ) }, waitTime) require.NoError(h.t, err) // Show that the committed update for the closed channel // is cleared from the DB. err = wait.Predicate(func() bool { sessions, err := h.clientDB.ListClientSessions( nil, ) require.NoError(h.t, err) var updates []wtdb.CommittedUpdate for id := range sessions { updates, err = fetchSessions(&id) require.NoError(h.t, err) if len(updates) != 0 { return false } } return true }, waitTime) require.NoError(h.t, err) }, }, } // TestClient executes the client test suite, asserting the ability to backup // states in a number of failure cases and it's reliability during shutdown. func TestClient(t *testing.T) { for _, test := range clientTests { tc := test t.Run(tc.name, func(t *testing.T) { t.Parallel() h := newHarness(t, tc.cfg) tc.fn(h) }) } }