diff --git a/watchtower/wtdb/mock.go b/watchtower/wtdb/mock.go index c47fe2b1c..fa53d6cf7 100644 --- a/watchtower/wtdb/mock.go +++ b/watchtower/wtdb/mock.go @@ -70,6 +70,33 @@ func (db *MockDB) InsertSessionInfo(info *SessionInfo) error { return nil } +func (db *MockDB) DeleteSession(target SessionID) error { + db.mu.Lock() + defer db.mu.Unlock() + + // Fail if the session doesn't exit. + if _, ok := db.sessions[target]; !ok { + return ErrSessionNotFound + } + + // Remove the target session. + delete(db.sessions, target) + + // Remove the state updates for any blobs stored under the target + // session identifier. + for hint, sessionUpdates := range db.blobs { + delete(sessionUpdates, target) + + //If this was the last state update, we can also remove the hint + //that would map to an empty set. + if len(sessionUpdates) == 0 { + delete(db.blobs, hint) + } + } + + return nil +} + func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { db.mu.Lock() defer db.mu.Unlock() diff --git a/watchtower/wtserver/create_session.go b/watchtower/wtserver/create_session.go new file mode 100644 index 000000000..e948d8b9a --- /dev/null +++ b/watchtower/wtserver/create_session.go @@ -0,0 +1,136 @@ +package wtserver + +import ( + "github.com/btcsuite/btcd/txscript" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +// handleCreateSession processes a CreateSession message from the peer, and returns +// a CreateSessionReply in response. This method will only succeed if no existing +// session info is known about the session id. If an existing session is found, +// the reward address is returned in case the client lost our reply. +func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, + req *wtwire.CreateSession) error { + + // TODO(conner): validate accept against policy + + // Query the db for session info belonging to the client's session id. + existingInfo, err := s.cfg.DB.GetSessionInfo(id) + switch { + + // We already have a session corresponding to this session id, return an + // error signaling that it already exists in our database. We return the + // reward address to the client in case they were not able to process + // our reply earlier. + case err == nil: + log.Debugf("Already have session for %s", id) + return s.replyCreateSession( + peer, id, wtwire.CreateSessionCodeAlreadyExists, + existingInfo.RewardAddress, + ) + + // Some other database error occurred, return a temporary failure. + case err != wtdb.ErrSessionNotFound: + log.Errorf("unable to load session info for %s", id) + return s.replyCreateSession( + peer, id, wtwire.CodeTemporaryFailure, nil, + ) + } + + // Now that we've established that this session does not exist in the + // database, retrieve the sweep address that will be given to the + // client. This address is to be included by the client when signing + // sweep transactions destined for this tower, if its negotiated output + // is not dust. + rewardAddress, err := s.cfg.NewAddress() + if err != nil { + log.Errorf("unable to generate reward addr for %s", id) + return s.replyCreateSession( + peer, id, wtwire.CodeTemporaryFailure, nil, + ) + } + + // Construct the pkscript the client should pay to when signing justice + // transactions for this session. + rewardScript, err := txscript.PayToAddrScript(rewardAddress) + if err != nil { + log.Errorf("unable to generate reward script for %s", id) + return s.replyCreateSession( + peer, id, wtwire.CodeTemporaryFailure, nil, + ) + } + + // Ensure that the requested blob type is supported by our tower. + if !blob.IsSupportedType(req.BlobType) { + log.Debugf("Rejecting CreateSession from %s, unsupported blob "+ + "type %s", id, req.BlobType) + return s.replyCreateSession( + peer, id, wtwire.CreateSessionCodeRejectBlobType, nil, + ) + } + + // TODO(conner): create invoice for upfront payment + + // Assemble the session info using the agreed upon parameters, reward + // address, and session id. + info := wtdb.SessionInfo{ + ID: *id, + Policy: wtpolicy.Policy{ + BlobType: req.BlobType, + MaxUpdates: req.MaxUpdates, + RewardBase: req.RewardBase, + RewardRate: req.RewardRate, + SweepFeeRate: req.SweepFeeRate, + }, + RewardAddress: rewardScript, + } + + // Insert the session info into the watchtower's database. If + // successful, the session will now be ready for use. + err = s.cfg.DB.InsertSessionInfo(&info) + if err != nil { + log.Errorf("unable to create session for %s", id) + return s.replyCreateSession( + peer, id, wtwire.CodeTemporaryFailure, nil, + ) + } + + log.Infof("Accepted session for %s", id) + + return s.replyCreateSession( + peer, id, wtwire.CodeOK, rewardScript, + ) +} + +// replyCreateSession sends a response to a CreateSession from a client. If the +// status code in the reply is OK, the error from the write will be bubbled up. +// Otherwise, this method returns a connection error to ensure we don't continue +// communication with the client. +func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID, + code wtwire.ErrorCode, data []byte) error { + + msg := &wtwire.CreateSessionReply{ + Code: code, + Data: data, + } + + err := s.sendMessage(peer, msg) + if err != nil { + log.Errorf("unable to send CreateSessionReply to %s", id) + } + + // Return the write error if the request succeeded. + if code == wtwire.CodeOK { + return err + } + + // Otherwise the request failed, return a connection failure to + // disconnect the client. + return &connFailure{ + ID: *id, + Code: uint16(code), + } +} diff --git a/watchtower/wtserver/delete_session.go b/watchtower/wtserver/delete_session.go new file mode 100644 index 000000000..a5b517c12 --- /dev/null +++ b/watchtower/wtserver/delete_session.go @@ -0,0 +1,57 @@ +package wtserver + +import ( + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +// handleDeleteSession processes a DeleteSession request for a client with given +// SessionID. The id is assumed to have been previously authenticated by the +// brontide connection. +func (s *Server) handleDeleteSession(peer Peer, id *wtdb.SessionID) error { + var failCode wtwire.DeleteSessionCode + + // Delete all session data associated with id. + err := s.cfg.DB.DeleteSession(*id) + switch { + case err == nil: + failCode = wtwire.CodeOK + + log.Debugf("Session %s deleted", id) + + case err == wtdb.ErrSessionNotFound: + failCode = wtwire.DeleteSessionCodeNotFound + + default: + failCode = wtwire.CodeTemporaryFailure + } + + return s.replyDeleteSession(peer, id, failCode) +} + +// replyDeleteSession sends a DeleteSessionReply back to the peer containing the +// error code resulting from processes a DeleteSession request. +func (s *Server) replyDeleteSession(peer Peer, id *wtdb.SessionID, + code wtwire.DeleteSessionCode) error { + + msg := &wtwire.DeleteSessionReply{ + Code: code, + } + + err := s.sendMessage(peer, msg) + if err != nil { + log.Errorf("Unable to send DeleteSessionReply to %s", id) + } + + // Return the write error if the request succeeded. + if code == wtwire.CodeOK { + return err + } + + // Otherwise the request failed, return a connection failure to + // disconnect the client. + return &connFailure{ + ID: *id, + Code: uint16(code), + } +} diff --git a/watchtower/wtserver/interface.go b/watchtower/wtserver/interface.go index 69d4ec9cf..d23645431 100644 --- a/watchtower/wtserver/interface.go +++ b/watchtower/wtserver/interface.go @@ -63,4 +63,8 @@ type DB interface { // validates the update against the current SessionInfo stored under the // update's session id.. InsertStateUpdate(*wtdb.SessionStateUpdate) (uint16, error) + + // DeleteSession removes all data associated with a particular session + // id from the tower's database. + DeleteSession(wtdb.SessionID) error } diff --git a/watchtower/wtserver/server.go b/watchtower/wtserver/server.go index b9d222757..4a49c3ad4 100644 --- a/watchtower/wtserver/server.go +++ b/watchtower/wtserver/server.go @@ -11,12 +11,9 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/connmgr" - "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtdb" - "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtwire" ) @@ -24,6 +21,10 @@ var ( // ErrPeerAlreadyConnected signals that a peer with the same session id // is already active within the server. ErrPeerAlreadyConnected = errors.New("peer already connected") + + // ErrServerExiting signals that a request could not be processed + // because the server has been requested to shut down. + ErrServerExiting = errors.New("server shutting down") ) // Config abstracts the primary components and dependencies of the server. @@ -242,241 +243,40 @@ func (s *Server) handleClient(peer Peer) { return } - // stateUpdateOnlyMode will become true if the client's first message is - // a StateUpdate. If instead, it is a CreateSession, this method will exit - // immediately after replying. We track this to ensure that the client - // can't send a CreateSession after having already sent a StateUpdate. - var stateUpdateOnlyMode bool - for { - select { - case <-s.quit: - return - default: - } + nextMsg, err := s.readMessage(peer) + if err != nil { + log.Errorf("Unable to read watchtower msg from %s: %v", + id, err) + return + } - nextMsg, err := s.readMessage(peer) + switch msg := nextMsg.(type) { + case *wtwire.CreateSession: + // Attempt to open a new session for this client. + err = s.handleCreateSession(peer, &id, msg) if err != nil { - log.Errorf("Unable to read watchtower msg from %x: %v", - id[:], err) - return + log.Errorf("Unable to handle CreateSession "+ + "from %s: %v", id, err) } - // Process the request according to the message's type. - switch msg := nextMsg.(type) { - - // A CreateSession indicates a request to establish a new session - // with our watchtower. - case *wtwire.CreateSession: - // Ensure CreateSession can only be sent as the first - // message. - if stateUpdateOnlyMode { - log.Errorf("client %x sent CreateSession after "+ - "StateUpdate", id) - return - } - - // Attempt to open a new session for this client. - err := s.handleCreateSession(peer, &id, msg) - if err != nil { - log.Errorf("Unable to handle CreateSession "+ - "from %s: %v", id, err) - } - - // Exit after replying to CreateSession. - return - - // A StateUpdate indicates an existing client attempting to - // back-up a revoked commitment state. - case *wtwire.StateUpdate: - // Try to accept the state update from the client. - err := s.handleStateUpdate(peer, &id, msg) - if err != nil { - log.Errorf("unable to handle StateUpdate "+ - "from %s: %v", id, err) - return - } - - // If the client signals that this is last StateUpdate - // message, we can disconnect the client. - if msg.IsComplete == 1 { - return - } - - // The client has signaled that more StateUpdates are - // yet to come. Enter state-update-only mode to disallow - // future sends of CreateSession messages. - stateUpdateOnlyMode = true - - default: - log.Errorf("received unsupported message type: %T "+ - "from %s", nextMsg, id) - return + case *wtwire.DeleteSession: + err = s.handleDeleteSession(peer, &id) + if err != nil { + log.Errorf("Unable to handle DeleteSession "+ + "from %s: %v", id, err) } - } -} -// handleCreateSession processes a CreateSession message from the peer, and returns -// a CreateSessionReply in response. This method will only succeed if no existing -// session info is known about the session id. If an existing session is found, -// the reward address is returned in case the client lost our reply. -func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, - req *wtwire.CreateSession) error { - - // TODO(conner): validate accept against policy - - // Query the db for session info belonging to the client's session id. - existingInfo, err := s.cfg.DB.GetSessionInfo(id) - switch { - - // We already have a session corresponding to this session id, return an - // error signaling that it already exists in our database. We return the - // reward address to the client in case they were not able to process - // our reply earlier. - case err == nil: - log.Debugf("Already have session for %s", id) - return s.replyCreateSession( - peer, id, wtwire.CreateSessionCodeAlreadyExists, - existingInfo.RewardAddress, - ) - - // Some other database error occurred, return a temporary failure. - case err != wtdb.ErrSessionNotFound: - log.Errorf("unable to load session info for %s", id) - return s.replyCreateSession( - peer, id, wtwire.CodeTemporaryFailure, nil, - ) - } - - // Now that we've established that this session does not exist in the - // database, retrieve the sweep address that will be given to the - // client. This address is to be included by the client when signing - // sweep transactions destined for this tower, if its negotiated output - // is not dust. - rewardAddress, err := s.cfg.NewAddress() - if err != nil { - log.Errorf("unable to generate reward addr for %s", id) - return s.replyCreateSession( - peer, id, wtwire.CodeTemporaryFailure, nil, - ) - } - - // Construct the pkscript the client should pay to when signing justice - // transactions for this session. - rewardScript, err := txscript.PayToAddrScript(rewardAddress) - if err != nil { - log.Errorf("unable to generate reward script for %s", id) - return s.replyCreateSession( - peer, id, wtwire.CodeTemporaryFailure, nil, - ) - } - - // Ensure that the requested blob type is supported by our tower. - if !blob.IsSupportedType(req.BlobType) { - log.Debugf("Rejecting CreateSession from %s, unsupported blob "+ - "type %s", id, req.BlobType) - return s.replyCreateSession( - peer, id, wtwire.CreateSessionCodeRejectBlobType, nil, - ) - } - - // TODO(conner): create invoice for upfront payment - - // Assemble the session info using the agreed upon parameters, reward - // address, and session id. - info := wtdb.SessionInfo{ - ID: *id, - Policy: wtpolicy.Policy{ - BlobType: req.BlobType, - MaxUpdates: req.MaxUpdates, - RewardBase: req.RewardBase, - RewardRate: req.RewardRate, - SweepFeeRate: req.SweepFeeRate, - }, - RewardAddress: rewardScript, - } - - // Insert the session info into the watchtower's database. If - // successful, the session will now be ready for use. - err = s.cfg.DB.InsertSessionInfo(&info) - if err != nil { - log.Errorf("unable to create session for %s", id) - return s.replyCreateSession( - peer, id, wtwire.CodeTemporaryFailure, nil, - ) - } - - log.Infof("Accepted session for %s", id) - - return s.replyCreateSession( - peer, id, wtwire.CodeOK, rewardScript, - ) -} - -// handleStateUpdate processes a StateUpdate message request from a client. An -// attempt will be made to insert the update into the db, where it is validated -// against the client's session. The possible errors are then mapped back to -// StateUpdateCodes specified by the watchtower wire protocol, and sent back -// using a StateUpdateReply message. -func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID, - update *wtwire.StateUpdate) error { - - var ( - lastApplied uint16 - failCode wtwire.ErrorCode - err error - ) - - sessionUpdate := wtdb.SessionStateUpdate{ - ID: *id, - Hint: update.Hint, - SeqNum: update.SeqNum, - LastApplied: update.LastApplied, - EncryptedBlob: update.EncryptedBlob, - } - - lastApplied, err = s.cfg.DB.InsertStateUpdate(&sessionUpdate) - switch { - case err == nil: - log.Debugf("State update %d accepted for %s", - update.SeqNum, id) - - failCode = wtwire.CodeOK - - // Return a permanent failure if a client tries to send an update for - // which we have no session. - case err == wtdb.ErrSessionNotFound: - failCode = wtwire.CodePermanentFailure - - case err == wtdb.ErrSeqNumAlreadyApplied: - failCode = wtwire.CodePermanentFailure - - // TODO(conner): remove session state for protocol - // violation. Could also double as clean up method for - // session-related state. - - case err == wtdb.ErrLastAppliedReversion: - failCode = wtwire.StateUpdateCodeClientBehind - - case err == wtdb.ErrSessionConsumed: - failCode = wtwire.StateUpdateCodeMaxUpdatesExceeded - - case err == wtdb.ErrUpdateOutOfOrder: - failCode = wtwire.StateUpdateCodeSeqNumOutOfOrder + case *wtwire.StateUpdate: + err = s.handleStateUpdates(peer, &id, msg) + if err != nil { + log.Errorf("Unable to handle StateUpdate "+ + "from %s: %v", id, err) + } default: - failCode = wtwire.CodeTemporaryFailure + log.Errorf("Received unsupported message type: %T "+ + "from %s", nextMsg, id) } - - if s.cfg.NoAckUpdates { - return &connFailure{ - ID: *id, - Code: uint16(failCode), - } - } - - return s.replyStateUpdate( - peer, id, failCode, lastApplied, - ) } // connFailure is a default error used when a request failed with a non-zero @@ -493,66 +293,6 @@ func (f *connFailure) Error() string { ) } -// replyCreateSession sends a response to a CreateSession from a client. If the -// status code in the reply is OK, the error from the write will be bubbled up. -// Otherwise, this method returns a connection error to ensure we don't continue -// communication with the client. -func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID, - code wtwire.ErrorCode, data []byte) error { - - msg := &wtwire.CreateSessionReply{ - Code: code, - Data: data, - } - - err := s.sendMessage(peer, msg) - if err != nil { - log.Errorf("unable to send CreateSessionReply to %s", id) - } - - // Return the write error if the request succeeded. - if code == wtwire.CodeOK { - return err - } - - // Otherwise the request failed, return a connection failure to - // disconnect the client. - return &connFailure{ - ID: *id, - Code: uint16(code), - } -} - -// replyStateUpdate sends a response to a StateUpdate from a client. If the -// status code in the reply is OK, the error from the write will be bubbled up. -// Otherwise, this method returns a connection error to ensure we don't continue -// communication with the client. -func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID, - code wtwire.StateUpdateCode, lastApplied uint16) error { - - msg := &wtwire.StateUpdateReply{ - Code: code, - LastApplied: lastApplied, - } - - err := s.sendMessage(peer, msg) - if err != nil { - log.Errorf("unable to send StateUpdateReply to %s", id) - } - - // Return the write error if the request succeeded. - if code == wtwire.CodeOK { - return err - } - - // Otherwise the request failed, return a connection failure to - // disconnect the client. - return &connFailure{ - ID: *id, - Code: uint16(code), - } -} - // readMessage receives and parses the next message from the given Peer. An // error is returned if a message is not received before the server's read // timeout, the read off the wire failed, or the message could not be diff --git a/watchtower/wtserver/server_test.go b/watchtower/wtserver/server_test.go index cdbaa281c..d9e3bec50 100644 --- a/watchtower/wtserver/server_test.go +++ b/watchtower/wtserver/server_test.go @@ -226,13 +226,13 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) { // Create a new client and connect to server. peerPub := randPubKey(t) peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0) - connect(t, i, s, peer, test.initMsg, timeoutDuration) + connect(t, s, peer, test.initMsg, timeoutDuration) // Send the CreateSession message, and wait for a reply. - sendMsg(t, i, test.createMsg, peer, timeoutDuration) + sendMsg(t, test.createMsg, peer, timeoutDuration) reply := recvReply( - t, i, "CreateSessionReply", peer, timeoutDuration, + t, "MsgCreateSessionReply", peer, timeoutDuration, ).(*wtwire.CreateSessionReply) // Verify that the server's response matches our expectation. @@ -254,13 +254,13 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) { // Simulate a peer with the same session id connection to the server // again. peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) - connect(t, i, s, peer, test.initMsg, timeoutDuration) + connect(t, s, peer, test.initMsg, timeoutDuration) // Send the _same_ CreateSession message as the first attempt. - sendMsg(t, i, test.createMsg, peer, timeoutDuration) + sendMsg(t, test.createMsg, peer, timeoutDuration) reply = recvReply( - t, i, "CreateSessionReply", peer, timeoutDuration, + t, "MsgCreateSessionReply", peer, timeoutDuration, ).(*wtwire.CreateSessionReply) // Ensure that the server's reply matches our expected response for a @@ -550,14 +550,14 @@ var stateUpdateTests = []stateUpdateTestCase{ func TestServerStateUpdates(t *testing.T) { t.Parallel() - for i, test := range stateUpdateTests { + for _, test := range stateUpdateTests { t.Run(test.name, func(t *testing.T) { - testServerStateUpdates(t, i, test) + testServerStateUpdates(t, test) }) } } -func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) { +func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) { const timeoutDuration = 100 * time.Millisecond s := initServer(t, nil, timeoutDuration) @@ -568,17 +568,17 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) { // Create a new client and connect to the server. peerPub := randPubKey(t) peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0) - connect(t, i, s, peer, test.initMsg, timeoutDuration) + connect(t, s, peer, test.initMsg, timeoutDuration) // Register a session for this client to use in the subsequent tests. - sendMsg(t, i, test.createMsg, peer, timeoutDuration) + sendMsg(t, test.createMsg, peer, timeoutDuration) initReply := recvReply( - t, i, "CreateSessionReply", peer, timeoutDuration, + t, "MsgCreateSessionReply", peer, timeoutDuration, ).(*wtwire.CreateSessionReply) // Fail if the server rejected our proposed CreateSession message. if initReply.Code != wtwire.CodeOK { - t.Fatalf("[test %d] server rejected session init", i) + t.Fatalf("server rejected session init") } // Check that the server closed the connection used to register the @@ -588,7 +588,7 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) { // Now that the original connection has been closed, connect a new // client with the same session id. peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) - connect(t, i, s, peer, test.initMsg, timeoutDuration) + connect(t, s, peer, test.initMsg, timeoutDuration) // Send the intended StateUpdate messages in series. for j, update := range test.updates { @@ -599,21 +599,21 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) { assertConnClosed(t, peer, 2*timeoutDuration) peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) - connect(t, i, s, peer, test.initMsg, timeoutDuration) + connect(t, s, peer, test.initMsg, timeoutDuration) continue } // Send the state update and verify it against our expected // response. - sendMsg(t, i, update, peer, timeoutDuration) + sendMsg(t, update, peer, timeoutDuration) reply := recvReply( - t, i, "StateUpdateReply", peer, timeoutDuration, + t, "MsgStateUpdateReply", peer, timeoutDuration, ).(*wtwire.StateUpdateReply) if !reflect.DeepEqual(reply, test.replies[j]) { - t.Fatalf("[test %d, update %d] expected reply "+ - "%v, got %d", i, j, + t.Fatalf("[update %d] expected reply "+ + "%v, got %d", j, test.replies[j], reply) } } @@ -622,16 +622,148 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) { assertConnClosed(t, peer, 2*timeoutDuration) } -func connect(t *testing.T, i int, s wtserver.Interface, peer *wtmock.MockPeer, +// TestServerDeleteSession asserts the response to a DeleteSession request, and +// checking that the proper error is returned when the session doesn't exist and +// that a successful deletion does not disrupt other sessions. +func TestServerDeleteSession(t *testing.T) { + db := wtdb.NewMockDB() + + localPub := randPubKey(t) + + // Initialize two distinct peers with different session ids. + peerPub1 := randPubKey(t) + peerPub2 := randPubKey(t) + + id1 := wtdb.NewSessionIDFromPubKey(peerPub1) + id2 := wtdb.NewSessionIDFromPubKey(peerPub2) + + // Create closure to simplify assertions on session existence with the + // server's database. + hasSession := func(t *testing.T, id *wtdb.SessionID, shouldHave bool) { + t.Helper() + + _, err := db.GetSessionInfo(id) + switch { + case shouldHave && err != nil: + t.Fatalf("expected server to have session %s, got: %v", + id, err) + case !shouldHave && err != wtdb.ErrSessionNotFound: + t.Fatalf("expected ErrSessionNotFound for session %s, "+ + "got: %v", id, err) + } + } + + initMsg := wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ) + + createSession := &wtwire.CreateSession{ + BlobType: blob.TypeDefault, + MaxUpdates: 1000, + RewardBase: 0, + RewardRate: 0, + SweepFeeRate: 1, + } + + const timeoutDuration = 100 * time.Millisecond + + s := initServer(t, db, timeoutDuration) + defer s.Stop() + + // Create a session for peer2 so that the server's db isn't completely + // empty. + peer2 := wtmock.NewMockPeer(localPub, peerPub2, nil, 0) + connect(t, s, peer2, initMsg, timeoutDuration) + sendMsg(t, createSession, peer2, timeoutDuration) + assertConnClosed(t, peer2, 2*timeoutDuration) + + // Our initial assertions are that peer2 has a valid session, but peer1 + // has not created one. + hasSession(t, &id1, false) + hasSession(t, &id2, true) + + peer1Msgs := []struct { + send wtwire.Message + recv wtwire.Message + assert func(t *testing.T) + }{ + { + // Deleting unknown session should fail. + send: &wtwire.DeleteSession{}, + recv: &wtwire.DeleteSessionReply{ + Code: wtwire.DeleteSessionCodeNotFound, + }, + assert: func(t *testing.T) { + // Peer2 should still be only session. + hasSession(t, &id1, false) + hasSession(t, &id2, true) + }, + }, + { + // Create session for peer1. + send: createSession, + recv: &wtwire.CreateSessionReply{ + Code: wtwire.CodeOK, + Data: addrScript, + }, + assert: func(t *testing.T) { + // Both peers should have sessions. + hasSession(t, &id1, true) + hasSession(t, &id2, true) + }, + }, + + { + // Delete peer1's session. + send: &wtwire.DeleteSession{}, + recv: &wtwire.DeleteSessionReply{ + Code: wtwire.CodeOK, + }, + assert: func(t *testing.T) { + // Peer1's session should have been removed. + hasSession(t, &id1, false) + hasSession(t, &id2, true) + }, + }, + } + + // Now as peer1, process the canned messages defined above. This will: + // 1. Try to delete an unknown session and get a not found error code. + // 2. Create a new session using the same parameters as peer2. + // 3. Delete the newly created session and get an OK. + for _, msg := range peer1Msgs { + peer1 := wtmock.NewMockPeer(localPub, peerPub1, nil, 0) + connect(t, s, peer1, initMsg, timeoutDuration) + sendMsg(t, msg.send, peer1, timeoutDuration) + reply := recvReply( + t, msg.recv.MsgType().String(), peer1, timeoutDuration, + ) + + if !reflect.DeepEqual(reply, msg.recv) { + t.Fatalf("expected reply: %v, got: %v", msg.recv, reply) + } + + assertConnClosed(t, peer1, 2*timeoutDuration) + + // Invoke assertions after completing the request/response + // dance. + msg.assert(t) + } +} + +func connect(t *testing.T, s wtserver.Interface, peer *wtmock.MockPeer, initMsg *wtwire.Init, timeout time.Duration) { + t.Helper() + s.InboundPeerConnected(peer) - sendMsg(t, i, initMsg, peer, timeout) - recvReply(t, i, "Init", peer, timeout) + sendMsg(t, initMsg, peer, timeout) + recvReply(t, "MsgInit", peer, timeout) } // sendMsg sends a wtwire.Message message via a wtmock.MockPeer. -func sendMsg(t *testing.T, i int, msg wtwire.Message, +func sendMsg(t *testing.T, msg wtwire.Message, peer *wtmock.MockPeer, timeout time.Duration) { t.Helper() @@ -639,22 +771,22 @@ func sendMsg(t *testing.T, i int, msg wtwire.Message, var b bytes.Buffer _, err := wtwire.WriteMessage(&b, msg, 0) if err != nil { - t.Fatalf("[test %d] unable to encode %T message: %v", - i, msg, err) + t.Fatalf("unable to encode %T message: %v", + msg, err) } select { case peer.IncomingMsgs <- b.Bytes(): case <-time.After(2 * timeout): - t.Fatalf("[test %d] unable to send %T message", i, msg) + t.Fatalf("unable to send %T message", msg) } } // recvReply receives a message from the server, and parses it according to // expected reply type. The supported replies are CreateSessionReply and // StateUpdateReply. -func recvReply(t *testing.T, i int, name string, - peer *wtmock.MockPeer, timeout time.Duration) wtwire.Message { +func recvReply(t *testing.T, name string, peer *wtmock.MockPeer, + timeout time.Duration) wtwire.Message { t.Helper() @@ -667,29 +799,34 @@ func recvReply(t *testing.T, i int, name string, case b := <-peer.OutgoingMsgs: msg, err = wtwire.ReadMessage(bytes.NewReader(b), 0) if err != nil { - t.Fatalf("[test %d] unable to decode server "+ - "reply: %v", i, err) + t.Fatalf("unable to decode server "+ + "reply: %v", err) } case <-time.After(2 * timeout): - t.Fatalf("[test %d] server did not reply", i) + t.Fatalf("server did not reply") } switch name { - case "Init": + case "MsgInit": if _, ok := msg.(*wtwire.Init); !ok { - t.Fatalf("[test %d] expected %s reply "+ - "message, got %T", i, name, msg) + t.Fatalf("expected %s reply message, "+ + "got %T", name, msg) } - case "CreateSessionReply": + case "MsgCreateSessionReply": if _, ok := msg.(*wtwire.CreateSessionReply); !ok { - t.Fatalf("[test %d] expected %s reply "+ - "message, got %T", i, name, msg) + t.Fatalf("expected %s reply message, "+ + "got %T", name, msg) } - case "StateUpdateReply": + case "MsgStateUpdateReply": if _, ok := msg.(*wtwire.StateUpdateReply); !ok { - t.Fatalf("[test %d] expected %s reply "+ - "message, got %T", i, name, msg) + t.Fatalf("expected %s reply message, "+ + "got %T", name, msg) + } + case "MsgDeleteSessionReply": + if _, ok := msg.(*wtwire.DeleteSessionReply); !ok { + t.Fatalf("expected %s reply message, "+ + "got %T", name, msg) } } diff --git a/watchtower/wtserver/state_update.go b/watchtower/wtserver/state_update.go new file mode 100644 index 000000000..7b3e0941b --- /dev/null +++ b/watchtower/wtserver/state_update.go @@ -0,0 +1,157 @@ +package wtserver + +import ( + "fmt" + + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +// handleStateUpdates processes a stream of StateUpdate requests from the +// client. The provided update should be the first such update read, subsequent +// updates will be consumed if the peer does not signal IsComplete on a +// particular update. +func (s *Server) handleStateUpdates(peer Peer, id *wtdb.SessionID, + update *wtwire.StateUpdate) error { + + // Set the current update to the first update read off the wire. + // Additional updates will be read if this value is set to nil after + // processing the first. + var curUpdate = update + for { + // If this is not the first update, read the next state update + // from the peer. + if curUpdate == nil { + nextMsg, err := s.readMessage(peer) + if err != nil { + return err + } + + var ok bool + curUpdate, ok = nextMsg.(*wtwire.StateUpdate) + if !ok { + return fmt.Errorf("client sent %T after "+ + "StateUpdate", nextMsg) + } + } + + // Try to accept the state update from the client. + err := s.handleStateUpdate(peer, id, curUpdate) + if err != nil { + return err + } + + // If the client signals that this is last StateUpdate + // message, we can disconnect the client. + if curUpdate.IsComplete == 1 { + return nil + } + + // Reset the current update to read subsequent updates in the + // stream. + curUpdate = nil + + select { + case <-s.quit: + return ErrServerExiting + default: + } + } +} + +// handleStateUpdate processes a StateUpdate message request from a client. An +// attempt will be made to insert the update into the db, where it is validated +// against the client's session. The possible errors are then mapped back to +// StateUpdateCodes specified by the watchtower wire protocol, and sent back +// using a StateUpdateReply message. +func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID, + update *wtwire.StateUpdate) error { + + var ( + lastApplied uint16 + failCode wtwire.ErrorCode + err error + ) + + sessionUpdate := wtdb.SessionStateUpdate{ + ID: *id, + Hint: update.Hint, + SeqNum: update.SeqNum, + LastApplied: update.LastApplied, + EncryptedBlob: update.EncryptedBlob, + } + + lastApplied, err = s.cfg.DB.InsertStateUpdate(&sessionUpdate) + switch { + case err == nil: + log.Debugf("State update %d accepted for %s", + update.SeqNum, id) + + failCode = wtwire.CodeOK + + // Return a permanent failure if a client tries to send an update for + // which we have no session. + case err == wtdb.ErrSessionNotFound: + failCode = wtwire.CodePermanentFailure + + case err == wtdb.ErrSeqNumAlreadyApplied: + failCode = wtwire.CodePermanentFailure + + // TODO(conner): remove session state for protocol + // violation. Could also double as clean up method for + // session-related state. + + case err == wtdb.ErrLastAppliedReversion: + failCode = wtwire.StateUpdateCodeClientBehind + + case err == wtdb.ErrSessionConsumed: + failCode = wtwire.StateUpdateCodeMaxUpdatesExceeded + + case err == wtdb.ErrUpdateOutOfOrder: + failCode = wtwire.StateUpdateCodeSeqNumOutOfOrder + + default: + failCode = wtwire.CodeTemporaryFailure + } + + if s.cfg.NoAckUpdates { + return &connFailure{ + ID: *id, + Code: uint16(failCode), + } + } + + return s.replyStateUpdate( + peer, id, failCode, lastApplied, + ) +} + +// replyStateUpdate sends a response to a StateUpdate from a client. If the +// status code in the reply is OK, the error from the write will be bubbled up. +// Otherwise, this method returns a connection error to ensure we don't continue +// communication with the client. +func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID, + code wtwire.StateUpdateCode, lastApplied uint16) error { + + msg := &wtwire.StateUpdateReply{ + Code: code, + LastApplied: lastApplied, + } + + err := s.sendMessage(peer, msg) + if err != nil { + log.Errorf("unable to send StateUpdateReply to %s", id) + } + + // Return the write error if the request succeeded. + if code == wtwire.CodeOK { + return err + } + + // Otherwise the request failed, return a connection failure to + // disconnect the client. + return &connFailure{ + ID: *id, + Code: uint16(code), + } +} diff --git a/watchtower/wtwire/delete_session.go b/watchtower/wtwire/delete_session.go new file mode 100644 index 000000000..fbe8cfb01 --- /dev/null +++ b/watchtower/wtwire/delete_session.go @@ -0,0 +1,45 @@ +package wtwire + +import "io" + +// DeleteSession is sent from the client to the tower to signal that the tower +// can delete all session state for the session key used to authenticate the +// brontide connection. This should be done by the client once all channels that +// have state updates in the session have been resolved on-chain. +type DeleteSession struct{} + +// Compile-time constraint to ensure DeleteSession implements the wtwire.Message +// interface. +var _ Message = (*DeleteSession)(nil) + +// Decode deserializes a serialized DeleteSession message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *DeleteSession) Decode(r io.Reader, pver uint32) error { + return nil +} + +// Encode serializes the target DeleteSession message into the passed io.Writer +// observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *DeleteSession) Encode(w io.Writer, pver uint32) error { + return nil +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the wtwire.Message interface. +func (m *DeleteSession) MsgType() MessageType { + return MsgDeleteSession +} + +// MaxPayloadLength returns the maximum allowed payload size for a DeleteSession +// message observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *DeleteSession) MaxPayloadLength(uint32) uint32 { + return 0 +} diff --git a/watchtower/wtwire/delete_session_reply.go b/watchtower/wtwire/delete_session_reply.go new file mode 100644 index 000000000..059d189e8 --- /dev/null +++ b/watchtower/wtwire/delete_session_reply.go @@ -0,0 +1,66 @@ +package wtwire + +import "io" + +// DeleteSessionCode is an error code returned by a watchtower in response to a +// DeleteSession message. +type DeleteSessionCode = ErrorCode + +const ( + // DeleteSessionCodeNotFound is returned when the watchtower does not + // know of the requested session. This may indicate an error on the + // client side, or that the tower had already deleted the session in a + // prior request that the client may not have received. + DeleteSessionCodeNotFound DeleteSessionCode = 80 + + // TODO(conner): add String method after wtclient is merged +) + +// DeleteSessionReply is a message sent in response to a client's DeleteSession +// request. The message indicates whether or not the deletion was a success or +// failure. +type DeleteSessionReply struct { + // Code will be non-zero if the watchtower was not able to delete the + // requested session. + Code DeleteSessionCode +} + +// A compile time check to ensure DeleteSessionReply implements the +// wtwire.Message interface. +var _ Message = (*DeleteSessionReply)(nil) + +// Decode deserializes a serialized DeleteSessionReply message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *DeleteSessionReply) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &m.Code, + ) +} + +// Encode serializes the target DeleteSessionReply into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the wtwire.Message interface. +func (m *DeleteSessionReply) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + m.Code, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the wtwire.Message interface. +func (m *DeleteSessionReply) MsgType() MessageType { + return MsgDeleteSessionReply +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// DeleteSessionReply complete message observing the specified protocol version. +// +// This is part of the wtwire.Message interface. +func (m *DeleteSessionReply) MaxPayloadLength(uint32) uint32 { + return 2 +} diff --git a/watchtower/wtwire/message.go b/watchtower/wtwire/message.go index 364a2dab5..192a397c2 100644 --- a/watchtower/wtwire/message.go +++ b/watchtower/wtwire/message.go @@ -40,6 +40,13 @@ const ( // MsgStateUpdateReply identifies an encoded StateUpdateReply message. MsgStateUpdateReply MessageType = 305 + + // MsgDeleteSession identifies an encoded DeleteSession message. + MsgDeleteSession MessageType = 306 + + // MsgDeleteSessionReply identifies an encoded DeleteSessionReply + // message. + MsgDeleteSessionReply MessageType = 307 ) // String returns a human readable description of the message type. @@ -55,6 +62,10 @@ func (m MessageType) String() string { return "MsgStateUpdate" case MsgStateUpdateReply: return "MsgStateUpdateReply" + case MsgDeleteSession: + return "MsgDeleteSession" + case MsgDeleteSessionReply: + return "MsgDeleteSessionReply" case MsgError: return "Error" default: @@ -97,6 +108,10 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &StateUpdate{} case MsgStateUpdateReply: msg = &StateUpdateReply{} + case MsgDeleteSession: + msg = &DeleteSession{} + case MsgDeleteSessionReply: + msg = &DeleteSessionReply{} case MsgError: msg = &Error{} default: diff --git a/watchtower/wtwire/wtwire_test.go b/watchtower/wtwire/wtwire_test.go index 1dfef1a26..e9b37a559 100644 --- a/watchtower/wtwire/wtwire_test.go +++ b/watchtower/wtwire/wtwire_test.go @@ -126,6 +126,18 @@ func TestWatchtowerWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: wtwire.MsgDeleteSession, + scenario: func(m wtwire.DeleteSession) bool { + return mainScenario(&m) + }, + }, + { + msgType: wtwire.MsgDeleteSessionReply, + scenario: func(m wtwire.DeleteSessionReply) bool { + return mainScenario(&m) + }, + }, { msgType: wtwire.MsgError, scenario: func(m wtwire.Error) bool {