From 9c70f499014ed484604a83f05611df864c43e333 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:29:42 -0700 Subject: [PATCH 01/21] watchtower/wtwire/create_session_reply: remove extra Reject from code --- watchtower/wtwire/create_session_reply.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/watchtower/wtwire/create_session_reply.go b/watchtower/wtwire/create_session_reply.go index da4867f29..9b63b08c7 100644 --- a/watchtower/wtwire/create_session_reply.go +++ b/watchtower/wtwire/create_session_reply.go @@ -14,9 +14,9 @@ const ( // reply was never received and/or processed by the client. CreateSessionCodeAlreadyExists CreateSessionCode = 60 - // CreateSessionCodeRejectRejectMaxUpdates the tower rejected the maximum + // CreateSessionCodeRejectMaxUpdates the tower rejected the maximum // number of state updates proposed by the client. - CreateSessionCodeRejectRejectMaxUpdates CreateSessionCode = 61 + CreateSessionCodeRejectMaxUpdates CreateSessionCode = 61 // CreateSessionCodeRejectRewardRate the tower rejected the reward rate // proposed by the client. From 99dbbf48aae65667282af2c3856c94448847c261 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:29:55 -0700 Subject: [PATCH 02/21] watchtower/wtwire/error_code: add human-readable descriptors --- watchtower/wtwire/error_code.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/watchtower/wtwire/error_code.go b/watchtower/wtwire/error_code.go index 6a441784a..2f4bc6bb3 100644 --- a/watchtower/wtwire/error_code.go +++ b/watchtower/wtwire/error_code.go @@ -1,5 +1,7 @@ package wtwire +import "fmt" + // ErrorCode represents a generic error code used when replying to watchtower // clients. Specific reply messages may extend the ErrorCode primitive and add // custom codes, so long as they don't collide with the generic error codes.. @@ -18,3 +20,33 @@ const ( // permanently failed, and further communication should be avoided. CodePermanentFailure ErrorCode = 50 ) + +// String returns a human-readable description of an ErrorCode. +func (c ErrorCode) String() string { + switch c { + case CodeOK: + return "CodeOK" + case CodeTemporaryFailure: + return "CodeTemporaryFailure" + case CodePermanentFailure: + return "CodePermanentFailure" + case CreateSessionCodeAlreadyExists: + return "CreateSessionCodeAlreadyExists" + case CreateSessionCodeRejectMaxUpdates: + return "CreateSessionCodeRejectMaxUpdates" + case CreateSessionCodeRejectRewardRate: + return "CreateSessionCodeRejectRewardRate" + case CreateSessionCodeRejectSweepFeeRate: + return "CreateSessionCodeRejectSweepFeeRate" + case CreateSessionCodeRejectBlobType: + return "CreateSessionCodeRejectBlobType" + case StateUpdateCodeClientBehind: + return "StateUpdateCodeClientBehind" + case StateUpdateCodeMaxUpdatesExceeded: + return "StateUpdateCodeMaxUpdatesExceeded" + case StateUpdateCodeSeqNumOutOfOrder: + return "StateUpdateCodeSeqNumOutOfOrder" + default: + return fmt.Sprintf("UnknownErrorCode: %d", c) + } +} From 247978dfe27803f14bd4dfd7f32f3bc58cca318d Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:30:09 -0700 Subject: [PATCH 03/21] watchtower/wtdb/tower: store wt pk and addrs --- watchtower/wtdb/tower.go | 65 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 watchtower/wtdb/tower.go diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go new file mode 100644 index 000000000..e7213cab4 --- /dev/null +++ b/watchtower/wtdb/tower.go @@ -0,0 +1,65 @@ +package wtdb + +import ( + "net" + "sync" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/lnwire" +) + +// Tower holds the necessary components required to connect to a remote tower. +// Communication is handled by brontide, and requires both a public key and an +// address. +type Tower struct { + // ID is a unique ID for this record assigned by the database. + ID uint64 + + // IdentityKey is the public key of the remote node, used to + // authenticate the brontide transport. + IdentityKey *btcec.PublicKey + + // Addresses is a list of possible addresses to reach the tower. + Addresses []net.Addr + + mu sync.RWMutex +} + +// AddAddress adds the given address to the tower's in-memory list of addresses. +// If the address's string is already present, the Tower will be left +// unmodified. Otherwise, the adddress is prepended to the beginning of the +// Tower's addresses, on the assumption that it is fresher than the others. +func (t *Tower) AddAddress(addr net.Addr) { + t.mu.Lock() + defer t.mu.Unlock() + + // Ensure we don't add a duplicate address. + addrStr := addr.String() + for _, existingAddr := range t.Addresses { + if existingAddr.String() == addrStr { + return + } + } + + // Add this address to the front of the list, on the assumption that it + // is a fresher address and will be tried first. + t.Addresses = append([]net.Addr{addr}, t.Addresses...) +} + +// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's +// addresses. This can be used to have a client try multiple addresses for the +// same Tower. +func (t *Tower) LNAddrs() []*lnwire.NetAddress { + t.mu.RLock() + defer t.mu.RUnlock() + + addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses)) + for _, addr := range t.Addresses { + addrs = append(addrs, &lnwire.NetAddress{ + IdentityKey: t.IdentityKey, + Address: addr, + }) + } + + return addrs +} From 9177358a3ce8298c0eab567669ac44c7c184738e Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:30:22 -0700 Subject: [PATCH 04/21] watchtower/wtdb/client_session: add ClientSession --- watchtower/wtdb/client_session.go | 110 ++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 watchtower/wtdb/client_session.go diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go new file mode 100644 index 000000000..f7b531fec --- /dev/null +++ b/watchtower/wtdb/client_session.go @@ -0,0 +1,110 @@ +package wtdb + +import ( + "errors" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" +) + +var ( + // ErrClientSessionNotFound signals that the requested client session + // was not found in the database. + ErrClientSessionNotFound = errors.New("client session not found") + + // ErrUpdateAlreadyCommitted signals that the chosen sequence number has + // already been committed to an update with a different breach hint. + ErrUpdateAlreadyCommitted = errors.New("update already committed") + + // ErrCommitUnorderedUpdate signals the client tried to commit a + // sequence number other than the next unallocated sequence number. + ErrCommitUnorderedUpdate = errors.New("update seqnum not monotonic") + + // ErrCommittedUpdateNotFound signals that the tower tried to ACK a + // sequence number that has not yet been allocated by the client. + ErrCommittedUpdateNotFound = errors.New("committed update not found") + + // ErrUnallocatedLastApplied signals that the tower tried to provide a + // LastApplied value greater than any allocated sequence number. + ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " + + "greater than allocated seqnum") +) + +// ClientSession encapsulates a SessionInfo returned from a successful +// session negotiation, and also records the tower and ephemeral secret used for +// communicating with the tower. +type ClientSession struct { + // ID is the client's public key used when authenticating with the + // tower. + ID SessionID + + // SeqNum is the next unallocated sequence number that can be sent to + // the tower. + SeqNum uint16 + + // TowerLastApplied the last last-applied the tower has echoed back. + TowerLastApplied uint16 + + // TowerID is the unique, db-assigned identifier that references the + // Tower with which the session is negotiated. + TowerID uint64 + + // Tower holds the pubkey and address of the watchtower. + // + // NOTE: This value is not serialized. It is recovered by looking up the + // tower with TowerID. + Tower *Tower + + // SessionKeyDesc is the key descriptor used to derive the client's + // session key so that it can authenticate with the tower to update its + // session. + SessionKeyDesc keychain.KeyLocator + + // SessionPrivKey is the ephemeral secret key used to connect to the + // watchtower. + // TODO(conner): remove after HD keys + SessionPrivKey *btcec.PrivateKey + + // Policy holds the negotiated session parameters. + Policy wtpolicy.Policy + + // RewardPkScript is the pkscript that the tower's reward will be + // deposited to if a sweep transaction confirms and the sessions + // specifies a reward output. + RewardPkScript []byte + + // CommittedUpdates is a map from allocated sequence numbers to unacked + // updates. These updates can be resent after a restart if the update + // failed to send or receive an acknowledgment. + CommittedUpdates map[uint16]*CommittedUpdate + + // AckedUpdates is a map from sequence number to backup id to record + // which revoked states were uploaded via this session. + AckedUpdates map[uint16]BackupID +} + +// BackupID identifies a particular revoked, remote commitment by channel id and +// commitment height. +type BackupID struct { + // ChanID is the channel id of the revoked commitment. + ChanID lnwire.ChannelID + + // CommitHeight is the commitment height of the revoked commitment. + CommitHeight uint64 +} + +// CommittedUpdate holds a state update sent by a client along with its +// SessionID. +type CommittedUpdate struct { + BackupID BackupID + + // Hint is the 16-byte prefix of the revoked commitment transaction ID. + Hint BreachHint + + // EncryptedBlob is a ciphertext containing the sweep information for + // exacting justice if the commitment transaction matching the breach + // hint is braodcast. + EncryptedBlob []byte +} From 04bbf39f51e4ddf02821648f45e3e2e5d7f55605 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:30:35 -0700 Subject: [PATCH 05/21] watchtower/wtclient/log: adds wtclient logging --- watchtower/wtclient/log.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 watchtower/wtclient/log.go diff --git a/watchtower/wtclient/log.go b/watchtower/wtclient/log.go new file mode 100644 index 000000000..8d2e37dda --- /dev/null +++ b/watchtower/wtclient/log.go @@ -0,0 +1,29 @@ +package wtclient + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger("WTCL", nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} From b1903451d99127778874b6f23821c1c4033a6ad0 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:30:47 -0700 Subject: [PATCH 06/21] watchtower/wtclient/interface: add DB ifaces --- watchtower/wtclient/interface.go | 76 ++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 watchtower/wtclient/interface.go diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go new file mode 100644 index 000000000..5164acea8 --- /dev/null +++ b/watchtower/wtclient/interface.go @@ -0,0 +1,76 @@ +package wtclient + +import ( + "net" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/brontide" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtserver" +) + +// DB abstracts the required database operations required by the watchtower +// client. +type DB interface { + // CreateTower initialize an address record used to communicate with a + // watchtower. Each Tower is assigned a unique ID, that is used to + // amortize storage costs of the public key when used by multiple + // sessions. + CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error) + + // CreateClientSession saves a newly negotiated client session to the + // client's database. This enables the session to be used across + // restarts. + CreateClientSession(*wtdb.ClientSession) error + + // ListClientSessions returns all sessions that have not yet been + // exhausted. This is used on startup to find any sessions which may + // still be able to accept state updates. + ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) + + // FetchChanPkScripts returns a map of all sweep pkscripts for + // registered channels. This is used on startup to cache the sweep + // pkscripts of registered channels in memory. + FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) + + // AddChanPkScript inserts a newly generated sweep pkscript for the + // given channel. + AddChanPkScript(lnwire.ChannelID, []byte) error + + // MarkBackupIneligible records that the state identified by the + // (channel id, commit height) tuple was ineligible for being backed up + // under the current policy. This state can be retried later under a + // different policy. + MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error + + // CommitUpdate writes the next state update for a particular + // session, so that we can be sure to resend it after a restart if it + // hasn't been ACK'd by the tower. The sequence number of the update + // should be exactly one greater than the existing entry, and less that + // or equal to the session's MaxUpdates. + CommitUpdate(id *wtdb.SessionID, seqNum uint16, + update *wtdb.CommittedUpdate) (uint16, error) + + // AckUpdate records an acknowledgment from the watchtower that the + // update identified by seqNum was received and saved. The returned + // lastApplied will be recorded. + AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error +} + +// Dial connects to an addr using the specified net and returns the connection +// object. +type Dial func(net, addr string) (net.Conn, error) + +// AuthDialer connects to a remote node using an authenticated transport, such as +// brontide. The dialer argument is used to specify a resolver, which allows +// this method to be used over Tor or clear net connections. +type AuthDialer func(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress, + dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error) + +// AuthDial is the watchtower client's default method of dialing. +func AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress, + dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error) { + + return brontide.Dial(localPriv, netAddr, dialer) +} From 4642954e722f4d5b1ec1166267e39f954eb1d4a4 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:30:59 -0700 Subject: [PATCH 07/21] watchtower/wtclient/backup_task: bind to ClientSession instead of SessionInfo --- watchtower/wtclient/backup_task.go | 15 ++++++++------- watchtower/wtclient/backup_task_internal_test.go | 16 ++++++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/watchtower/wtclient/backup_task.go b/watchtower/wtclient/backup_task.go index 17c4c5a7b..c88bfd0f9 100644 --- a/watchtower/wtclient/backup_task.go +++ b/watchtower/wtclient/backup_task.go @@ -34,9 +34,8 @@ import ( // necessary components are stripped out and encrypted before being sent to // the tower in a StateUpdate. type backupTask struct { - chanID lnwire.ChannelID - commitHeight uint64 - breachInfo *lnwallet.BreachRetribution + id wtdb.BackupID + breachInfo *lnwallet.BreachRetribution // state-dependent variables @@ -96,8 +95,10 @@ func newBackupTask(chanID *lnwire.ChannelID, } return &backupTask{ - chanID: *chanID, - commitHeight: breachInfo.RevokedStateNum, + id: wtdb.BackupID{ + ChanID: *chanID, + CommitHeight: breachInfo.RevokedStateNum, + }, breachInfo: breachInfo, toLocalInput: toLocalInput, toRemoteInput: toRemoteInput, @@ -125,7 +126,7 @@ func (t *backupTask) inputs() map[wire.OutPoint]input.Input { // SessionInfo's policy. If no error is returned, the task has been bound to the // session and can be queued to upload to the tower. Otherwise, the bind failed // and should be rescheduled with a different session. -func (t *backupTask) bindSession(session *wtdb.SessionInfo) error { +func (t *backupTask) bindSession(session *wtdb.ClientSession) error { // First we'll begin by deriving a weight estimate for the justice // transaction. The final weight can be different depending on whether @@ -154,7 +155,7 @@ func (t *backupTask) bindSession(session *wtdb.SessionInfo) error { // in the current session's policy. outputs, err := session.Policy.ComputeJusticeTxOuts( t.totalAmt, int64(weightEstimate.Weight()), - t.sweepPkScript, session.RewardAddress, + t.sweepPkScript, session.RewardPkScript, ) if err != nil { return err diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index c38e5c974..2c25c9a02 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -69,7 +69,7 @@ type backupTaskTest struct { expSweepAmt int64 expRewardAmt int64 expRewardScript []byte - session *wtdb.SessionInfo + session *wtdb.ClientSession bindErr error expSweepScript []byte signer input.Signer @@ -205,13 +205,13 @@ func genTaskTest( expSweepAmt: expSweepAmt, expRewardAmt: expRewardAmt, expRewardScript: rewardScript, - session: &wtdb.SessionInfo{ + session: &wtdb.ClientSession{ Policy: wtpolicy.Policy{ BlobType: blobType, SweepFeeRate: sweepFeeRate, RewardRate: 10000, }, - RewardAddress: rewardScript, + RewardPkScript: rewardScript, }, bindErr: bindErr, expSweepScript: makeAddrSlice(22), @@ -379,7 +379,7 @@ var backupTaskTests = []backupTaskTest{ } // TestBackupTaskBind tests the initialization and binding of a backupTask to a -// SessionInfo. After a succesfful bind, all parameters of the justice +// ClientSession. After a successful bind, all parameters of the justice // transaction should be solidified, so we assert there correctness. In an // unsuccessful bind, the session-dependent parameters should be unmodified so // that the backup task can be rescheduled if necessary. Finally, we assert that @@ -401,14 +401,14 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Assert that all parameters set during initialization are properly // populated. - if task.chanID != test.chanID { + if task.id.ChanID != test.chanID { t.Fatalf("channel id mismatch, want: %s, got: %s", - test.chanID, task.chanID) + test.chanID, task.id.ChanID) } - if task.commitHeight != test.breachInfo.RevokedStateNum { + if task.id.CommitHeight != test.breachInfo.RevokedStateNum { t.Fatalf("commit height mismatch, want: %d, got: %d", - test.breachInfo.RevokedStateNum, task.commitHeight) + test.breachInfo.RevokedStateNum, task.id.CommitHeight) } if task.totalAmt != test.expTotalAmt { From b23bff62d53914a8903a34653ca497fe1214b37b Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:31:11 -0700 Subject: [PATCH 08/21] watchtower/wtclient/errors --- watchtower/wtclient/errors.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 watchtower/wtclient/errors.go diff --git a/watchtower/wtclient/errors.go b/watchtower/wtclient/errors.go new file mode 100644 index 000000000..71b68cf15 --- /dev/null +++ b/watchtower/wtclient/errors.go @@ -0,0 +1,35 @@ +package wtclient + +import "errors" + +var ( + // ErrClientExiting signals that the watchtower client is shutting down. + ErrClientExiting = errors.New("watchtower client shutting down") + + // ErrTowerCandidatesExhausted signals that a TowerCandidateIterator has + // cycled through all available candidates. + ErrTowerCandidatesExhausted = errors.New("exhausted all tower " + + "candidates") + + // ErrPermanentTowerFailure signals that the tower has reported that it + // has permanently failed or the client believes this has happened based + // on the tower's behavior. + ErrPermanentTowerFailure = errors.New("permanent tower failure") + + // ErrNegotiatorExiting signals that the SessionNegotiator is shutting + // down. + ErrNegotiatorExiting = errors.New("negotiator exiting") + + // ErrNoTowerAddrs signals that the client could not be created because + // we have no addresses with which we can reach a tower. + ErrNoTowerAddrs = errors.New("no tower addresses") + + // ErrFailedNegotiation signals that the session negotiator could not + // acquire a new session as requested. + ErrFailedNegotiation = errors.New("session negotiation unsuccessful") + + // ErrUnregisteredChannel signals that the client was unable to backup a + // revoked state becuase the channel had not been previously registered + // with the client. + ErrUnregisteredChannel = errors.New("channel is not registered") +) From a8721bcedf4ba42d356ac8bd2fef4ad09d84bd8b Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:31:24 -0700 Subject: [PATCH 09/21] watchtower/wtclient/tower_candidate_iterator: linked-list iterator --- watchtower/wtclient/candidate_iterator.go | 82 +++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 watchtower/wtclient/candidate_iterator.go diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go new file mode 100644 index 000000000..aafdee1a3 --- /dev/null +++ b/watchtower/wtclient/candidate_iterator.go @@ -0,0 +1,82 @@ +package wtclient + +import ( + "container/list" + "sync" + + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +// TowerCandidateIterator provides an abstraction for iterating through possible +// watchtower addresses when attempting to create a new session. +type TowerCandidateIterator interface { + // Reset clears any internal iterator state, making previously taken + // candidates available as long as they remain in the set. + Reset() error + + // Next returns the next candidate tower. The iterator is not required + // to return results in any particular order. If no more candidates are + // available, ErrTowerCandidatesExhausted is returned. + Next() (*wtdb.Tower, error) +} + +// towerListIterator is a linked-list backed TowerCandidateIterator. +type towerListIterator struct { + mu sync.Mutex + candidates *list.List + nextCandidate *list.Element +} + +// Compile-time constraint to ensure *towerListIterator implements the +// TowerCandidateIterator interface. +var _ TowerCandidateIterator = (*towerListIterator)(nil) + +// newTowerListIterator initializes a new towerListIterator from a variadic list +// of lnwire.NetAddresses. +func newTowerListIterator(candidates ...*wtdb.Tower) *towerListIterator { + iter := &towerListIterator{ + candidates: list.New(), + } + + for _, candidate := range candidates { + iter.candidates.PushBack(candidate) + } + iter.Reset() + + return iter +} + +// Reset clears the iterators state, and makes the address at the front of the +// list the next item to be returned.. +func (t *towerListIterator) Reset() error { + t.mu.Lock() + defer t.mu.Unlock() + + // Reset the next candidate to the front of the linked-list. + t.nextCandidate = t.candidates.Front() + + return nil +} + +// Next returns the next candidate tower. This iterator will always return +// candidates in the order given when the iterator was instantiated. If no more +// candidates are available, ErrTowerCandidatesExhausted is returned. +func (t *towerListIterator) Next() (*wtdb.Tower, error) { + t.mu.Lock() + defer t.mu.Unlock() + + // If the next candidate is nil, we've exhausted the list. + if t.nextCandidate == nil { + return nil, ErrTowerCandidatesExhausted + } + + // Propose the tower at the front of the list. + tower := t.nextCandidate.Value.(*wtdb.Tower) + + // Set the next candidate to the subsequent element. + t.nextCandidate = t.nextCandidate.Next() + + return tower, nil +} + +// TODO(conner): implement graph-backed candidate iterator for public towers. From 95fa7659e0405a4658a3698f7215a3d7cae06473 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:31:37 -0700 Subject: [PATCH 10/21] watchtower/wtclient/session_negotiator: add session negotiation --- watchtower/wtclient/session_negotiator.go | 451 ++++++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 watchtower/wtclient/session_negotiator.go diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go new file mode 100644 index 000000000..b62819cb3 --- /dev/null +++ b/watchtower/wtclient/session_negotiator.go @@ -0,0 +1,451 @@ +package wtclient + +import ( + "fmt" + "sync" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/lightningnetwork/lnd/watchtower/wtserver" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +// SessionNegotiator is an interface for asynchronously requesting new sessions. +type SessionNegotiator interface { + // RequestSession signals to the session negotiator that the client + // needs another session. Once the session is negotiated, it should be + // returned via NewSessions. + RequestSession() + + // NewSessions is a read-only channel where newly negotiated sessions + // will be delivered. + NewSessions() <-chan *wtdb.ClientSession + + // Start safely initializes the session negotiator. + Start() error + + // Stop safely shuts down the session negotiator. + Stop() error +} + +// NegotiatorConfig provides access to the resources required by a +// SessionNegotiator to faithfully carry out its duties. All nil-able field must +// be initialized. +type NegotiatorConfig struct { + // DB provides access to a persistent storage medium used by the tower + // to properly allocate session ephemeral keys and record successfully + // negotiated sessions. + DB DB + + // Candidates is an abstract set of tower candidates that the negotiator + // will traverse serially when attempting to negotiate a new session. + Candidates TowerCandidateIterator + + // Policy defines the session policy that will be proposed to towers + // when attempting to negotiate a new session. This policy will be used + // across all negotiation proposals for the lifetime of the negotiator. + Policy wtpolicy.Policy + + // Dial initiates an outbound brontide connection to the given address + // using a specified private key. The peer is returned in the event of a + // successful connection. + Dial func(*btcec.PrivateKey, *lnwire.NetAddress) (wtserver.Peer, error) + + // SendMessage writes a wtwire message to remote peer. + SendMessage func(wtserver.Peer, wtwire.Message) error + + // ReadMessage reads a message from a remote peer and returns the + // decoded wtwire message. + ReadMessage func(wtserver.Peer) (wtwire.Message, error) + + // ChainHash the genesis hash identifying the chain for any negotiated + // sessions. Any state updates sent to that session should also + // originate from this chain. + ChainHash chainhash.Hash + + // MinBackoff defines the initial backoff applied by the session + // negotiator after all tower candidates have been exhausted and + // reattempting negotiation with the same set of candidates. Subsequent + // backoff durations will grow exponentially. + MinBackoff time.Duration + + // MaxBackoff defines the maximum backoff applied by the session + // negotiator after all tower candidates have been exhausted and + // reattempting negotation with the same set of candidates. If the + // exponential backoff produces a timeout greater than this value, the + // backoff duration will be clamped to MaxBackoff. + MaxBackoff time.Duration +} + +// sessionNegotiator is concrete SessionNegotiator that is able to request new +// sessions from a set of candidate towers asynchronously and return successful +// sessions to the primary client. +type sessionNegotiator struct { + started sync.Once + stopped sync.Once + + localInit *wtwire.Init + + cfg *NegotiatorConfig + + dispatcher chan struct{} + newSessions chan *wtdb.ClientSession + successfulNegotiations chan *wtdb.ClientSession + + wg sync.WaitGroup + quit chan struct{} +} + +// Compile-time constraint to ensure a *sessionNegotiator implements the +// SessionNegotiator interface. +var _ SessionNegotiator = (*sessionNegotiator)(nil) + +// newSessionNegotiator initializes a fresh sessionNegotiator instance. +func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator { + localInit := wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + cfg.ChainHash, + ) + + return &sessionNegotiator{ + cfg: cfg, + localInit: localInit, + dispatcher: make(chan struct{}, 1), + newSessions: make(chan *wtdb.ClientSession), + successfulNegotiations: make(chan *wtdb.ClientSession), + quit: make(chan struct{}), + } +} + +// Start safely starts up the sessionNegotiator. +func (n *sessionNegotiator) Start() error { + n.started.Do(func() { + log.Debugf("Starting session negotiator") + + n.wg.Add(1) + go n.negotiationDispatcher() + }) + + return nil +} + +// Stop safely shutsdown the sessionNegotiator. +func (n *sessionNegotiator) Stop() error { + n.stopped.Do(func() { + log.Debugf("Stopping session negotiator") + + close(n.quit) + n.wg.Wait() + }) + + return nil +} + +// NewSessions returns a receive-only channel from which newly negotiated +// sessions will be returned. +func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession { + return n.newSessions +} + +// RequestSession sends a request to the sessionNegotiator to begin requesting a +// new session. If one is already in the process of being negotiated, the +// request will be ignored. +func (n *sessionNegotiator) RequestSession() { + select { + case n.dispatcher <- struct{}{}: + default: + } +} + +// negotiationDispatcher acts as the primary event loop for the +// sessionNegotiator, coordinating requests for more sessions and dispatching +// attempts to negotiate them from a list of candidates. +func (n *sessionNegotiator) negotiationDispatcher() { + defer n.wg.Done() + + var pendingNegotiations int + for { + select { + case <-n.dispatcher: + pendingNegotiations++ + + if pendingNegotiations > 1 { + log.Debugf("Already negotiating session, " + + "waiting for existing negotiation to " + + "complete") + continue + } + + // TODO(conner): consider reusing good towers + + log.Debugf("Dispatching session negotiation") + + n.wg.Add(1) + go n.negotiate() + + case session := <-n.successfulNegotiations: + select { + case n.newSessions <- session: + pendingNegotiations-- + case <-n.quit: + return + } + + if pendingNegotiations > 0 { + log.Debugf("Dispatching pending session " + + "negotiation") + + n.wg.Add(1) + go n.negotiate() + } + + case <-n.quit: + return + } + } +} + +// negotiate handles the process of iterating through potential tower candidates +// and attempting to negotiate a new session until a successful negotiation +// occurs. If the candidate iterator becomes exhausted because none were +// successful, this method will back off exponentially up to the configured max +// backoff. This method will continue trying until a negotiation is succesful +// before returning the negotiated session to the dispatcher via the succeed +// channel. +// +// NOTE: This method MUST be run as a goroutine. +func (n *sessionNegotiator) negotiate() { + defer n.wg.Done() + + // On the first pass, initialize the backoff to our configured min + // backoff. + backoff := n.cfg.MinBackoff + +retryWithBackoff: + // If we are retrying, wait out the delay before continuing. + if backoff > 0 { + select { + case <-time.After(backoff): + case <-n.quit: + return + } + } + + // Before attempting a bout of session negotiation, reset the candidate + // iterator to ensure the results are fresh. + n.cfg.Candidates.Reset() + for { + // Pull the next candidate from our list of addresses. + tower, err := n.cfg.Candidates.Next() + if err != nil { + // We've run out of addresses, double and clamp backoff. + backoff *= 2 + if backoff > n.cfg.MaxBackoff { + backoff = n.cfg.MaxBackoff + } + + log.Debugf("Unable to get new tower candidate, "+ + "retrying after %v -- reason: %v", backoff, err) + + goto retryWithBackoff + } + + log.Debugf("Attempting session negotiation with tower=%x", + tower.IdentityKey.SerializeCompressed()) + + // We'll now attempt the CreateSession dance with the tower to + // get a new session, trying all addresses if necessary. + err = n.createSession(tower) + if err != nil { + log.Debugf("Session negotiation with tower=%x "+ + "failed, trying again -- reason: %v", + tower.IdentityKey.SerializeCompressed(), err) + continue + } + + // Success. + return + } +} + +// createSession takes a tower an attempts to negotiate a session using any of +// its stored addresses. This method returns after the first successful +// negotiation, or after all addresses have failed with ErrFailedNegotiation. If +// the tower has no addresses, ErrNoTowerAddrs is returned. +func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error { + // If the tower has no addresses, there's nothing we can do. + if len(tower.Addresses) == 0 { + return ErrNoTowerAddrs + } + + // TODO(conner): create with hdkey at random index + sessionPrivKey, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return err + } + + // TODO(conner): write towerAddr+privkey + + for _, lnAddr := range tower.LNAddrs() { + err = n.tryAddress(sessionPrivKey, tower, lnAddr) + switch { + case err == ErrPermanentTowerFailure: + // TODO(conner): report to iterator? can then be reset + // with restart + fallthrough + + case err != nil: + log.Debugf("Request for session negotiation with "+ + "tower=%s failed, trying again -- reason: "+ + "%v", lnAddr, err) + continue + + default: + return nil + } + } + + return ErrFailedNegotiation +} + +// tryAddress executes a single create session dance using the given address. +// The address should belong to the tower's set of addresses. This method only +// returns true if all steps succeed and the new session has been persisted, and +// fails otherwise. +func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey, + tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { + + // Connect to the tower address using our generated session key. + conn, err := n.cfg.Dial(privKey, lnAddr) + if err != nil { + return err + } + + // Send local Init message. + err = n.cfg.SendMessage(conn, n.localInit) + if err != nil { + return fmt.Errorf("unable to send Init: %v", err) + } + + // Receive remote Init message. + remoteMsg, err := n.cfg.ReadMessage(conn) + if err != nil { + return fmt.Errorf("unable to read Init: %v", err) + } + + // Check that returned message is wtwire.Init. + remoteInit, ok := remoteMsg.(*wtwire.Init) + if !ok { + return fmt.Errorf("expected Init, got %T in reply", remoteMsg) + } + + // Verify the watchtower's remote Init message against our own. + err = n.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames) + if err != nil { + return err + } + + policy := n.cfg.Policy + createSession := &wtwire.CreateSession{ + BlobType: policy.BlobType, + MaxUpdates: policy.MaxUpdates, + RewardBase: policy.RewardBase, + RewardRate: policy.RewardRate, + SweepFeeRate: policy.SweepFeeRate, + } + + // Send CreateSession message. + err = n.cfg.SendMessage(conn, createSession) + if err != nil { + return fmt.Errorf("unable to send CreateSession: %v", err) + } + + // Receive CreateSessionReply message. + remoteMsg, err = n.cfg.ReadMessage(conn) + if err != nil { + return fmt.Errorf("unable to read CreateSessionReply: %v", err) + } + + // Check that returned message is wtwire.CreateSessionReply. + createSessionReply, ok := remoteMsg.(*wtwire.CreateSessionReply) + if !ok { + return fmt.Errorf("expected CreateSessionReply, got %T in "+ + "reply", remoteMsg) + } + + switch createSessionReply.Code { + case wtwire.CodeOK, wtwire.CreateSessionCodeAlreadyExists: + + // TODO(conner): add last-applied to create session reply to + // handle case where we lose state, session already exists, and + // we want to possibly resume using the session + + // TODO(conner): validate reward address + rewardPkScript := createSessionReply.Data + + sessionID := wtdb.NewSessionIDFromPubKey( + privKey.PubKey(), + ) + clientSession := &wtdb.ClientSession{ + TowerID: tower.ID, + Tower: tower, + SessionPrivKey: privKey, // remove after using HD keys + ID: sessionID, + Policy: n.cfg.Policy, + SeqNum: 0, + RewardPkScript: rewardPkScript, + } + + err = n.cfg.DB.CreateClientSession(clientSession) + if err != nil { + return fmt.Errorf("unable to persist ClientSession: %v", + err) + } + + log.Debugf("New session negotiated with %s, policy: %s", + lnAddr, clientSession.Policy) + + // We have a newly negotiated session, return it to the + // dispatcher so that it can update how many outstanding + // negotiation requests we have. + select { + case n.successfulNegotiations <- clientSession: + return nil + case <-n.quit: + return ErrNegotiatorExiting + } + + // TODO(conner): handle error codes properly + case wtwire.CreateSessionCodeRejectBlobType: + return fmt.Errorf("tower rejected blob type: %v", + policy.BlobType) + + case wtwire.CreateSessionCodeRejectMaxUpdates: + return fmt.Errorf("tower rejected max updates: %v", + policy.MaxUpdates) + + case wtwire.CreateSessionCodeRejectRewardRate: + // The tower rejected the session because of the reward rate. If + // we didn't request a reward session, we'll treat this as a + // permanent tower failure. + if !policy.BlobType.Has(blob.FlagReward) { + return ErrPermanentTowerFailure + } + + return fmt.Errorf("tower rejected reward rate: %v", + policy.RewardRate) + + case wtwire.CreateSessionCodeRejectSweepFeeRate: + return fmt.Errorf("tower rejected sweep fee rate: %v", + policy.SweepFeeRate) + + default: + return fmt.Errorf("received unhandled error code: %v", + createSessionReply.Code) + } +} From 65d09fca6439c62ae0b73fdd4b7b28c119666a82 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:31:50 -0700 Subject: [PATCH 11/21] watchtower/wtclient/task_pipeline: add reliable task aggregator --- watchtower/wtclient/task_pipeline.go | 185 +++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 watchtower/wtclient/task_pipeline.go diff --git a/watchtower/wtclient/task_pipeline.go b/watchtower/wtclient/task_pipeline.go new file mode 100644 index 000000000..076c93f63 --- /dev/null +++ b/watchtower/wtclient/task_pipeline.go @@ -0,0 +1,185 @@ +package wtclient + +import ( + "container/list" + "sync" + "time" +) + +// taskPipeline implements a reliable, in-order queue that ensures its queue +// fully drained before exiting. Stopping the taskPipeline prevents the pipeline +// from accepting any further tasks, and will cause the pipeline to exit after +// all updates have been delivered to the downstream receiver. If this process +// hangs and is unable to make progress, users can optionally call ForceQuit to +// abandon the reliable draining of the queue in order to permit shutdown. +type taskPipeline struct { + started sync.Once + stopped sync.Once + forced sync.Once + + queueMtx sync.Mutex + queueCond *sync.Cond + queue *list.List + + newBackupTasks chan *backupTask + + quit chan struct{} + forceQuit chan struct{} + shutdown chan struct{} +} + +// newTaskPipeline initializes a new taskPipeline. +func newTaskPipeline() *taskPipeline { + rq := &taskPipeline{ + queue: list.New(), + newBackupTasks: make(chan *backupTask), + quit: make(chan struct{}), + forceQuit: make(chan struct{}), + shutdown: make(chan struct{}), + } + rq.queueCond = sync.NewCond(&rq.queueMtx) + + return rq +} + +// Start spins up the taskPipeline, making it eligible to begin receiving backup +// tasks and deliver them to the receiver of NewBackupTasks. +func (q *taskPipeline) Start() { + q.started.Do(func() { + go q.queueManager() + }) +} + +// Stop begins a graceful shutdown of the taskPipeline. This method returns once +// all backupTasks have been delivered via NewBackupTasks, or a ForceQuit causes +// the delivery of pending tasks to be interrupted. +func (q *taskPipeline) Stop() { + q.stopped.Do(func() { + log.Debugf("Stopping task pipeline") + + close(q.quit) + q.signalUntilShutdown() + + // Skip log if we also force quit. + select { + case <-q.forceQuit: + default: + log.Debugf("Task pipeline stopped successfully") + } + }) +} + +// ForceQuit signals the taskPipeline to immediately exit, dropping any +// backupTasks that have not been delivered via NewBackupTasks. +func (q *taskPipeline) ForceQuit() { + q.forced.Do(func() { + log.Infof("Force quitting task pipeline") + + close(q.forceQuit) + q.signalUntilShutdown() + + log.Infof("Task pipeline unclean shutdown complete") + }) +} + +// NewBackupTasks returns a read-only channel for enqueue backupTasks. The +// channel will be closed after a call to Stop and all pending tasks have been +// delivered, or if a call to ForceQuit is called before the pending entries +// have been drained. +func (q *taskPipeline) NewBackupTasks() <-chan *backupTask { + return q.newBackupTasks +} + +// QueueBackupTask enqueues a backupTask for reliable delivery to the consumer +// of NewBackupTasks. If the taskPipeline is shutting down, ErrClientExiting is +// returned. Otherwise, if QueueBackupTask returns nil it is guaranteed to be +// delivered via NewBackupTasks unless ForceQuit is called before completion. +func (q *taskPipeline) QueueBackupTask(task *backupTask) error { + q.queueCond.L.Lock() + select { + + // Reject new tasks after quit has been signaled. + case <-q.quit: + q.queueCond.L.Unlock() + return ErrClientExiting + + // Reject new tasks after force quit has been signaled. + case <-q.forceQuit: + q.queueCond.L.Unlock() + return ErrClientExiting + + default: + } + + // Queue the new task and signal the queue's condition variable to wake up + // the queueManager for processing. + q.queue.PushBack(task) + q.queueCond.L.Unlock() + + q.queueCond.Signal() + + return nil +} + +// queueManager processes all incoming backup requests that get added via +// QueueBackupTask. The manager will exit +// +// NOTE: This method MUST be run as a goroutine. +func (q *taskPipeline) queueManager() { + defer close(q.shutdown) + defer close(q.newBackupTasks) + + for { + q.queueCond.L.Lock() + for q.queue.Front() == nil { + q.queueCond.Wait() + + select { + case <-q.quit: + // Exit only after the queue has been fully drained. + if q.queue.Len() == 0 { + q.queueCond.L.Unlock() + log.Debugf("Revoked state pipeline flushed.") + return + } + + case <-q.forceQuit: + q.queueCond.L.Unlock() + log.Debugf("Revoked state pipeline force quit.") + return + + default: + } + } + + // Pop the first element from the queue. + e := q.queue.Front() + task := q.queue.Remove(e).(*backupTask) + q.queueCond.L.Unlock() + + select { + + // Backup task submitted to dispatcher. We don't select on quit to + // ensure that we still drain tasks while shutting down. + case q.newBackupTasks <- task: + + // Force quit, return immediately to allow the client to exit. + case <-q.forceQuit: + log.Debugf("Revoked state pipeline force quit.") + return + } + } +} + +// signalUntilShutdown strobes the queue's condition variable to ensure the +// queueManager reliably unblocks to check for the exit condition. +func (q *taskPipeline) signalUntilShutdown() { + for { + select { + case <-time.After(time.Millisecond): + q.queueCond.Signal() + case <-q.shutdown: + return + } + } +} From aa2b21117ce05673e8401a21d6a79bbba659ed37 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:32:02 -0700 Subject: [PATCH 12/21] watchtower/wtclient/session_queue: batch upload state updates --- watchtower/wtclient/session_queue.go | 688 +++++++++++++++++++++++++++ 1 file changed, 688 insertions(+) create mode 100644 watchtower/wtclient/session_queue.go diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go new file mode 100644 index 000000000..cc946811e --- /dev/null +++ b/watchtower/wtclient/session_queue.go @@ -0,0 +1,688 @@ +package wtclient + +import ( + "container/list" + "fmt" + "sort" + "sync" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtserver" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +// retryInterval is the default duration we will wait between attempting to +// connect back out to a tower if the prior state update failed. +const retryInterval = 2 * time.Second + +// reserveStatus is an enum that signals how full a particular session is. +type reserveStatus uint8 + +const ( + // reserveAvailable indicates that the session has space for at least + // one more backup. + reserveAvailable reserveStatus = iota + + // reserveExhausted indicates that all slots in the session have been + // allocated. + reserveExhausted +) + +// sessionQueueConfig bundles the resources required by the sessionQueue to +// perform its duties. All entries MUST be non-nil. +type sessionQueueConfig struct { + // ClientSession provides access to the negotiated session parameters + // and updating its persistent storage. + ClientSession *wtdb.ClientSession + + // ChainHash identifies the chain for which the session's justice + // transactions are targeted. + ChainHash chainhash.Hash + + // Dial allows the client to dial the tower using it's public key and + // net address. + Dial func(*btcec.PrivateKey, + *lnwire.NetAddress) (wtserver.Peer, error) + + // SendMessage encodes, encrypts, and writes a message to the given peer. + SendMessage func(wtserver.Peer, wtwire.Message) error + + // ReadMessage receives, decypts, and decodes a message from the given + // peer. + ReadMessage func(wtserver.Peer) (wtwire.Message, error) + + // Signer facilitates signing of inputs, used to construct the witnesses + // for justice transaction inputs. + Signer input.Signer + + // DB provides access to the client's stable storage. + DB DB + + // MinBackoff defines the initial backoff applied by the session + // queue before reconnecting to the tower after a failed or partially + // successful batch is sent. Subsequent backoff durations will grow + // exponentially up until MaxBackoff. + MinBackoff time.Duration + + // MaxBackoff defines the maximum backoff applied by the session + // queue before reconnecting to the tower after a failed or partially + // successful batch is sent. If the exponential backoff produces a + // timeout greater than this value, the backoff duration will be clamped + // to MaxBackoff. + MaxBackoff time.Duration +} + +// sessionQueue implements a reliable queue that will encrypt and send accepted +// backups to the watchtower specified in the config's ClientSession. Calling +// Quit will attempt to perform a clean shutdown by receiving an ACK from the +// tower for all pending backups before exiting. The clean shutdown can be +// aborted by using ForceQuit, which will attempt to shutdown the queue +// immediately. +type sessionQueue struct { + started sync.Once + stopped sync.Once + forced sync.Once + + cfg *sessionQueueConfig + + commitQueue *list.List + pendingQueue *list.List + queueMtx sync.Mutex + queueCond *sync.Cond + + localInit *wtwire.Init + towerAddr *lnwire.NetAddress + + seqNum uint16 + + retryBackoff time.Duration + + quit chan struct{} + forceQuit chan struct{} + shutdown chan struct{} +} + +// newSessionQueue intiializes a fresh sessionQueue. +func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { + localInit := wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + cfg.ChainHash, + ) + + towerAddr := &lnwire.NetAddress{ + IdentityKey: cfg.ClientSession.Tower.IdentityKey, + Address: cfg.ClientSession.Tower.Addresses[0], + } + + sq := &sessionQueue{ + cfg: cfg, + commitQueue: list.New(), + pendingQueue: list.New(), + localInit: localInit, + towerAddr: towerAddr, + seqNum: cfg.ClientSession.SeqNum, + retryBackoff: cfg.MinBackoff, + quit: make(chan struct{}), + forceQuit: make(chan struct{}), + shutdown: make(chan struct{}), + } + sq.queueCond = sync.NewCond(&sq.queueMtx) + + sq.restoreCommittedUpdates() + + return sq +} + +// Start idempotently starts the sessionQueue so that it can begin accepting +// backups. +func (q *sessionQueue) Start() { + q.started.Do(func() { + // TODO(conner): load prior committed state updates from disk an + // populate in queue. + + go q.sessionManager() + }) +} + +// Stop idempotently stops the sessionQueue by initiating a clean shutdown that +// will clear all pending tasks in the queue before returning to the caller. +func (q *sessionQueue) Stop() { + q.stopped.Do(func() { + log.Debugf("Stopping session queue %s", q.ID()) + + close(q.quit) + q.signalUntilShutdown() + + // Skip log if we also force quit. + select { + case <-q.forceQuit: + return + default: + } + + log.Debugf("Session queue %s successfully stopped", q.ID()) + }) +} + +// ForceQuit idempotently aborts any clean shutdown in progress and returns to +// he caller after all lingering goroutines have spun down. +func (q *sessionQueue) ForceQuit() { + q.forced.Do(func() { + log.Infof("Force quitting session queue %s", q.ID()) + + close(q.forceQuit) + q.signalUntilShutdown() + + log.Infof("Session queue %s unclean shutdown complete", q.ID()) + }) +} + +// ID returns the wtdb.SessionID for the queue, which can be used to uniquely +// identify this a particular queue. +func (q *sessionQueue) ID() *wtdb.SessionID { + return &q.cfg.ClientSession.ID +} + +// AcceptTask attempts to queue a backupTask for delivery to the sessionQueue's +// tower. The session will only be accepted if the queue is not already +// exhausted and the task is successfully bound to the ClientSession. +func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) { + q.queueCond.L.Lock() + + // Examine the current reserve status of the session queue. + curStatus := q.reserveStatus() + switch curStatus { + + // The session queue is exhausted, and cannot accept the task because it + // is full. Reject the task such that it can be tried against a + // different session. + case reserveExhausted: + q.queueCond.L.Unlock() + return curStatus, false + + // The session queue is not exhausted. Compute the sweep and reward + // outputs as a function of the session parameters. If the outputs are + // dusty or uneconomical to backup, the task is rejected and will not be + // tried again. + // + // TODO(conner): queue backups and retry with different session params. + case reserveAvailable: + err := task.bindSession(q.cfg.ClientSession) + if err != nil { + q.queueCond.L.Unlock() + log.Debugf("SessionQueue %s rejected backup chanid=%s "+ + "commit-height=%d: %v", q.ID(), task.id.ChanID, + task.id.CommitHeight, err) + return curStatus, false + } + } + + // The sweep and reward outputs satisfy the session's policy, queue the + // task for final signing and delivery. + q.pendingQueue.PushBack(task) + + // Finally, compute the session's *new* reserve status. This will be + // used by the client to determine if it can continue using this session + // queue, or if it should negotiate a new one. + newStatus := q.reserveStatus() + q.queueCond.L.Unlock() + + q.queueCond.Signal() + + return newStatus, true +} + +// updateWithSeqNum stores a CommittedUpdate with its assigned sequence number. +// This allows committed updates to be sorted after a restart, and added to the +// commitQueue in the proper order for delivery. +type updateWithSeqNum struct { + seqNum uint16 + update *wtdb.CommittedUpdate +} + +// restoreCommittedUpdates processes any CommittedUpdates loaded on startup by +// sorting them in ascending order of sequence numbers and adding them to the +// commitQueue. These will be sent before any pending updates are processed. +func (q *sessionQueue) restoreCommittedUpdates() { + committedUpdates := q.cfg.ClientSession.CommittedUpdates + + // Construct and unordered slice of all committed updates with their + // assigned sequence numbers. + sortedUpdates := make([]updateWithSeqNum, 0, len(committedUpdates)) + for seqNum, update := range committedUpdates { + sortedUpdates = append(sortedUpdates, updateWithSeqNum{ + seqNum: seqNum, + update: update, + }) + } + + // Sort the resulting slice by increasing sequence number. + sort.Slice(sortedUpdates, func(i, j int) bool { + return sortedUpdates[i].seqNum < sortedUpdates[j].seqNum + }) + + // Finally, add the sorted, committed updates to he commitQueue. These + // updates will be prioritized before any new tasks are assigned to the + // sessionQueue. The queue will begin uploading any tasks in the + // commitQueue as soon as it is started, e.g. during client + // initialization when detecting that this session has unacked updates. + for _, update := range sortedUpdates { + q.commitQueue.PushBack(update) + } +} + +// sessionManager is the primary event loop for the sessionQueue, and is +// responsible for encrypting and sending accepted tasks to the tower. +func (q *sessionQueue) sessionManager() { + defer close(q.shutdown) + + for { + q.queueCond.L.Lock() + for q.commitQueue.Len() == 0 && + q.pendingQueue.Len() == 0 { + + q.queueCond.Wait() + + select { + case <-q.quit: + if q.commitQueue.Len() == 0 && + q.pendingQueue.Len() == 0 { + q.queueCond.L.Unlock() + return + } + case <-q.forceQuit: + q.queueCond.L.Unlock() + return + default: + } + } + q.queueCond.L.Unlock() + + // Exit immediately if a force quit has been requested. If the + // either of the queues still has state updates to send to the + // tower, we may never exit in the above case if we are unable + // to reach the tower for some reason. + select { + case <-q.forceQuit: + return + default: + } + + // Initiate a new connection to the watchtower and attempt to + // drain all pending tasks. + q.drainBackups() + } +} + +// drainBackups attempts to send all pending updates in the queue to the tower. +func (q *sessionQueue) drainBackups() { + // First, check that we are able to dial this session's tower. + conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionPrivKey, q.towerAddr) + if err != nil { + log.Errorf("Unable to dial watchtower at %v: %v", + q.towerAddr, err) + + q.increaseBackoff() + select { + case <-time.After(q.retryBackoff): + case <-q.forceQuit: + } + return + } + defer conn.Close() + + // Begin draining the queue of pending state updates. Before the first + // update is sent, we will precede it with an Init message. If the first + // is successful, subsequent updates can be streamed without sending an + // Init. + for sendInit := true; ; sendInit = false { + // Generate the next state update to upload to the tower. This + // method will first proceed in dequeueing committed updates + // before attempting to dequeue any pending updates. + stateUpdate, isPending, err := q.nextStateUpdate() + if err != nil { + log.Errorf("Unable to get next state update: %v", err) + return + } + + // Now, send the state update to the tower and wait for a reply. + err = q.sendStateUpdate( + conn, stateUpdate, q.localInit, sendInit, isPending, + ) + if err != nil { + log.Errorf("Unable to send state update: %v", err) + + q.increaseBackoff() + select { + case <-time.After(q.retryBackoff): + case <-q.forceQuit: + } + return + } + + // If the last task was backed up successfully, we'll exit and + // continue once more tasks are added to the queue. We'll also + // clear any accumulated backoff as this batch was able to be + // sent reliably. + if stateUpdate.IsComplete == 1 { + q.resetBackoff() + return + } + + // Always apply a small delay between sends, which makes the + // unit tests more reliable. If we were requested to back off, + // when we will do so. + select { + case <-time.After(time.Millisecond): + case <-q.forceQuit: + return + } + } +} + +// nextStateUpdate returns the next wtwire.StateUpdate to upload to the tower. +// If any committed updates are present, this method will reconstruct the state +// update from the committed update using the current last applied value found +// in the database. Otherwise, it will select the next pending update, craft the +// payload, and commit an update before returning the state update to send. The +// boolean value in the response is true if the state update is taken from the +// pending queue, allowing the caller to remove the update from either the +// commit or pending queue if the update is successfully acked. +func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, error) { + var ( + seqNum uint16 + update *wtdb.CommittedUpdate + isLast bool + isPending bool + ) + + q.queueCond.L.Lock() + switch { + + // If the commit queue is non-empty, parse the next committed update. + case q.commitQueue.Len() > 0: + next := q.commitQueue.Front() + updateWithSeq := next.Value.(updateWithSeqNum) + + seqNum = updateWithSeq.seqNum + update = updateWithSeq.update + + // If this is the last item in the commit queue and no items + // exist in the pending queue, we will use the IsComplete flag + // in the StateUpdate to signal that the tower can release the + // connection after replying to free up resources. + isLast = q.commitQueue.Len() == 1 && q.pendingQueue.Len() == 0 + q.queueCond.L.Unlock() + + log.Debugf("Reprocessing committed state update for "+ + "session=%s seqnum=%d", q.ID(), seqNum) + + // Otherwise, craft and commit the next update from the pending queue. + default: + isPending = true + + // Determine the current sequence number to apply for this + // pending update. + seqNum = q.seqNum + 1 + + // Obtain the next task from the queue. + next := q.pendingQueue.Front() + task := next.Value.(*backupTask) + + // If this is the last item in the pending queue, we will use + // the IsComplete flag in the StateUpdate to signal that the + // tower can release the connection after replying to free up + // resources. + isLast = q.pendingQueue.Len() == 1 + q.queueCond.L.Unlock() + + hint, encBlob, err := task.craftSessionPayload(q.cfg.Signer) + if err != nil { + // TODO(conner): mark will not send + return nil, false, fmt.Errorf("unable to craft "+ + "session payload: %v", err) + } + // TODO(conner): special case other obscure errors + + update = &wtdb.CommittedUpdate{ + BackupID: task.id, + Hint: hint, + EncryptedBlob: encBlob, + } + + log.Debugf("Committing state update for session=%s seqnum=%d", + q.ID(), seqNum) + } + + // Before sending the task to the tower, commit the state update + // to disk using the assigned sequence number. If this task has already + // been committed, the call will succeed and only be used for the + // purpose of obtaining the last applied value to send to the tower. + // + // This step ensures that if we crash before receiving an ack that we + // will retransmit the same update. If the tower successfully received + // the update from before, it will reply with an ACK regardless of what + // we send the next time. This step ensures that if we reliably send the + // same update for a given sequence number, to prevent us from thinking + // we backed up a state when we instead backed up another. + lastApplied, err := q.cfg.DB.CommitUpdate(q.ID(), seqNum, update) + if err != nil { + // TODO(conner): mark failed/reschedule + return nil, false, fmt.Errorf("unable to commit state update "+ + "for session=%s seqnum=%d: %v", q.ID(), seqNum, err) + } + + stateUpdate := &wtwire.StateUpdate{ + SeqNum: seqNum, + LastApplied: lastApplied, + Hint: update.Hint, + EncryptedBlob: update.EncryptedBlob, + } + + // Set the IsComplete flag if this is the last queued item. + if isLast { + stateUpdate.IsComplete = 1 + } + + return stateUpdate, isPending, nil +} + +// sendStateUpdate sends a wtwire.StateUpdate to the watchtower and processes +// the ACK before returning. If sendInit is true, this method will first send +// the localInit message and verify that the tower supports our required feature +// bits. And error is returned if any part of the send fails. The boolean return +// variable indicates whether or not we should back off before attempting to +// send the next state update. +func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, + stateUpdate *wtwire.StateUpdate, localInit *wtwire.Init, + sendInit, isPending bool) error { + + // If this is the first message being sent to the tower, we must send an + // Init message to establish that server supports the features we + // require. + if sendInit { + // Send Init to tower. + err := q.cfg.SendMessage(conn, q.localInit) + if err != nil { + return err + } + + // Receive Init from tower. + remoteMsg, err := q.cfg.ReadMessage(conn) + if err != nil { + return err + } + + remoteInit, ok := remoteMsg.(*wtwire.Init) + if !ok { + return fmt.Errorf("watchtower responded with %T to "+ + "Init", remoteMsg) + } + + // Validate Init. + err = q.localInit.CheckRemoteInit( + remoteInit, wtwire.FeatureNames, + ) + if err != nil { + return err + } + } + + // Send StateUpdate to tower. + err := q.cfg.SendMessage(conn, stateUpdate) + if err != nil { + return err + } + + // Receive StateUpdate from tower. + remoteMsg, err := q.cfg.ReadMessage(conn) + if err != nil { + return err + } + + stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply) + if !ok { + return fmt.Errorf("watchtower responded with %T to StateUpdate", + remoteMsg) + } + + // Process the reply from the tower. + switch stateUpdateReply.Code { + + // The tower reported a successful update, validate the response and + // record the last applied returned. + case wtwire.CodeOK: + + // TODO(conner): handle other error cases properly, ban towers, etc. + default: + err := fmt.Errorf("received error code %s in "+ + "StateUpdateReply from tower=%x session=%s", + stateUpdateReply.Code, + conn.RemotePub().SerializeCompressed(), q.ID()) + log.Warnf("Unable to upload state update: %v", err) + return err + } + + lastApplied := stateUpdateReply.LastApplied + err = q.cfg.DB.AckUpdate(q.ID(), stateUpdate.SeqNum, lastApplied) + switch { + case err == wtdb.ErrUnallocatedLastApplied: + // TODO(conner): borked watchtower + err = fmt.Errorf("unable to ack update=%d session=%s: %v", + stateUpdate.SeqNum, q.ID(), err) + log.Errorf("Failed to ack update: %v", err) + return err + + case err == wtdb.ErrLastAppliedReversion: + // TODO(conner): borked watchtower + err = fmt.Errorf("unable to ack update=%d session=%s: %v", + stateUpdate.SeqNum, q.ID(), err) + log.Errorf("Failed to ack update: %v", err) + return err + + case err != nil: + err = fmt.Errorf("unable to ack update=%d session=%s: %v", + stateUpdate.SeqNum, q.ID(), err) + log.Errorf("Failed to ack update: %v", err) + return err + } + + log.Infof("Removing update session=%s seqnum=%d is_pending=%v "+ + "from memory", q.ID(), stateUpdate.SeqNum, isPending) + + q.queueCond.L.Lock() + if isPending { + // If a pending update was successfully sent, increment the + // sequence number and remove the item from the queue. This + // ensures the total number of backups in the session remains + // unchanged, which maintains the external view of the session's + // reserve status. + q.seqNum++ + q.pendingQueue.Remove(q.pendingQueue.Front()) + } else { + // Otherwise, simply remove the update from the committed queue. + // This has no effect on the queues reserve status since the + // update had already been committed. + q.commitQueue.Remove(q.commitQueue.Front()) + } + q.queueCond.L.Unlock() + + return nil +} + +// reserveStatus returns a reserveStatus indicating whether or not the +// sessionQueue can accept another task. reserveAvailable is returned when a +// task can be accepted, and reserveExhausted is returned if the all slots in +// the session have been allocated. +// +// NOTE: This method MUST be called with queueCond's exclusive lock held. +func (q *sessionQueue) reserveStatus() reserveStatus { + numPending := uint32(q.pendingQueue.Len()) + maxUpdates := uint32(q.cfg.ClientSession.Policy.MaxUpdates) + + log.Debugf("SessionQueue %s reserveStatus seqnum=%d pending=%d "+ + "max-updates=%d", q.ID(), q.seqNum, numPending, maxUpdates) + + if uint32(q.seqNum)+numPending < maxUpdates { + return reserveAvailable + } + + return reserveExhausted + +} + +// resetBackoff returns the connection backoff the minimum configured backoff. +func (q *sessionQueue) resetBackoff() { + q.retryBackoff = q.cfg.MinBackoff +} + +// increaseBackoff doubles the current connection backoff, clamping to the +// configured maximum backoff if it would exceed the limit. +func (q *sessionQueue) increaseBackoff() { + q.retryBackoff *= 2 + if q.retryBackoff > q.cfg.MaxBackoff { + q.retryBackoff = q.cfg.MaxBackoff + } +} + +// signalUntilShutdown strobes the sessionQueue's condition variable until the +// main event loop exits. +func (q *sessionQueue) signalUntilShutdown() { + for { + select { + case <-time.After(time.Millisecond): + q.queueCond.Signal() + case <-q.shutdown: + return + } + } +} + +// sessionQueueSet maintains a mapping of SessionIDs to their corresponding +// sessionQueue. +type sessionQueueSet map[wtdb.SessionID]*sessionQueue + +// Add inserts a sessionQueue into the sessionQueueSet. +func (s *sessionQueueSet) Add(sessionQueue *sessionQueue) { + (*s)[*sessionQueue.ID()] = sessionQueue +} + +// ApplyAndWait executes the nil-adic function returned from getApply for each +// sessionQueue in the set in parallel, then waits for all of them to finish +// before returning to the caller. +func (s *sessionQueueSet) ApplyAndWait(getApply func(*sessionQueue) func()) { + var wg sync.WaitGroup + for _, sessionq := range *s { + wg.Add(1) + go func(sq *sessionQueue) { + defer wg.Done() + getApply(sq)() + }(sessionq) + } + wg.Wait() +} From abef9e09e70c2125ed4cfb722651b6aa2b96acc1 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:32:15 -0700 Subject: [PATCH 13/21] watchtower/wtclient/stats: adds clientStats --- watchtower/wtclient/stats.go | 51 ++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 watchtower/wtclient/stats.go diff --git a/watchtower/wtclient/stats.go b/watchtower/wtclient/stats.go new file mode 100644 index 000000000..069c1c971 --- /dev/null +++ b/watchtower/wtclient/stats.go @@ -0,0 +1,51 @@ +package wtclient + +import "fmt" + +type clientStats struct { + numTasksReceived int + numTasksAccepted int + numTasksIneligible int + numSessionsAcquired int + numSessionsExhausted int +} + +// taskReceived increments the number to backup requests the client has received +// from active channels. +func (s *clientStats) taskReceived() { + s.numTasksReceived++ +} + +// taskAccepted increments the number of tasks that have been assigned to active +// session queues, and are awaiting upload to a tower. +func (s *clientStats) taskAccepted() { + s.numTasksAccepted++ +} + +// taskIneligible increments the number of tasks that were unable to satisfy the +// active session queue's policy. These can potentially be retried later, but +// typically this means that the balance created dust outputs, so it may not be +// worth backing up at all. +func (s *clientStats) taskIneligible() { + s.numTasksIneligible++ +} + +// sessionAcquired increments the number of sessions that have been successfully +// negotiated by the client during this execution. +func (s *clientStats) sessionAcquired() { + s.numSessionsAcquired++ +} + +// sessionExhausted increments the number of session that have become full as a +// result of accepting backup tasks. +func (s *clientStats) sessionExhausted() { + s.numSessionsExhausted++ +} + +// String returns a human readable summary of the client's metrics. +func (s clientStats) String() string { + return fmt.Sprintf("tasks(received=%d accepted=%d ineligible=%d) "+ + "sessions(acquired=%d exhausted=%d)", s.numTasksReceived, + s.numTasksAccepted, s.numTasksIneligible, s.numSessionsAcquired, + s.numSessionsExhausted) +} From f00b4c5e960888686e6e6dd334da89716b489bf3 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:32:27 -0700 Subject: [PATCH 14/21] watchtower/wtclient/client: hook up full client pipeline --- watchtower/wtclient/client.go | 804 ++++++++++++++++++++++++++++++++++ 1 file changed, 804 insertions(+) create mode 100644 watchtower/wtclient/client.go diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go new file mode 100644 index 000000000..9bd8e9e8a --- /dev/null +++ b/watchtower/wtclient/client.go @@ -0,0 +1,804 @@ +package wtclient + +import ( + "bytes" + "fmt" + "sync" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/lightningnetwork/lnd/watchtower/wtserver" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +const ( + // DefaultReadTimeout specifies the default duration we will wait during + // a read before breaking out of a blocking read. + DefaultReadTimeout = 15 * time.Second + + // DefaultWriteTimeout specifies the default duration we will wait during + // a write before breaking out of a blocking write. + DefaultWriteTimeout = 15 * time.Second + + // DefaultStatInterval specifies the default interval between logging + // metrics about the client's operation. + DefaultStatInterval = 30 * time.Second +) + +// Client is the primary interface used by the daemon to control a client's +// lifecycle and backup revoked states. +type Client interface { + // RegisterChannel persistently initializes any channel-dependent + // parameters within the client. This should be called during link + // startup to ensure that the client is able to support the link during + // operation. + RegisterChannel(lnwire.ChannelID) error + + // BackupState initiates a request to back up a particular revoked + // state. If the method returns nil, the backup is guaranteed to be + // successful unless the client is force quit, or the justice + // transaction would create dust outputs when trying to abide by the + // negotiated policy. + BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution) error + + // Start initializes the watchtower client, allowing it process requests + // to backup revoked channel states. + Start() error + + // Stop attempts a graceful shutdown of the watchtower client. In doing + // so, it will attempt to flush the pipeline and deliver any queued + // states to the tower before exiting. + Stop() error + + // ForceQuit will forcibly shutdown the watchtower client. Calling this + // may lead to queued states being dropped. + ForceQuit() +} + +// Config provides the TowerClient with access to the resources it requires to +// perform its duty. All nillable fields must be non-nil for the tower to be +// initialized properly. +type Config struct { + // Signer provides access to the wallet so that the client can sign + // justice transactions that spend from a remote party's commitment + // transaction. + Signer input.Signer + + // NewAddress generates a new on-chain sweep pkscript. + NewAddress func() ([]byte, error) + + // SecretKeyRing is used to derive the session keys used to communicate + // with the tower. The client only stores the KeyLocators internally so + // that we never store private keys on disk. + SecretKeyRing keychain.SecretKeyRing + + // Dial connects to an addr using the specified net and returns the + // connection object. + Dial Dial + + // AuthDialer establishes a brontide connection over an onion or clear + // network. + AuthDial AuthDialer + + // DB provides access to the client's stable storage medium. + DB DB + + // Policy is the session policy the client will propose when creating + // new sessions with the tower. If the policy differs from any active + // sessions recorded in the database, those sessions will be ignored and + // new sessions will be requested immediately. + Policy wtpolicy.Policy + + // PrivateTower is the net address of a private tower. The client will + // try to create all sessions with this tower. + PrivateTower *lnwire.NetAddress + + // ChainHash identifies the chain that the client is on and for which + // the tower must be watching to monitor for breaches. + ChainHash chainhash.Hash + + // ForceQuitDelay is the duration after attempting to shutdown that the + // client will automatically abort any pending backups if an unclean + // shutdown is detected. If the value is less than or equal to zero, a + // call to Stop may block indefinitely. The client can always be + // ForceQuit externally irrespective of the chosen parameter. + ForceQuitDelay time.Duration + + // ReadTimeout is the duration we will wait during a read before + // breaking out of a blocking read. If the value is less than or equal + // to zero, the default will be used instead. + ReadTimeout time.Duration + + // WriteTimeout is the duration we will wait during a write before + // breaking out of a blocking write. If the value is less than or equal + // to zero, the default will be used instead. + WriteTimeout time.Duration + + // MinBackoff defines the initial backoff applied to connections with + // watchtowers. Subsequent backoff durations will grow exponentially up + // until MaxBackoff. + MinBackoff time.Duration + + // MaxBackoff defines the maximum backoff applied to conenctions with + // watchtowers. If the exponential backoff produces a timeout greater + // than this value, the backoff will be clamped to MaxBackoff. + MaxBackoff time.Duration +} + +// TowerClient is a concrete implementation of the Client interface, offering a +// non-blocking, reliable subsystem for backing up revoked states to a specified +// private tower. +type TowerClient struct { + started sync.Once + stopped sync.Once + forced sync.Once + + cfg *Config + + pipeline *taskPipeline + + negotiator SessionNegotiator + candidateSessions map[wtdb.SessionID]*wtdb.ClientSession + activeSessions sessionQueueSet + + sessionQueue *sessionQueue + prevTask *backupTask + + sweepPkScriptMu sync.RWMutex + sweepPkScripts map[lnwire.ChannelID][]byte + + statTicker *time.Ticker + stats clientStats + + wg sync.WaitGroup + forceQuit chan struct{} +} + +// Compile-time constraint to ensure *TowerClient implements the Client +// interface. +var _ Client = (*TowerClient)(nil) + +// New initializes a new TowerClient from the provide Config. An error is +// returned if the client could not initialized. +func New(config *Config) (*TowerClient, error) { + // Copy the config to prevent side-effects from modifying both the + // internal and external version of the Config. + cfg := new(Config) + *cfg = *config + + // Set the read timeout to the default if none was provided. + if cfg.ReadTimeout <= 0 { + cfg.ReadTimeout = DefaultReadTimeout + } + + // Set the write timeout to the default if none was provided. + if cfg.WriteTimeout <= 0 { + cfg.WriteTimeout = DefaultWriteTimeout + } + + // Record the tower in our database, also loading any addresses + // previously associated with its public key. + tower, err := cfg.DB.CreateTower(cfg.PrivateTower) + if err != nil { + return nil, err + } + + log.Infof("Using private watchtower %s, offering policy %s", + cfg.PrivateTower, cfg.Policy) + + c := &TowerClient{ + cfg: cfg, + pipeline: newTaskPipeline(), + activeSessions: make(sessionQueueSet), + statTicker: time.NewTicker(DefaultStatInterval), + forceQuit: make(chan struct{}), + } + c.negotiator = newSessionNegotiator(&NegotiatorConfig{ + DB: cfg.DB, + Policy: cfg.Policy, + ChainHash: cfg.ChainHash, + SendMessage: c.sendMessage, + ReadMessage: c.readMessage, + Dial: c.dial, + Candidates: newTowerListIterator(tower), + MinBackoff: cfg.MinBackoff, + MaxBackoff: cfg.MaxBackoff, + }) + + // Next, load all active sessions from the db into the client. We will + // use any of these session if their policies match the current policy + // of the client, otherwise they will be ignored and new sessions will + // be requested. + c.candidateSessions, err = c.cfg.DB.ListClientSessions() + if err != nil { + return nil, err + } + + // Finally, load the sweep pkscripts that have been generated for all + // previously registered channels. + c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts() + if err != nil { + return nil, err + } + + return c, nil +} + +// Start initializes the watchtower client by loading or negotiating an active +// session and then begins processing backup tasks from the request pipeline. +func (c *TowerClient) Start() error { + var err error + c.started.Do(func() { + log.Infof("Starting watchtower client") + + // First, restart a session queue for any sessions that have + // committed but unacked state updates. This ensures that these + // sessions will be able to flush the committed updates after a + // restart. + for _, session := range c.candidateSessions { + if len(session.CommittedUpdates) > 0 { + log.Infof("Starting session=%s to process "+ + "%d committed backups", session.ID, + len(session.CommittedUpdates)) + c.initActiveQueue(session) + } + } + + // Now start the session negotiator, which will allow us to + // request new session as soon as the backupDispatcher starts + // up. + err = c.negotiator.Start() + if err != nil { + return + } + + // Start the task pipeline to which new backup tasks will be + // submitted from active links. + c.pipeline.Start() + + c.wg.Add(1) + go c.backupDispatcher() + + log.Infof("Watchtower client started successfully") + }) + return err +} + +// Stop idempotently initiates a graceful shutdown of the watchtower client. +func (c *TowerClient) Stop() error { + c.stopped.Do(func() { + log.Debugf("Stopping watchtower client") + + // 1. Shutdown the backup queue, which will prevent any further + // updates from being accepted. In practice, the links should be + // shutdown before the client has been stopped, so all updates + // would have been added prior. + c.pipeline.Stop() + + // 2. To ensure we don't hang forever on shutdown due to + // unintended failures, we'll delay a call to force quit the + // pipeline if a ForceQuitDelay is specified. This will have no + // effect if the pipeline shuts down cleanly before the delay + // fires. + // + // For full safety, this can be set to 0 and wait out + // indefinitely. However for mobile clients which may have a + // limited amount of time to exit before the background process + // is killed, this offers a way to ensure the process + // terminates. + if c.cfg.ForceQuitDelay > 0 { + time.AfterFunc(c.cfg.ForceQuitDelay, c.ForceQuit) + } + + // 3. Once the backup queue has shutdown, wait for the main + // dispatcher to exit. The backup queue will signal it's + // completion to the dispatcher, which releases the wait group + // after all tasks have been assigned to session queues. + c.wg.Wait() + + // 4. Since all valid tasks have been assigned to session + // queues, we no longer need to negotiate sessions. + c.negotiator.Stop() + + log.Debugf("Waiting for active session queues to finish "+ + "draining, stats: %s", c.stats) + + // 5. Shutdown all active session queues in parallel. These will + // exit once all updates have been acked by the watchtower. + c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() { + return s.Stop + }) + + // Skip log if force quitting. + select { + case <-c.forceQuit: + return + default: + } + + log.Debugf("Client successfully stopped, stats: %s", c.stats) + }) + return nil +} + +// ForceQuit idempotently initiates an unclean shutdown of the watchtower +// client. This should only be executed if Stop is unable to exit cleanly. +func (c *TowerClient) ForceQuit() { + c.forced.Do(func() { + log.Infof("Force quitting watchtower client") + + // Cancel log message from stop. + close(c.forceQuit) + + // 1. Shutdown the backup queue, which will prevent any further + // updates from being accepted. In practice, the links should be + // shutdown before the client has been stopped, so all updates + // would have been added prior. + c.pipeline.ForceQuit() + + // 2. Once the backup queue has shutdown, wait for the main + // dispatcher to exit. The backup queue will signal it's + // completion to the dispatcher, which releases the wait group + // after all tasks have been assigned to session queues. + c.wg.Wait() + + // 3. Since all valid tasks have been assigned to session + // queues, we no longer need to negotiate sessions. + c.negotiator.Stop() + + // 4. Force quit all active session queues in parallel. These + // will exit once all updates have been acked by the watchtower. + c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() { + return s.ForceQuit + }) + + log.Infof("Watchtower client unclean shutdown complete, "+ + "stats: %s", c.stats) + }) +} + +// RegisterChannel persistently initializes any channel-dependent parameters +// within the client. This should be called during link startup to ensure that +// the client is able to support the link during operation. +func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { + c.sweepPkScriptMu.Lock() + defer c.sweepPkScriptMu.Unlock() + + // If a pkscript for this channel already exists, the channel has been + // previously registered. + if _, ok := c.sweepPkScripts[chanID]; ok { + return nil + } + + // Otherwise, generate a new sweep pkscript used to sweep funds for this + // channel. + pkScript, err := c.cfg.NewAddress() + if err != nil { + return err + } + + // Persist the sweep pkscript so that restarts will not introduce + // address inflation when the channel is reregistered after a restart. + err = c.cfg.DB.AddChanPkScript(chanID, pkScript) + if err != nil { + return err + } + + // Finally, cache the pkscript in our in-memory cache to avoid db + // lookups for the remainder of the daemon's execution. + c.sweepPkScripts[chanID] = pkScript + + return nil +} + +// BackupState initiates a request to back up a particular revoked state. If the +// method returns nil, the backup is guaranteed to be successful unless the: +// - client is force quit, +// - justice transaction would create dust outputs when trying to abide by the +// negotiated policy, or +// - breached outputs contain too little value to sweep at the target sweep fee +// rate. +func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, + breachInfo *lnwallet.BreachRetribution) error { + + // Retrieve the cached sweep pkscript used for this channel. + c.sweepPkScriptMu.RLock() + sweepPkScript, ok := c.sweepPkScripts[*chanID] + c.sweepPkScriptMu.RUnlock() + if !ok { + return ErrUnregisteredChannel + } + + task := newBackupTask(chanID, breachInfo, sweepPkScript) + + return c.pipeline.QueueBackupTask(task) +} + +// nextSessionQueue attempts to fetch an active session from our set of +// candidate sessions. Candidate sessions with a differing policy from the +// active client's advertised policy will be ignored, but may be resumed if the +// client is restarted with a matching policy. If no candidates were found, nil +// is returned to signal that we need to request a new policy. +func (c *TowerClient) nextSessionQueue() *sessionQueue { + // Select any candidate session at random, and remove it from the set of + // candidate sessions. + var candidateSession *wtdb.ClientSession + for id, sessionInfo := range c.candidateSessions { + delete(c.candidateSessions, id) + + // Skip any sessions with policies that don't match the current + // configuration. These can be used again if the client changes + // their configuration back. + if sessionInfo.Policy != c.cfg.Policy { + continue + } + + candidateSession = sessionInfo + break + } + + // If none of the sessions could be used or none were found, we'll + // return nil to signal that we need another session to be negotiated. + if candidateSession == nil { + return nil + } + + // Initialize the session queue and spin it up so it can begin handling + // updates. If the queue was already made active on startup, this will + // simply return the existing session queue from the set. + return c.getOrInitActiveQueue(candidateSession) +} + +// backupDispatcher processes events coming from the taskPipeline and is +// responsible for detecting when the client needs to renegotiate a session to +// fulfill continuing demand. The event loop exits after all tasks have been +// received from the upstream taskPipeline, or the taskPipeline is force quit. +// +// NOTE: This method MUST be run as a goroutine. +func (c *TowerClient) backupDispatcher() { + defer c.wg.Done() + + log.Tracef("Starting backup dispatcher") + defer log.Tracef("Stopping backup dispatcher") + + for { + switch { + + // No active session queue and no additional sessions. + case c.sessionQueue == nil && len(c.candidateSessions) == 0: + log.Infof("Requesting new session.") + + // Immediately request a new session. + c.negotiator.RequestSession() + + // Wait until we receive the newly negotiated session. + // All backups sent in the meantime are queued in the + // revoke queue, as we cannot process them. + select { + case session := <-c.negotiator.NewSessions(): + log.Infof("Acquired new session with id=%s", + session.ID) + c.candidateSessions[session.ID] = session + c.stats.sessionAcquired() + + case <-c.statTicker.C: + log.Infof("Client stats: %s", c.stats) + } + + // No active session queue but have additional sessions. + case c.sessionQueue == nil && len(c.candidateSessions) > 0: + // We've exhausted the prior session, we'll pop another + // from the remaining sessions and continue processing + // backup tasks. + c.sessionQueue = c.nextSessionQueue() + if c.sessionQueue != nil { + log.Debugf("Loaded next candidate session "+ + "queue id=%s", c.sessionQueue.ID()) + } + + // Have active session queue, process backups. + case c.sessionQueue != nil: + if c.prevTask != nil { + c.processTask(c.prevTask) + + // Continue to ensure the sessionQueue is + // properly initialized before attempting to + // process more tasks from the pipeline. + continue + } + + // Normal operation where new tasks are read from the + // pipeline. + select { + + // If any sessions are negotiated while we have an + // active session queue, queue them for future use. + // This shouldn't happen with the current design, so + // it doesn't hurt to select here just in case. In the + // future, we will likely allow more asynchrony so that + // we can request new sessions before the session is + // fully empty, which this case would handle. + case session := <-c.negotiator.NewSessions(): + log.Warnf("Acquired new session with id=%s", + "while processing tasks", session.ID) + c.candidateSessions[session.ID] = session + c.stats.sessionAcquired() + + case <-c.statTicker.C: + log.Infof("Client stats: %s", c.stats) + + // Process each backup task serially from the queue of + // revoked states. + case task, ok := <-c.pipeline.NewBackupTasks(): + // All backups in the pipeline have been + // processed, it is now safe to exit. + if !ok { + return + } + + log.Debugf("Processing backup task chanid=%s "+ + "commit-height=%d", task.id.ChanID, + task.id.CommitHeight) + + c.stats.taskReceived() + c.processTask(task) + } + } + } +} + +// processTask attempts to schedule the given backupTask on the active +// sessionQueue. The task will either be accepted or rejected, afterwhich the +// appropriate modifications to the client's state machine will be made. After +// every invocation of processTask, the caller should ensure that the +// sessionQueue hasn't been exhausted before proceeding to the next task. Tasks +// that are rejected because the active sessionQueue is full will be cached as +// the prevTask, and should be reprocessed after obtaining a new sessionQueue. +func (c *TowerClient) processTask(task *backupTask) { + status, accepted := c.sessionQueue.AcceptTask(task) + if accepted { + c.taskAccepted(task, status) + } else { + c.taskRejected(task, status) + } +} + +// taskAccepted processes the acceptance of a task by a sessionQueue depending +// on the state the sessionQueue is in *after* the task is added. The client's +// prevTask is always removed as a result of this call. The client's +// sessionQueue will be removed if accepting the task left the sessionQueue in +// an exhausted state. +func (c *TowerClient) taskAccepted(task *backupTask, newStatus reserveStatus) { + log.Infof("Backup chanid=%s commit-height=%d accepted successfully", + task.id.ChanID, task.id.CommitHeight) + + c.stats.taskAccepted() + + // If this task was accepted, we discard anything held in the prevTask. + // Either it was nil before, or is the task which was just accepted. + c.prevTask = nil + + switch newStatus { + + // The sessionQueue still has capacity after accepting this task. + case reserveAvailable: + + // The sessionQueue is full after accepting this task, so we will need + // to request a new one before proceeding. + case reserveExhausted: + c.stats.sessionExhausted() + + log.Debugf("Session %s exhausted", c.sessionQueue.ID()) + + // This task left the session exhausted, set it to nil and + // proceed to the next loop so we can consume another + // pre-negotiated session or request another. + c.sessionQueue = nil + } +} + +// taskRejected process the rejection of a task by a sessionQueue depending on +// the state the was in *before* the task was rejected. The client's prevTask +// will cache the task if the sessionQueue was exhausted before hand, and nil +// the sessionQueue to find a new session. If the sessionQueue was not +// exhausted, the client marks the task as ineligible, as this implies we +// couldn't construct a valid justice transaction given the session's policy. +func (c *TowerClient) taskRejected(task *backupTask, curStatus reserveStatus) { + switch curStatus { + + // The sessionQueue has available capacity but the task was rejected, + // this indicates that the task was ineligible for backup. + case reserveAvailable: + c.stats.taskIneligible() + + log.Infof("Backup chanid=%s commit-height=%d is ineligible", + task.id.ChanID, task.id.CommitHeight) + + err := c.cfg.DB.MarkBackupIneligible( + task.id.ChanID, task.id.CommitHeight, + ) + if err != nil { + log.Errorf("Unable to mark task chanid=%s "+ + "commit-height=%d ineligible: %v", + task.id.ChanID, task.id.CommitHeight, err) + + // It is safe to not handle this error, even if we could + // not persist the result. At worst, this task may be + // reprocessed on a subsequent start up, and will either + // succeed do a change in session parameters or fail in + // the same manner. + } + + // If this task was rejected *and* the session had available + // capacity, we discard anything held in the prevTask. Either it + // was nil before, or is the task which was just rejected. + c.prevTask = nil + + // The sessionQueue rejected the task because it is full, we will stash + // this task and try to add it to the next available sessionQueue. + case reserveExhausted: + c.stats.sessionExhausted() + + log.Debugf("Session %s exhausted, backup chanid=%s "+ + "commit-height=%d queued for next session", + c.sessionQueue.ID(), task.id.ChanID, + task.id.CommitHeight) + + // Cache the task that we pulled off, so that we can process it + // once a new session queue is available. + c.sessionQueue = nil + c.prevTask = task + } +} + +// dial connects the peer at addr using privKey as our secret key for the +// connection. The connection will use the configured Net's resolver to resolve +// the address for either Tor or clear net connections. +func (c *TowerClient) dial(privKey *btcec.PrivateKey, + addr *lnwire.NetAddress) (wtserver.Peer, error) { + + return c.cfg.AuthDial(privKey, addr, c.cfg.Dial) +} + +// readMessage receives and parses the next message from the given Peer. An +// error is returned if a message is not received before the server's read +// timeout, the read off the wire failed, or the message could not be +// deserialized. +func (c *TowerClient) readMessage(peer wtserver.Peer) (wtwire.Message, error) { + // Set a read timeout to ensure we drop the connection if nothing is + // received in a timely manner. + err := peer.SetReadDeadline(time.Now().Add(c.cfg.ReadTimeout)) + if err != nil { + err = fmt.Errorf("unable to set read deadline: %v", err) + log.Errorf("Unable to read msg: %v", err) + return nil, err + } + + // Pull the next message off the wire, + rawMsg, err := peer.ReadNextMessage() + if err != nil { + err = fmt.Errorf("unable to read message: %v", err) + log.Errorf("Unable to read msg: %v", err) + return nil, err + } + + // Parse the received message according to the watchtower wire + // specification. + msgReader := bytes.NewReader(rawMsg) + msg, err := wtwire.ReadMessage(msgReader, 0) + if err != nil { + err = fmt.Errorf("unable to parse message: %v", err) + log.Errorf("Unable to read msg: %v", err) + return nil, err + } + + logMessage(peer, msg, true) + + return msg, nil +} + +// sendMessage sends a watchtower wire message to the target peer. +func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error { + // Encode the next wire message into the buffer. + // TODO(conner): use buffer pool + var b bytes.Buffer + _, err := wtwire.WriteMessage(&b, msg, 0) + if err != nil { + err = fmt.Errorf("Unable to encode msg: %v", err) + log.Errorf("Unable to send msg: %v", err) + return err + } + + // Set the write deadline for the connection, ensuring we drop the + // connection if nothing is sent in a timely manner. + err = peer.SetWriteDeadline(time.Now().Add(c.cfg.WriteTimeout)) + if err != nil { + err = fmt.Errorf("unable to set write deadline: %v", err) + log.Errorf("Unable to send msg: %v", err) + return err + } + + logMessage(peer, msg, false) + + // Write out the full message to the remote peer. + _, err = peer.Write(b.Bytes()) + if err != nil { + log.Errorf("Unable to send msg: %v", err) + } + return err +} + +// newSessionQueue creates a sessionQueue from a ClientSession loaded from the +// database and supplying it with the resources needed by the client. +func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue { + return newSessionQueue(&sessionQueueConfig{ + ClientSession: s, + ChainHash: c.cfg.ChainHash, + Dial: c.dial, + ReadMessage: c.readMessage, + SendMessage: c.sendMessage, + Signer: c.cfg.Signer, + DB: c.cfg.DB, + MinBackoff: c.cfg.MinBackoff, + MaxBackoff: c.cfg.MaxBackoff, + }) +} + +// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the +// passed ClientSession. If it exists, the active sessionQueue is returned. +// Otherwise a new sessionQueue is initialized and added to the set. +func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue { + if sq, ok := c.activeSessions[s.ID]; ok { + return sq + } + + return c.initActiveQueue(s) +} + +// initActiveQueue creates a new sessionQueue from the passed ClientSession, +// adds the sessionQueue to the activeSessions set, and starts the sessionQueue +// so that it can deliver any committed updates or begin accepting newly +// assigned tasks. +func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue { + // Initialize the session queue, providing it with all of the resources + // it requires from the client instance. + sq := c.newSessionQueue(s) + + // Add the session queue as an active session so that we remember to + // stop it on shutdown. + c.activeSessions.Add(sq) + + // Start the queue so that it can be active in processing newly assigned + // tasks or to upload previously committed updates. + sq.Start() + + return sq +} + +// logMessage writes information about a message received from a remote peer, +// using directional prepositions to signal whether the message was sent or +// received. +func logMessage(peer wtserver.Peer, msg wtwire.Message, read bool) { + var action = "Received" + var preposition = "from" + if !read { + action = "Sending" + preposition = "to" + } + + summary := wtwire.MessageSummary(msg) + if len(summary) > 0 { + summary = "(" + summary + ")" + } + + log.Debugf("%s %s%v %s %x@%s", action, msg.MsgType(), summary, + preposition, peer.RemotePub().SerializeCompressed(), + peer.RemoteAddr()) +} From 87e8700c5d65b06688a20ff565399f0d81b33131 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:32:40 -0700 Subject: [PATCH 15/21] watchtower/wtmock/client_db: add mock client db --- watchtower/wtmock/client_db.go | 223 +++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 watchtower/wtmock/client_db.go diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go new file mode 100644 index 000000000..54f9a697e --- /dev/null +++ b/watchtower/wtmock/client_db.go @@ -0,0 +1,223 @@ +package wtmock + +import ( + "fmt" + "net" + "sync" + "sync/atomic" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +type towerPK [33]byte + +// ClientDB is a mock, in-memory database or testing the watchtower client +// behavior. +type ClientDB struct { + nextTowerID uint64 // to be used atomically + + mu sync.Mutex + sweepPkScripts map[lnwire.ChannelID][]byte + activeSessions map[wtdb.SessionID]*wtdb.ClientSession + towerIndex map[towerPK]uint64 + towers map[uint64]*wtdb.Tower +} + +// NewClientDB initializes a new mock ClientDB. +func NewClientDB() *ClientDB { + return &ClientDB{ + sweepPkScripts: make(map[lnwire.ChannelID][]byte), + activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), + towerIndex: make(map[towerPK]uint64), + towers: make(map[uint64]*wtdb.Tower), + } +} + +// CreateTower initializes a database entry with the given lightning address. If +// the tower exists, the address is append to the list of all addresses used to +// that tower previously. +func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { + m.mu.Lock() + defer m.mu.Unlock() + + var towerPubKey towerPK + copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed()) + + var tower *wtdb.Tower + towerID, ok := m.towerIndex[towerPubKey] + if ok { + tower = m.towers[towerID] + tower.AddAddress(lnAddr.Address) + } else { + towerID = atomic.AddUint64(&m.nextTowerID, 1) + tower = &wtdb.Tower{ + ID: towerID, + IdentityKey: lnAddr.IdentityKey, + Addresses: []net.Addr{lnAddr.Address}, + } + } + + m.towerIndex[towerPubKey] = towerID + m.towers[towerID] = tower + + return tower, nil +} + +// MarkBackupIneligible records that particular commit height is ineligible for +// backup. This allows the client to track which updates it should not attempt +// to retry after startup. +func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error { + return nil +} + +// ListClientSessions returns the set of all client sessions known to the db. +func (m *ClientDB) ListClientSessions() (map[wtdb.SessionID]*wtdb.ClientSession, error) { + m.mu.Lock() + defer m.mu.Unlock() + + sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) + for _, session := range m.activeSessions { + sessions[session.ID] = session + } + + return sessions, nil +} + +// CreateClientSession records a newly negotiated client session in the set of +// active sessions. The session can be identified by its SessionID. +func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.activeSessions[session.ID] = &wtdb.ClientSession{ + TowerID: session.TowerID, + Tower: session.Tower, + SessionKeyDesc: session.SessionKeyDesc, + SessionPrivKey: session.SessionPrivKey, + ID: session.ID, + Policy: session.Policy, + SeqNum: session.SeqNum, + TowerLastApplied: session.TowerLastApplied, + RewardPkScript: session.RewardPkScript, + CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate), + AckedUpdates: make(map[uint16]wtdb.BackupID), + } + + return nil +} + +// CommitUpdate persists the CommittedUpdate provided in the slot for (session, +// seqNum). This allows the client to retransmit this update on startup. +func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16, + update *wtdb.CommittedUpdate) (uint16, error) { + + m.mu.Lock() + defer m.mu.Unlock() + + // Fail if session doesn't exist. + session, ok := m.activeSessions[*id] + if !ok { + return 0, wtdb.ErrClientSessionNotFound + } + + // Check if an update has already been committed for this state. + dbUpdate, ok := session.CommittedUpdates[seqNum] + if ok { + // If the breach hint matches, we'll just return the last + // applied value so the client can retransmit. + if dbUpdate.Hint == update.Hint { + return session.TowerLastApplied, nil + } + + // Otherwise, fail since the breach hint doesn't match. + return 0, wtdb.ErrUpdateAlreadyCommitted + } + + // Sequence number must increment. + if seqNum != session.SeqNum+1 { + return 0, wtdb.ErrCommitUnorderedUpdate + } + + // Save the update and increment the sequence number. + session.CommittedUpdates[seqNum] = update + session.SeqNum++ + + return session.TowerLastApplied, nil +} + +// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This +// removes the update from the set of committed updates, and validates the +// lastApplied value returned from the tower. +func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error { + m.mu.Lock() + defer m.mu.Unlock() + + // Fail if session doesn't exist. + session, ok := m.activeSessions[*id] + if !ok { + return wtdb.ErrClientSessionNotFound + } + + // Retrieve the committed update, failing if none is found. We should + // only receive acks for state updates that we send. + update, ok := session.CommittedUpdates[seqNum] + if !ok { + return wtdb.ErrCommittedUpdateNotFound + } + + // Ensure the returned last applied value does not exceed the highest + // allocated sequence number. + if lastApplied > session.SeqNum { + return wtdb.ErrUnallocatedLastApplied + } + + // Ensure the last applied value isn't lower than a previous one sent by + // the tower. + if lastApplied < session.TowerLastApplied { + return wtdb.ErrLastAppliedReversion + } + + // Finally, remove the committed update from disk and mark the update as + // acked. The tower last applied value is also recorded to send along + // with the next update. + delete(session.CommittedUpdates, seqNum) + session.AckedUpdates[seqNum] = update.BackupID + session.TowerLastApplied = lastApplied + + return nil +} + +// FetchChanPkScripts returns the set of sweep pkscripts known for all channels. +// This allows the client to cache them in memory on startup. +func (m *ClientDB) FetchChanPkScripts() (map[lnwire.ChannelID][]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + sweepPkScripts := make(map[lnwire.ChannelID][]byte) + for chanID, pkScript := range m.sweepPkScripts { + sweepPkScripts[chanID] = cloneBytes(pkScript) + } + + return sweepPkScripts, nil +} + +// AddChanPkScript sets a pkscript or sweeping funds from the channel or chanID. +func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.sweepPkScripts[chanID]; ok { + return fmt.Errorf("pkscript for %x already exists", pkScript) + } + + m.sweepPkScripts[chanID] = cloneBytes(pkScript) + + return nil +} + +func cloneBytes(b []byte) []byte { + bb := make([]byte, len(b)) + copy(bb, b) + return bb +} From 81497eceafa417d5164cbd627ac731ee643fbde9 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:32:53 -0700 Subject: [PATCH 16/21] watchtower/wtmock/peer: create mock net.Conn using bidi MockPeer --- watchtower/wtmock/peer.go | 72 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/watchtower/wtmock/peer.go b/watchtower/wtmock/peer.go index 9103c4990..16aeb0d8c 100644 --- a/watchtower/wtmock/peer.go +++ b/watchtower/wtmock/peer.go @@ -13,6 +13,8 @@ import ( type MockPeer struct { remotePub *btcec.PublicKey remoteAddr net.Addr + localPub *btcec.PublicKey + localAddr net.Addr IncomingMsgs chan []byte OutgoingMsgs chan []byte @@ -20,30 +22,71 @@ type MockPeer struct { writeDeadline <-chan time.Time readDeadline <-chan time.Time - Quit chan struct{} + RemoteQuit chan struct{} + Quit chan struct{} } // NewMockPeer returns a fresh MockPeer. func NewMockPeer(pk *btcec.PublicKey, addr net.Addr, bufferSize int) *MockPeer { return &MockPeer{ - remotePub: pk, - remoteAddr: addr, + remotePub: pk, + remoteAddr: addr, + localAddr: &net.TCPAddr{ + IP: net.IP{0x32, 0x31, 0x30, 0x29}, + Port: 36723, + }, IncomingMsgs: make(chan []byte, bufferSize), OutgoingMsgs: make(chan []byte, bufferSize), Quit: make(chan struct{}), } } +// NewMockConn establishes a bidirectional connection between two MockPeers. +func NewMockConn(localPk, remotePk *btcec.PublicKey, + localAddr, remoteAddr net.Addr, + bufferSize int) (*MockPeer, *MockPeer) { + + localPeer := &MockPeer{ + remotePub: remotePk, + remoteAddr: remoteAddr, + localPub: localPk, + localAddr: localAddr, + IncomingMsgs: make(chan []byte, bufferSize), + OutgoingMsgs: make(chan []byte, bufferSize), + Quit: make(chan struct{}), + } + + remotePeer := &MockPeer{ + remotePub: localPk, + remoteAddr: localAddr, + localPub: remotePk, + localAddr: remoteAddr, + IncomingMsgs: localPeer.OutgoingMsgs, + OutgoingMsgs: localPeer.IncomingMsgs, + Quit: make(chan struct{}), + } + + localPeer.RemoteQuit = remotePeer.Quit + remotePeer.RemoteQuit = localPeer.Quit + + return localPeer, remotePeer +} + // Write sends the raw bytes as the next full message read to the remote peer. // The write will fail if either party closes the connection or the write // deadline expires. The passed bytes slice is copied before sending, thus the // bytes may be reused once the method returns. func (p *MockPeer) Write(b []byte) (n int, err error) { + bb := make([]byte, len(b)) + copy(bb, b) + select { - case p.OutgoingMsgs <- b: + case p.OutgoingMsgs <- bb: return len(b), nil case <-p.writeDeadline: return 0, fmt.Errorf("write timeout expired") + case <-p.RemoteQuit: + return 0, fmt.Errorf("remote closed connected") case <-p.Quit: return 0, fmt.Errorf("connection closed") } @@ -69,6 +112,8 @@ func (p *MockPeer) ReadNextMessage() ([]byte, error) { return b, nil case <-p.readDeadline: return nil, fmt.Errorf("read timeout expired") + case <-p.RemoteQuit: + return nil, fmt.Errorf("remote closed connected") case <-p.Quit: return nil, fmt.Errorf("connection closed") } @@ -112,6 +157,25 @@ func (p *MockPeer) RemoteAddr() net.Addr { return p.remoteAddr } +// LocalAddr returns the local net address of the peer. +func (p *MockPeer) LocalAddr() net.Addr { + return p.localAddr +} + +// Read is not implemented. +func (p *MockPeer) Read(dst []byte) (int, error) { + panic("not implemented") +} + +// SetDeadline is not implemented. +func (p *MockPeer) SetDeadline(t time.Time) error { + panic("not implemented") +} + // Compile-time constraint ensuring the MockPeer implements the wserver.Peer // interface. var _ wtserver.Peer = (*MockPeer)(nil) + +// Compile-time constraint ensuring the MockPeer implements the net.Conn +// interface. +var _ net.Conn = (*MockPeer)(nil) From 8b0cc487f0a4c3244afde62adb62e49eaee9e9b1 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:33:05 -0700 Subject: [PATCH 17/21] watchtower/wtdb+wtserver: allow retransmission of last update --- watchtower/wtdb/session_info.go | 2 +- watchtower/wtserver/server_test.go | 36 ++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/watchtower/wtdb/session_info.go b/watchtower/wtdb/session_info.go index f1b2e2a81..1c7b7f0ff 100644 --- a/watchtower/wtdb/session_info.go +++ b/watchtower/wtdb/session_info.go @@ -82,7 +82,7 @@ func (s *SessionInfo) AcceptUpdateSequence(seqNum, lastApplied uint16) error { return ErrSessionConsumed // Client update does not match our expected next seqnum. - case seqNum != s.LastApplied+1: + case seqNum != s.LastApplied && seqNum != s.LastApplied+1: return ErrUpdateOutOfOrder } diff --git a/watchtower/wtserver/server_test.go b/watchtower/wtserver/server_test.go index 4a6ee27eb..6df7cd58d 100644 --- a/watchtower/wtserver/server_test.go +++ b/watchtower/wtserver/server_test.go @@ -418,8 +418,8 @@ var stateUpdateTests = []stateUpdateTestCase{ {Code: wtwire.CodeOK, LastApplied: 4}, }, }, - // Valid update sequence with disconnection, ensure resumes resume. - // Client doesn't echo last applied until last message. + // Valid update sequence with disconnection, resume next update. Client + // doesn't echo last applied until last message. { name: "resume after disconnect lagging lastapplied", initMsg: wtwire.NewInitMessage( @@ -448,6 +448,38 @@ var stateUpdateTests = []stateUpdateTestCase{ {Code: wtwire.CodeOK, LastApplied: 4}, }, }, + // Valid update sequence with disconnection, resume last update. Client + // doesn't echo last applied until last message. + { + name: "resume after disconnect lagging lastapplied", + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), + createMsg: &wtwire.CreateSession{ + BlobType: blob.TypeDefault, + MaxUpdates: 4, + RewardBase: 0, + RewardRate: 0, + SweepFeeRate: 1, + }, + updates: []*wtwire.StateUpdate{ + {SeqNum: 1, LastApplied: 0}, + {SeqNum: 2, LastApplied: 0}, + nil, // Wait for read timeout to drop conn, then reconnect. + {SeqNum: 2, LastApplied: 0}, + {SeqNum: 3, LastApplied: 0}, + {SeqNum: 4, LastApplied: 3}, + }, + replies: []*wtwire.StateUpdateReply{ + {Code: wtwire.CodeOK, LastApplied: 1}, + {Code: wtwire.CodeOK, LastApplied: 2}, + nil, + {Code: wtwire.CodeOK, LastApplied: 2}, + {Code: wtwire.CodeOK, LastApplied: 3}, + {Code: wtwire.CodeOK, LastApplied: 4}, + }, + }, // Send update with sequence number that exceeds MaxUpdates. { name: "seqnum exceed maxupdates", From a222a63d818668672542df975f79ff96c3c76760 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:33:20 -0700 Subject: [PATCH 18/21] watchtower/wtserver/server: no ack updates --- watchtower/wtserver/server.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/watchtower/wtserver/server.go b/watchtower/wtserver/server.go index 98848fbc1..dd39af98c 100644 --- a/watchtower/wtserver/server.go +++ b/watchtower/wtserver/server.go @@ -55,6 +55,10 @@ type Config struct { // ChainHash identifies the network that the server is watching. ChainHash chainhash.Hash + + // NoAckUpdates causes the server to not acknowledge state updates, this + // should only be used for testing. + NoAckUpdates bool } // Server houses the state required to handle watchtower peers. It's primary job @@ -445,6 +449,13 @@ func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID, failCode = wtwire.CodeTemporaryFailure } + if s.cfg.NoAckUpdates { + return &connFailure{ + ID: *id, + Code: uint16(failCode), + } + } + return s.replyStateUpdate( peer, id, failCode, lastApplied, ) From e1e805d1b81228d1fa1bafb3b2a8682bc8ab3b7a Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:33:33 -0700 Subject: [PATCH 19/21] watchtower/wtserver/server: fix race condition on Stop --- watchtower/wtserver/server.go | 81 ++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 24 deletions(-) diff --git a/watchtower/wtserver/server.go b/watchtower/wtserver/server.go index dd39af98c..b9d222757 100644 --- a/watchtower/wtserver/server.go +++ b/watchtower/wtserver/server.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "sync" - "sync/atomic" "time" "github.com/btcsuite/btcd/btcec" @@ -65,8 +64,8 @@ type Config struct { // is to accept incoming connections, and dispatch processing of the client // message streams. type Server struct { - started int32 // atomic - shutdown int32 // atomic + started sync.Once + stopped sync.Once cfg *Config @@ -75,6 +74,8 @@ type Server struct { clientMtx sync.RWMutex clients map[wtdb.SessionID]Peer + newPeers chan Peer + localInit *wtwire.Init wg sync.WaitGroup @@ -93,6 +94,7 @@ func New(cfg *Config) (*Server, error) { s := &Server{ cfg: cfg, clients: make(map[wtdb.SessionID]Peer), + newPeers: make(chan Peer), localInit: localInit, quit: make(chan struct{}), } @@ -113,36 +115,31 @@ func New(cfg *Config) (*Server, error) { // Start begins listening on the server's listeners. func (s *Server) Start() error { - // Already running? - if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { - return nil - } + s.started.Do(func() { + log.Infof("Starting watchtower server") - log.Infof("Starting watchtower server") + s.wg.Add(1) + go s.peerHandler() - s.connMgr.Start() - - log.Infof("Watchtower server started successfully") + s.connMgr.Start() + log.Infof("Watchtower server started successfully") + }) return nil } // Stop shutdowns down the server's listeners and any active requests. func (s *Server) Stop() error { - // Bail if we're already shutting down. - if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) { - return nil - } + s.stopped.Do(func() { + log.Infof("Stopping watchtower server") - log.Infof("Stopping watchtower server") + s.connMgr.Stop() - s.connMgr.Stop() - - close(s.quit) - s.wg.Wait() - - log.Infof("Watchtower server stopped successfully") + close(s.quit) + s.wg.Wait() + log.Infof("Watchtower server stopped successfully") + }) return nil } @@ -167,8 +164,29 @@ func (s *Server) inboundPeerConnected(c net.Conn) { // by the client. This method serves also as a public endpoint for locally // registering new clients with the server. func (s *Server) InboundPeerConnected(peer Peer) { - s.wg.Add(1) - go s.handleClient(peer) + select { + case s.newPeers <- peer: + case <-s.quit: + } +} + +// peerHandler processes newly accepted peers and spawns a client handler for +// each. The peerHandler is used to ensure that waitgrouped client handlers are +// spawned from a waitgrouped goroutine. +func (s *Server) peerHandler() { + defer s.wg.Done() + defer s.removeAllPeers() + + for { + select { + case peer := <-s.newPeers: + s.wg.Add(1) + go s.handleClient(peer) + + case <-s.quit: + return + } + } } // handleClient processes a series watchtower messages sent by a client. The @@ -625,6 +643,21 @@ func (s *Server) removePeer(id *wtdb.SessionID, addr net.Addr) { } } +// removeAllPeers iterates through the server's current set of peers and closes +// all open connections. +func (s *Server) removeAllPeers() { + s.clientMtx.Lock() + defer s.clientMtx.Unlock() + + for id, peer := range s.clients { + log.Infof("Releasing incoming peer %s@%s", id, + peer.RemoteAddr()) + + delete(s.clients, id) + peer.Close() + } +} + // logMessage writes information about a message exchanged with a remote peer, // using directional prepositions to signal whether the message was sent or // received. From 80040d9d961db9c2d40d3bd10d70792141228e75 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:33:47 -0700 Subject: [PATCH 20/21] watchtower/wtclient/client_test: adds client-server upload test --- watchtower/wtclient/client_test.go | 1118 ++++++++++++++++++++++++++++ 1 file changed, 1118 insertions(+) create mode 100644 watchtower/wtclient/client_test.go diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go new file mode 100644 index 000000000..dba4275e6 --- /dev/null +++ b/watchtower/wtclient/client_test.go @@ -0,0 +1,1118 @@ +// +build dev + +package wtclient_test + +import ( + "encoding/binary" + "net" + "sync" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtclient" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtmock" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/lightningnetwork/lnd/watchtower/wtserver" +) + +const csvDelay uint32 = 144 + +var ( + revPrivBytes = []byte{ + 0x8f, 0x4b, 0x51, 0x83, 0xa9, 0x34, 0xbd, 0x5f, + 0x74, 0x6c, 0x9d, 0x5c, 0xae, 0x88, 0x2d, 0x31, + 0x06, 0x90, 0xdd, 0x8c, 0x9b, 0x31, 0xbc, 0xd1, + 0x78, 0x91, 0x88, 0x2a, 0xf9, 0x74, 0xa0, 0xef, + } + + toLocalPrivBytes = []byte{ + 0xde, 0x17, 0xc1, 0x2f, 0xdc, 0x1b, 0xc0, 0xc6, + 0x59, 0x5d, 0xf9, 0xc1, 0x3e, 0x89, 0xbc, 0x6f, + 0x01, 0x85, 0x45, 0x76, 0x26, 0xce, 0x9c, 0x55, + 0x3b, 0xc9, 0xec, 0x3d, 0xd8, 0x8b, 0xac, 0xa8, + } + + toRemotePrivBytes = []byte{ + 0x28, 0x59, 0x6f, 0x36, 0xb8, 0x9f, 0x19, 0x5d, + 0xcb, 0x07, 0x48, 0x8a, 0xe5, 0x89, 0x71, 0x74, + 0x70, 0x4c, 0xff, 0x1e, 0x9c, 0x00, 0x93, 0xbe, + 0xe2, 0x2e, 0x68, 0x08, 0x4c, 0xb4, 0x0f, 0x4f, + } + + // addr is the server's reward address given to watchtower clients. + addr, _ = btcutil.DecodeAddress( + "mrX9vMRYLfVy1BnZbc5gZjuyaqH3ZW2ZHz", &chaincfg.TestNet3Params, + ) + + addrScript, _ = txscript.PayToAddrScript(addr) +) + +// randPrivKey generates a new secp keypair, and returns the public key. +func randPrivKey(t *testing.T) *btcec.PrivateKey { + t.Helper() + + sk, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate pubkey: %v", err) + } + + return sk +} + +type mockNet struct { + mu sync.RWMutex + connCallback func(wtserver.Peer) +} + +func newMockNet(cb func(wtserver.Peer)) *mockNet { + return &mockNet{ + connCallback: cb, + } +} + +func (m *mockNet) Dial(network string, address string) (net.Conn, error) { + return nil, nil +} + +func (m *mockNet) LookupHost(host string) ([]string, error) { + panic("not implemented") +} + +func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) { + panic("not implemented") +} + +func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) { + panic("not implemented") +} + +func (m *mockNet) AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress, + dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error) { + + localPk := localPriv.PubKey() + localAddr := &net.TCPAddr{ + IP: net.IP{0x32, 0x31, 0x30, 0x29}, + Port: 36723, + } + + localPeer, remotePeer := wtmock.NewMockConn( + localPk, netAddr.IdentityKey, localAddr, netAddr.Address, 0, + ) + + m.mu.RLock() + m.connCallback(remotePeer) + m.mu.RUnlock() + + return localPeer, nil +} + +func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) { + m.mu.Lock() + defer m.mu.Unlock() + m.connCallback = cb +} + +type mockChannel struct { + mu sync.Mutex + commitHeight uint64 + retributions map[uint64]*lnwallet.BreachRetribution + localBalance lnwire.MilliSatoshi + remoteBalance lnwire.MilliSatoshi + + revSK *btcec.PrivateKey + revPK *btcec.PublicKey + revKeyLoc keychain.KeyLocator + + toRemoteSK *btcec.PrivateKey + toRemotePK *btcec.PublicKey + toRemoteKeyLoc keychain.KeyLocator + + toLocalPK *btcec.PublicKey // only need to generate to-local script + + dustLimit lnwire.MilliSatoshi + csvDelay uint32 +} + +func newMockChannel(t *testing.T, signer *wtmock.MockSigner, + localAmt, remoteAmt lnwire.MilliSatoshi) *mockChannel { + + // Generate the revocation, to-local, and to-remote keypairs. + revSK := randPrivKey(t) + revPK := revSK.PubKey() + + toLocalSK := randPrivKey(t) + toLocalPK := toLocalSK.PubKey() + + toRemoteSK := randPrivKey(t) + toRemotePK := toRemoteSK.PubKey() + + // Register the revocation secret key and the to-remote secret key with + // the signer. We will not need to sign with the to-local key, as this + // is to be known only by the counterparty. + revKeyLoc := signer.AddPrivKey(revSK) + toRemoteKeyLoc := signer.AddPrivKey(toRemoteSK) + + c := &mockChannel{ + retributions: make(map[uint64]*lnwallet.BreachRetribution), + localBalance: localAmt, + remoteBalance: remoteAmt, + revSK: revSK, + revPK: revPK, + revKeyLoc: revKeyLoc, + toLocalPK: toLocalPK, + toRemoteSK: toRemoteSK, + toRemotePK: toRemotePK, + toRemoteKeyLoc: toRemoteKeyLoc, + dustLimit: 546000, + csvDelay: 144, + } + + // Create the initial remote commitment with the initial balances. + c.createRemoteCommitTx(t) + + return c +} + +func (c *mockChannel) createRemoteCommitTx(t *testing.T) { + t.Helper() + + // Construct the to-local witness script. + toLocalScript, err := input.CommitScriptToSelf( + c.csvDelay, c.toLocalPK, c.revPK, + ) + if err != nil { + t.Fatalf("unable to create to-local script: %v", err) + } + + // Compute the to-local witness script hash. + toLocalScriptHash, err := input.WitnessScriptHash(toLocalScript) + if err != nil { + t.Fatalf("unable to create to-local witness script hash: %v", err) + } + + // Compute the to-remote witness script hash. + toRemoteScriptHash, err := input.CommitScriptUnencumbered(c.toRemotePK) + if err != nil { + t.Fatalf("unable to create to-remote script: %v", err) + } + + // Construct the remote commitment txn, containing the to-local and + // to-remote outputs. The balances are flipped since the transaction is + // from the PoV of the remote party. We don't need any inputs for this + // test. We increment the version with the commit height to ensure that + // all commitment transactions are unique even if the same distribution + // of funds is used more than once. + commitTxn := &wire.MsgTx{ + Version: int32(c.commitHeight + 1), + } + + var ( + toLocalSignDesc *input.SignDescriptor + toRemoteSignDesc *input.SignDescriptor + ) + + var outputIndex int + if c.remoteBalance >= c.dustLimit { + commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{ + Value: int64(c.remoteBalance.ToSatoshis()), + PkScript: toLocalScriptHash, + }) + + // Create the sign descriptor used to sign for the to-local + // input. + toLocalSignDesc = &input.SignDescriptor{ + KeyDesc: keychain.KeyDescriptor{ + KeyLocator: c.revKeyLoc, + PubKey: c.revPK, + }, + WitnessScript: toLocalScript, + Output: commitTxn.TxOut[outputIndex], + HashType: txscript.SigHashAll, + } + outputIndex++ + } + if c.localBalance >= c.dustLimit { + commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{ + Value: int64(c.localBalance.ToSatoshis()), + PkScript: toRemoteScriptHash, + }) + + // Create the sign descriptor used to sign for the to-remote + // input. + toRemoteSignDesc = &input.SignDescriptor{ + KeyDesc: keychain.KeyDescriptor{ + KeyLocator: c.toRemoteKeyLoc, + PubKey: c.toRemotePK, + }, + WitnessScript: toRemoteScriptHash, + Output: commitTxn.TxOut[outputIndex], + HashType: txscript.SigHashAll, + } + outputIndex++ + } + + txid := commitTxn.TxHash() + + var ( + toLocalOutPoint wire.OutPoint + toRemoteOutPoint wire.OutPoint + ) + + outputIndex = 0 + if toLocalSignDesc != nil { + toLocalOutPoint = wire.OutPoint{ + Hash: txid, + Index: uint32(outputIndex), + } + outputIndex++ + } + if toRemoteSignDesc != nil { + toRemoteOutPoint = wire.OutPoint{ + Hash: txid, + Index: uint32(outputIndex), + } + outputIndex++ + } + + commitKeyRing := &lnwallet.CommitmentKeyRing{ + RevocationKey: c.revPK, + NoDelayKey: c.toLocalPK, + DelayKey: c.toRemotePK, + } + + retribution := &lnwallet.BreachRetribution{ + BreachTransaction: commitTxn, + RevokedStateNum: c.commitHeight, + KeyRing: commitKeyRing, + RemoteDelay: c.csvDelay, + LocalOutpoint: toRemoteOutPoint, + LocalOutputSignDesc: toRemoteSignDesc, + RemoteOutpoint: toLocalOutPoint, + RemoteOutputSignDesc: toLocalSignDesc, + } + + c.retributions[c.commitHeight] = retribution + c.commitHeight++ +} + +// advanceState creates the next channel state and retribution without altering +// channel balances. +func (c *mockChannel) advanceState(t *testing.T) { + c.mu.Lock() + defer c.mu.Unlock() + + c.createRemoteCommitTx(t) +} + +// sendPayment creates the next channel state and retribution after transferring +// amt to the remote party. +func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliSatoshi) { + t.Helper() + + c.mu.Lock() + defer c.mu.Unlock() + + if c.localBalance < amt { + t.Fatalf("insufficient funds to send, need: %v, have: %v", + amt, c.localBalance) + } + + c.localBalance -= amt + c.remoteBalance += amt + c.createRemoteCommitTx(t) +} + +// receivePayment creates the next channel state and retribution after +// transferring amt to the local party. +func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliSatoshi) { + t.Helper() + + c.mu.Lock() + defer c.mu.Unlock() + + if c.remoteBalance < amt { + t.Fatalf("insufficient funds to recv, need: %v, have: %v", + amt, c.remoteBalance) + } + + c.localBalance += amt + c.remoteBalance -= amt + c.createRemoteCommitTx(t) +} + +// getState retrieves the channel's commitment and retribution at state i. +func (c *mockChannel) getState(i uint64) (*wire.MsgTx, *lnwallet.BreachRetribution) { + c.mu.Lock() + defer c.mu.Unlock() + + retribution := c.retributions[i] + + return retribution.BreachTransaction, retribution +} + +type testHarness struct { + t *testing.T + cfg harnessCfg + signer *wtmock.MockSigner + capacity lnwire.MilliSatoshi + clientDB *wtmock.ClientDB + clientCfg *wtclient.Config + client wtclient.Client + serverDB *wtdb.MockDB + serverCfg *wtserver.Config + server *wtserver.Server + net *mockNet + + mu sync.Mutex + channels map[lnwire.ChannelID]*mockChannel +} + +type harnessCfg struct { + localBalance lnwire.MilliSatoshi + remoteBalance lnwire.MilliSatoshi + policy wtpolicy.Policy + noRegisterChan0 bool +} + +func newHarness(t *testing.T, cfg harnessCfg) *testHarness { + towerAddrStr := "18.28.243.2:9911" + towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr) + if err != nil { + t.Fatalf("Unable to resolve tower TCP addr: %v", err) + } + + privKey, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("Unable to generate tower private key: %v", err) + } + + towerPubKey := privKey.PubKey() + + towerAddr := &lnwire.NetAddress{ + IdentityKey: towerPubKey, + Address: towerTCPAddr, + } + + const timeout = 200 * time.Millisecond + serverDB := wtdb.NewMockDB() + + serverCfg := &wtserver.Config{ + DB: serverDB, + ReadTimeout: timeout, + WriteTimeout: timeout, + NewAddress: func() (btcutil.Address, error) { + return addr, nil + }, + } + + server, err := wtserver.New(serverCfg) + if err != nil { + t.Fatalf("unable to create wtserver: %v", err) + } + + signer := wtmock.NewMockSigner() + mockNet := newMockNet(server.InboundPeerConnected) + clientDB := wtmock.NewClientDB() + + clientCfg := &wtclient.Config{ + Signer: signer, + Dial: func(string, string) (net.Conn, error) { + return nil, nil + }, + DB: clientDB, + AuthDial: mockNet.AuthDial, + PrivateTower: towerAddr, + Policy: cfg.policy, + NewAddress: func() ([]byte, error) { + return addrScript, nil + }, + ReadTimeout: timeout, + WriteTimeout: timeout, + MinBackoff: time.Millisecond, + MaxBackoff: 10 * time.Millisecond, + } + client, err := wtclient.New(clientCfg) + if err != nil { + t.Fatalf("Unable to create wtclient: %v", err) + } + + if err := server.Start(); err != nil { + t.Fatalf("Unable to start wtserver: %v", err) + } + + if err = client.Start(); err != nil { + server.Stop() + t.Fatalf("Unable to start wtclient: %v", err) + } + + h := &testHarness{ + t: t, + cfg: cfg, + signer: signer, + capacity: cfg.localBalance + cfg.remoteBalance, + clientDB: clientDB, + clientCfg: clientCfg, + client: client, + serverDB: serverDB, + serverCfg: serverCfg, + server: server, + net: mockNet, + channels: make(map[lnwire.ChannelID]*mockChannel), + } + + h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) + if !cfg.noRegisterChan0 { + h.registerChannel(0) + } + + return h +} + +// startServer creates a new server using the harness's current serverCfg and +// starts it after pointing the mockNet's callback to the new server. +func (h *testHarness) startServer() { + h.t.Helper() + + var err error + h.server, err = wtserver.New(h.serverCfg) + if err != nil { + h.t.Fatalf("unable to create wtserver: %v", err) + } + + h.net.setConnCallback(h.server.InboundPeerConnected) + + if err := h.server.Start(); err != nil { + h.t.Fatalf("unable to start wtserver: %v", err) + } +} + +// startClient creates a new server using the harness's current clientCf and +// starts it. +func (h *testHarness) startClient() { + h.t.Helper() + + var err error + h.client, err = wtclient.New(h.clientCfg) + if err != nil { + h.t.Fatalf("unable to create wtclient: %v", err) + } + if err := h.client.Start(); err != nil { + h.t.Fatalf("unable to start wtclient: %v", err) + } +} + +// chanIDFromInt creates a unique channel id given a unique integral id. +func chanIDFromInt(id uint64) lnwire.ChannelID { + var chanID lnwire.ChannelID + binary.BigEndian.PutUint64(chanID[:8], id) + return chanID +} + +// makeChannel creates new channel with id, using the localAmt and remoteAmt as +// the starting balances. The channel will be available by using h.channel(id). +// +// NOTE: The method fails if channel for id already exists. +func (h *testHarness) makeChannel(id uint64, + localAmt, remoteAmt lnwire.MilliSatoshi) { + + h.t.Helper() + + chanID := chanIDFromInt(id) + c := newMockChannel(h.t, h.signer, localAmt, remoteAmt) + + c.mu.Lock() + _, ok := h.channels[chanID] + if !ok { + h.channels[chanID] = c + } + c.mu.Unlock() + + if ok { + h.t.Fatalf("channel %d already created", id) + } +} + +// channel retrieves the channel corresponding to id. +// +// NOTE: The method fails if a channel for id does not exist. +func (h *testHarness) channel(id uint64) *mockChannel { + h.t.Helper() + + h.mu.Lock() + c, ok := h.channels[chanIDFromInt(id)] + h.mu.Unlock() + if !ok { + h.t.Fatalf("unable to fetch channel %d", id) + } + + return c +} + +// registerChannel registers the channel identified by id with the client. +func (h *testHarness) registerChannel(id uint64) { + h.t.Helper() + + chanID := chanIDFromInt(id) + err := h.client.RegisterChannel(chanID) + if err != nil { + h.t.Fatalf("unable to register channel %d: %v", id, err) + } +} + +// advanceChannelN calls advanceState on the channel identified by id the number +// of provided times and returns the breach hints corresponding to the new +// states. +func (h *testHarness) advanceChannelN(id uint64, n int) []wtdb.BreachHint { + h.t.Helper() + + channel := h.channel(id) + + var hints []wtdb.BreachHint + for i := uint64(0); i < uint64(n); i++ { + channel.advanceState(h.t) + commitTx, _ := h.channel(id).getState(i) + breachTxID := commitTx.TxHash() + hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID)) + } + + return hints +} + +// backupStates instructs the channel identified by id to send backups to the +// client for states in the range [to, from). +func (h *testHarness) backupStates(id, from, to uint64, expErr error) { + h.t.Helper() + + for i := from; i < to; i++ { + h.backupState(id, i, expErr) + } +} + +// backupStates instructs the channel identified by id to send a backup for +// state i. +func (h *testHarness) backupState(id, i uint64, expErr error) { + _, retribution := h.channel(id).getState(i) + + chanID := chanIDFromInt(id) + err := h.client.BackupState(&chanID, retribution) + if err != expErr { + h.t.Fatalf("back error mismatch, want: %v, got: %v", + expErr, err) + } +} + +// sendPayments instructs the channel identified by id to send amt to the remote +// party for each state in from-to times and returns the breach hints for states +// [from, to). +func (h *testHarness) sendPayments(id, from, to uint64, + amt lnwire.MilliSatoshi) []wtdb.BreachHint { + + h.t.Helper() + + channel := h.channel(id) + + var hints []wtdb.BreachHint + for i := from; i < to; i++ { + h.channel(id).sendPayment(h.t, amt) + commitTx, _ := channel.getState(i) + breachTxID := commitTx.TxHash() + hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID)) + } + + return hints +} + +// receivePayment instructs the channel identified by id to recv amt from the +// remote party for each state in from-to times and returns the breach hints for +// states [from, to). +func (h *testHarness) recvPayments(id, from, to uint64, + amt lnwire.MilliSatoshi) []wtdb.BreachHint { + + h.t.Helper() + + channel := h.channel(id) + + var hints []wtdb.BreachHint + for i := from; i < to; i++ { + channel.receivePayment(h.t, amt) + commitTx, _ := channel.getState(i) + breachTxID := commitTx.TxHash() + hints = append(hints, wtdb.NewBreachHintFromHash(&breachTxID)) + } + + return hints +} + +// waitServerUpdates blocks until the breach hints provided all appear in the +// watchtower's database or the timeout expires. This is used to test that the +// client in fact sends the updates to the server, even if it is offline. +func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint, + timeout time.Duration) { + + h.t.Helper() + + // If no breach hints are provided, we will wait out the full timeout to + // assert that no updates appear. + wantUpdates := len(hints) > 0 + + hintSet := make(map[wtdb.BreachHint]struct{}) + for _, hint := range hints { + hintSet[hint] = struct{}{} + } + + if len(hints) != len(hintSet) { + h.t.Fatalf("breach hints are not unique, list-len: %d "+ + "set-len: %d", len(hints), len(hintSet)) + } + + // Closure to assert the server's matches are consistent with the hint + // set. + serverHasHints := func(matches []wtdb.Match) bool { + if len(hintSet) != len(matches) { + return false + } + + for _, match := range matches { + if _, ok := hintSet[match.Hint]; ok { + continue + } + + h.t.Fatalf("match %v in db is not in hint set", + match.Hint) + } + + return true + } + + failTimeout := time.After(timeout) + for { + select { + case <-time.After(time.Second): + matches, err := h.serverDB.QueryMatches(hints) + switch { + case err != nil: + h.t.Fatalf("unable to query for hints: %v", err) + + case wantUpdates && serverHasHints(matches): + return + + case wantUpdates: + h.t.Logf("Received %d/%d\n", len(matches), + len(hints)) + } + + case <-failTimeout: + matches, err := h.serverDB.QueryMatches(hints) + switch { + case err != nil: + h.t.Fatalf("unable to query for hints: %v", err) + + case serverHasHints(matches): + return + + default: + h.t.Fatalf("breach hints not received, only "+ + "got %d/%d", len(matches), len(hints)) + } + } + } +} + +const ( + localBalance = lnwire.MilliSatoshi(100000000) + remoteBalance = lnwire.MilliSatoshi(200000000) +) + +type clientTest struct { + name string + cfg harnessCfg + fn func(*testHarness) +} + +var clientTests = []clientTest{ + { + // Asserts that client will return the ErrUnregisteredChannel + // error when trying to backup states for a channel that has not + // been registered (and received it's pkscript). + name: "backup unregistered channel", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 20000, + SweepFeeRate: 1, + }, + noRegisterChan0: true, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 5 + chanID = 0 + ) + + // Advance the channel and backup the retributions. We + // expect ErrUnregisteredChannel to be returned since + // the channel was not registered during harness + // creation. + h.advanceChannelN(chanID, numUpdates) + h.backupStates( + chanID, 0, numUpdates, + wtclient.ErrUnregisteredChannel, + ) + }, + }, + { + // Asserts that the client returns an ErrClientExiting when + // trying to backup channels after the Stop method has been + // called. + name: "backup after stop", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 20000, + SweepFeeRate: 1, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 5 + chanID = 0 + ) + + // Stop the client, subsequent backups should fail. + h.client.Stop() + + // Advance the channel and try to back up the states. We + // expect ErrClientExiting to be returned from + // BackupState. + h.advanceChannelN(chanID, numUpdates) + h.backupStates( + chanID, 0, numUpdates, + wtclient.ErrClientExiting, + ) + }, + }, + { + // Asserts that the client will continue to back up all states + // that have previously been enqueued before it finishes + // exiting. + name: "backup reliable flush", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 5, + SweepFeeRate: 1, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 5 + chanID = 0 + ) + + // Generate numUpdates retributions and back them up to + // the tower. + hints := h.advanceChannelN(chanID, numUpdates) + h.backupStates(chanID, 0, numUpdates, nil) + + // Stop the client in the background, to assert the + // pipeline is always flushed before it exits. + go h.client.Stop() + + // Wait for all of the updates to be populated in the + // server's database. + h.waitServerUpdates(hints, time.Second) + }, + }, + { + // Assert that the client will not send out backups for states + // whose justice transactions are ineligible for backup, e.g. + // creating dust outputs. + name: "backup dust ineligible", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 20000, + SweepFeeRate: 1000000, // high sweep fee creates dust + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 5 + chanID = 0 + ) + + // Create the retributions and queue them for backup. + h.advanceChannelN(chanID, numUpdates) + h.backupStates(chanID, 0, numUpdates, nil) + + // Ensure that no updates are received by the server, + // since they should all be marked as ineligible. + h.waitServerUpdates(nil, time.Second) + }, + }, + { + // Verifies that the client will properly retransmit a committed + // state update to the watchtower after a restart if the update + // was not acked while the client was active last. + name: "committed update restart", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 20000, + SweepFeeRate: 1, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 5 + chanID = 0 + ) + + hints := h.advanceChannelN(0, numUpdates) + + var numSent uint64 + + // Add the first two states to the client's pipeline. + h.backupStates(chanID, 0, 2, nil) + numSent = 2 + + // Wait for both to be reflected in the server's + // database. + h.waitServerUpdates(hints[:numSent], time.Second) + + // Now, restart the server and prevent it from acking + // state updates. + h.server.Stop() + h.serverCfg.NoAckUpdates = true + h.startServer() + defer h.server.Stop() + + // Send the next state update to the tower. Since the + // tower isn't acking state updates, we expect this + // update to be committed and sent by the session queue, + // but it will never receive an ack. + h.backupState(chanID, numSent, nil) + numSent++ + + // Force quit the client to abort the state updates it + // has queued. The sleep ensures that the session queues + // have enough time to commit the state updates before + // the client is killed. + time.Sleep(time.Second) + h.client.ForceQuit() + + // Restart the server and allow it to ack the updates + // after the client retransmits the unacked update. + h.server.Stop() + h.serverCfg.NoAckUpdates = false + h.startServer() + defer h.server.Stop() + + // Restart the client and allow it to process the + // committed update. + h.startClient() + defer h.client.ForceQuit() + + // Wait for the committed update to be accepted by the + // tower. + h.waitServerUpdates(hints[:numSent], time.Second) + + // Finally, send the rest of the updates and wait for + // the tower to receive the remaining states. + h.backupStates(chanID, numSent, numUpdates, nil) + + // Wait for all of the updates to be populated in the + // server's database. + h.waitServerUpdates(hints, time.Second) + + }, + }, + { + // Asserts that the client will continue to retry sending state + // updates if it doesn't receive an ack from the server. The + // client is expected to flush everything in its in-memory + // pipeline once the server begins sending acks again. + name: "no ack from server", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 5, + SweepFeeRate: 1, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 100 + chanID = 0 + ) + + // Generate the retributions that will be backed up. + hints := h.advanceChannelN(chanID, numUpdates) + + // Restart the server and prevent it from acking state + // updates. + h.server.Stop() + h.serverCfg.NoAckUpdates = true + h.startServer() + defer h.server.Stop() + + // Now, queue the retributions for backup. + h.backupStates(chanID, 0, numUpdates, nil) + + // Stop the client in the background, to assert the + // pipeline is always flushed before it exits. + go h.client.Stop() + + // Give the client time to saturate a large number of + // session queues for which the server has not acked the + // state updates that it has received. + time.Sleep(time.Second) + + // Restart the server and allow it to ack the updates + // after the client retransmits the unacked updates. + h.server.Stop() + h.serverCfg.NoAckUpdates = false + h.startServer() + defer h.server.Stop() + + // Wait for all of the updates to be populated in the + // server's database. + h.waitServerUpdates(hints, 5*time.Second) + }, + }, + { + // Asserts that the client is able to send state updates to the + // tower for a full range of channel values, assuming the sweep + // fee rates permit it. We expect all of these to be successful + // since a sweep transactions spending only from one output is + // less expensive than one that sweeps both. + name: "send and recv", + cfg: harnessCfg{ + localBalance: 10000001, // ensure (% amt != 0) + remoteBalance: 20000001, // ensure (% amt != 0) + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 1000, + SweepFeeRate: 1, + }, + }, + fn: func(h *testHarness) { + var ( + capacity = h.cfg.localBalance + h.cfg.remoteBalance + paymentAmt = lnwire.MilliSatoshi(200000) + numSends = uint64(h.cfg.localBalance / paymentAmt) + numRecvs = uint64(capacity / paymentAmt) + numUpdates = numSends + numRecvs // 200 updates + chanID = uint64(0) + ) + + // Send money to the remote party until all funds are + // depleted. + sendHints := h.sendPayments(chanID, 0, numSends, paymentAmt) + + // Now, sequentially receive the entire channel balance + // from the remote party. + recvHints := h.recvPayments(chanID, numSends, numUpdates, paymentAmt) + + // Collect the hints generated by both sending and + // receiving. + hints := append(sendHints, recvHints...) + + // Backup the channel's states the client. + h.backupStates(chanID, 0, numUpdates, nil) + + // Wait for all of the updates to be populated in the + // server's database. + h.waitServerUpdates(hints, 3*time.Second) + }, + }, + { + // Asserts that the client is able to support multiple links. + name: "multiple link backup", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 5, + SweepFeeRate: 1, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 5 + numChans = 10 + ) + + // Initialize and register an additional 9 channels. + for id := uint64(1); id < 10; id++ { + h.makeChannel( + id, h.cfg.localBalance, + h.cfg.remoteBalance, + ) + h.registerChannel(id) + } + + // Generate the retributions for all 10 channels and + // collect the breach hints. + var hints []wtdb.BreachHint + for id := uint64(0); id < 10; id++ { + chanHints := h.advanceChannelN(id, numUpdates) + hints = append(hints, chanHints...) + } + + // Provided all retributions to the client from all + // channels. + for id := uint64(0); id < 10; id++ { + h.backupStates(id, 0, numUpdates, nil) + } + + // Test reliable flush under multi-client scenario. + go h.client.Stop() + + // Wait for all of the updates to be populated in the + // server's database. + h.waitServerUpdates(hints, 10*time.Second) + }, + }, +} + +// TestClient executes the client test suite, asserting the ability to backup +// states in a number of failure cases and it's reliability during shutdown. +func TestClient(t *testing.T) { + for _, test := range clientTests { + tc := test + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + h := newHarness(t, tc.cfg) + defer h.server.Stop() + defer h.client.ForceQuit() + + tc.fn(h) + }) + } +} From 05e3a7f6c09949dcde3ca43f5ffefa46871f734c Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:34:00 -0700 Subject: [PATCH 21/21] watchtower/wtmock/peer: set local pubkey --- watchtower/wtmock/peer.go | 7 +++++-- watchtower/wtserver/server_test.go | 20 +++++++++++++------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/watchtower/wtmock/peer.go b/watchtower/wtmock/peer.go index 16aeb0d8c..fc1ff9af1 100644 --- a/watchtower/wtmock/peer.go +++ b/watchtower/wtmock/peer.go @@ -27,14 +27,17 @@ type MockPeer struct { } // NewMockPeer returns a fresh MockPeer. -func NewMockPeer(pk *btcec.PublicKey, addr net.Addr, bufferSize int) *MockPeer { +func NewMockPeer(lpk, rpk *btcec.PublicKey, addr net.Addr, + bufferSize int) *MockPeer { + return &MockPeer{ - remotePub: pk, + remotePub: rpk, remoteAddr: addr, localAddr: &net.TCPAddr{ IP: net.IP{0x32, 0x31, 0x30, 0x29}, Port: 36723, }, + localPub: lpk, IncomingMsgs: make(chan []byte, bufferSize), OutgoingMsgs: make(chan []byte, bufferSize), Quit: make(chan struct{}), diff --git a/watchtower/wtserver/server_test.go b/watchtower/wtserver/server_test.go index 6df7cd58d..cdbaa281c 100644 --- a/watchtower/wtserver/server_test.go +++ b/watchtower/wtserver/server_test.go @@ -87,10 +87,12 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) { s := initServer(t, nil, timeoutDuration) defer s.Stop() + localPub := randPubKey(t) + // Create two peers using the same session id. peerPub := randPubKey(t) - peer1 := wtmock.NewMockPeer(peerPub, nil, 0) - peer2 := wtmock.NewMockPeer(peerPub, nil, 0) + peer1 := wtmock.NewMockPeer(localPub, peerPub, nil, 0) + peer2 := wtmock.NewMockPeer(localPub, peerPub, nil, 0) // Serialize a Init message to be sent by both peers. init := wtwire.NewInitMessage( @@ -219,9 +221,11 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) { s := initServer(t, nil, timeoutDuration) defer s.Stop() + localPub := randPubKey(t) + // Create a new client and connect to server. peerPub := randPubKey(t) - peer := wtmock.NewMockPeer(peerPub, nil, 0) + peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0) connect(t, i, s, peer, test.initMsg, timeoutDuration) // Send the CreateSession message, and wait for a reply. @@ -249,7 +253,7 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) { // Simulate a peer with the same session id connection to the server // again. - peer = wtmock.NewMockPeer(peerPub, nil, 0) + peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) connect(t, i, s, peer, test.initMsg, timeoutDuration) // Send the _same_ CreateSession message as the first attempt. @@ -559,9 +563,11 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) { s := initServer(t, nil, timeoutDuration) defer s.Stop() + localPub := randPubKey(t) + // Create a new client and connect to the server. peerPub := randPubKey(t) - peer := wtmock.NewMockPeer(peerPub, nil, 0) + peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0) connect(t, i, s, peer, test.initMsg, timeoutDuration) // Register a session for this client to use in the subsequent tests. @@ -581,7 +587,7 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) { // Now that the original connection has been closed, connect a new // client with the same session id. - peer = wtmock.NewMockPeer(peerPub, nil, 0) + peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) connect(t, i, s, peer, test.initMsg, timeoutDuration) // Send the intended StateUpdate messages in series. @@ -592,7 +598,7 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) { if update == nil { assertConnClosed(t, peer, 2*timeoutDuration) - peer = wtmock.NewMockPeer(peerPub, nil, 0) + peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) connect(t, i, s, peer, test.initMsg, timeoutDuration) continue