Merge pull request #2779 from cfromknecht/wtserver-delete-session

watchtower/wtserver: add DeleteSession request
This commit is contained in:
Olaoluwa Osuntokun 2019-03-19 22:01:17 -07:00 committed by GitHub
commit 9143067014
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 725 additions and 329 deletions

View file

@ -70,6 +70,33 @@ func (db *MockDB) InsertSessionInfo(info *SessionInfo) error {
return nil return nil
} }
func (db *MockDB) DeleteSession(target SessionID) error {
db.mu.Lock()
defer db.mu.Unlock()
// Fail if the session doesn't exit.
if _, ok := db.sessions[target]; !ok {
return ErrSessionNotFound
}
// Remove the target session.
delete(db.sessions, target)
// Remove the state updates for any blobs stored under the target
// session identifier.
for hint, sessionUpdates := range db.blobs {
delete(sessionUpdates, target)
//If this was the last state update, we can also remove the hint
//that would map to an empty set.
if len(sessionUpdates) == 0 {
delete(db.blobs, hint)
}
}
return nil
}
func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()

View file

@ -0,0 +1,136 @@
package wtserver
import (
"github.com/btcsuite/btcd/txscript"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
// handleCreateSession processes a CreateSession message from the peer, and returns
// a CreateSessionReply in response. This method will only succeed if no existing
// session info is known about the session id. If an existing session is found,
// the reward address is returned in case the client lost our reply.
func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
req *wtwire.CreateSession) error {
// TODO(conner): validate accept against policy
// Query the db for session info belonging to the client's session id.
existingInfo, err := s.cfg.DB.GetSessionInfo(id)
switch {
// We already have a session corresponding to this session id, return an
// error signaling that it already exists in our database. We return the
// reward address to the client in case they were not able to process
// our reply earlier.
case err == nil:
log.Debugf("Already have session for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CreateSessionCodeAlreadyExists,
existingInfo.RewardAddress,
)
// Some other database error occurred, return a temporary failure.
case err != wtdb.ErrSessionNotFound:
log.Errorf("unable to load session info for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
// Now that we've established that this session does not exist in the
// database, retrieve the sweep address that will be given to the
// client. This address is to be included by the client when signing
// sweep transactions destined for this tower, if its negotiated output
// is not dust.
rewardAddress, err := s.cfg.NewAddress()
if err != nil {
log.Errorf("unable to generate reward addr for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
// Construct the pkscript the client should pay to when signing justice
// transactions for this session.
rewardScript, err := txscript.PayToAddrScript(rewardAddress)
if err != nil {
log.Errorf("unable to generate reward script for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
// Ensure that the requested blob type is supported by our tower.
if !blob.IsSupportedType(req.BlobType) {
log.Debugf("Rejecting CreateSession from %s, unsupported blob "+
"type %s", id, req.BlobType)
return s.replyCreateSession(
peer, id, wtwire.CreateSessionCodeRejectBlobType, nil,
)
}
// TODO(conner): create invoice for upfront payment
// Assemble the session info using the agreed upon parameters, reward
// address, and session id.
info := wtdb.SessionInfo{
ID: *id,
Policy: wtpolicy.Policy{
BlobType: req.BlobType,
MaxUpdates: req.MaxUpdates,
RewardBase: req.RewardBase,
RewardRate: req.RewardRate,
SweepFeeRate: req.SweepFeeRate,
},
RewardAddress: rewardScript,
}
// Insert the session info into the watchtower's database. If
// successful, the session will now be ready for use.
err = s.cfg.DB.InsertSessionInfo(&info)
if err != nil {
log.Errorf("unable to create session for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
log.Infof("Accepted session for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeOK, rewardScript,
)
}
// replyCreateSession sends a response to a CreateSession from a client. If the
// status code in the reply is OK, the error from the write will be bubbled up.
// Otherwise, this method returns a connection error to ensure we don't continue
// communication with the client.
func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
code wtwire.ErrorCode, data []byte) error {
msg := &wtwire.CreateSessionReply{
Code: code,
Data: data,
}
err := s.sendMessage(peer, msg)
if err != nil {
log.Errorf("unable to send CreateSessionReply to %s", id)
}
// Return the write error if the request succeeded.
if code == wtwire.CodeOK {
return err
}
// Otherwise the request failed, return a connection failure to
// disconnect the client.
return &connFailure{
ID: *id,
Code: uint16(code),
}
}

View file

@ -0,0 +1,57 @@
package wtserver
import (
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
// handleDeleteSession processes a DeleteSession request for a client with given
// SessionID. The id is assumed to have been previously authenticated by the
// brontide connection.
func (s *Server) handleDeleteSession(peer Peer, id *wtdb.SessionID) error {
var failCode wtwire.DeleteSessionCode
// Delete all session data associated with id.
err := s.cfg.DB.DeleteSession(*id)
switch {
case err == nil:
failCode = wtwire.CodeOK
log.Debugf("Session %s deleted", id)
case err == wtdb.ErrSessionNotFound:
failCode = wtwire.DeleteSessionCodeNotFound
default:
failCode = wtwire.CodeTemporaryFailure
}
return s.replyDeleteSession(peer, id, failCode)
}
// replyDeleteSession sends a DeleteSessionReply back to the peer containing the
// error code resulting from processes a DeleteSession request.
func (s *Server) replyDeleteSession(peer Peer, id *wtdb.SessionID,
code wtwire.DeleteSessionCode) error {
msg := &wtwire.DeleteSessionReply{
Code: code,
}
err := s.sendMessage(peer, msg)
if err != nil {
log.Errorf("Unable to send DeleteSessionReply to %s", id)
}
// Return the write error if the request succeeded.
if code == wtwire.CodeOK {
return err
}
// Otherwise the request failed, return a connection failure to
// disconnect the client.
return &connFailure{
ID: *id,
Code: uint16(code),
}
}

View file

@ -63,4 +63,8 @@ type DB interface {
// validates the update against the current SessionInfo stored under the // validates the update against the current SessionInfo stored under the
// update's session id.. // update's session id..
InsertStateUpdate(*wtdb.SessionStateUpdate) (uint16, error) InsertStateUpdate(*wtdb.SessionStateUpdate) (uint16, error)
// DeleteSession removes all data associated with a particular session
// id from the tower's database.
DeleteSession(wtdb.SessionID) error
} }

View file

@ -11,12 +11,9 @@ import (
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/connmgr"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/watchtower/wtwire" "github.com/lightningnetwork/lnd/watchtower/wtwire"
) )
@ -24,6 +21,10 @@ var (
// ErrPeerAlreadyConnected signals that a peer with the same session id // ErrPeerAlreadyConnected signals that a peer with the same session id
// is already active within the server. // is already active within the server.
ErrPeerAlreadyConnected = errors.New("peer already connected") ErrPeerAlreadyConnected = errors.New("peer already connected")
// ErrServerExiting signals that a request could not be processed
// because the server has been requested to shut down.
ErrServerExiting = errors.New("server shutting down")
) )
// Config abstracts the primary components and dependencies of the server. // Config abstracts the primary components and dependencies of the server.
@ -242,242 +243,41 @@ func (s *Server) handleClient(peer Peer) {
return return
} }
// stateUpdateOnlyMode will become true if the client's first message is
// a StateUpdate. If instead, it is a CreateSession, this method will exit
// immediately after replying. We track this to ensure that the client
// can't send a CreateSession after having already sent a StateUpdate.
var stateUpdateOnlyMode bool
for {
select {
case <-s.quit:
return
default:
}
nextMsg, err := s.readMessage(peer) nextMsg, err := s.readMessage(peer)
if err != nil { if err != nil {
log.Errorf("Unable to read watchtower msg from %x: %v", log.Errorf("Unable to read watchtower msg from %s: %v",
id[:], err) id, err)
return return
} }
// Process the request according to the message's type.
switch msg := nextMsg.(type) { switch msg := nextMsg.(type) {
// A CreateSession indicates a request to establish a new session
// with our watchtower.
case *wtwire.CreateSession: case *wtwire.CreateSession:
// Ensure CreateSession can only be sent as the first
// message.
if stateUpdateOnlyMode {
log.Errorf("client %x sent CreateSession after "+
"StateUpdate", id)
return
}
// Attempt to open a new session for this client. // Attempt to open a new session for this client.
err := s.handleCreateSession(peer, &id, msg) err = s.handleCreateSession(peer, &id, msg)
if err != nil { if err != nil {
log.Errorf("Unable to handle CreateSession "+ log.Errorf("Unable to handle CreateSession "+
"from %s: %v", id, err) "from %s: %v", id, err)
} }
// Exit after replying to CreateSession. case *wtwire.DeleteSession:
return err = s.handleDeleteSession(peer, &id)
// A StateUpdate indicates an existing client attempting to
// back-up a revoked commitment state.
case *wtwire.StateUpdate:
// Try to accept the state update from the client.
err := s.handleStateUpdate(peer, &id, msg)
if err != nil { if err != nil {
log.Errorf("unable to handle StateUpdate "+ log.Errorf("Unable to handle DeleteSession "+
"from %s: %v", id, err) "from %s: %v", id, err)
return
} }
// If the client signals that this is last StateUpdate case *wtwire.StateUpdate:
// message, we can disconnect the client. err = s.handleStateUpdates(peer, &id, msg)
if msg.IsComplete == 1 { if err != nil {
return log.Errorf("Unable to handle StateUpdate "+
"from %s: %v", id, err)
} }
// The client has signaled that more StateUpdates are
// yet to come. Enter state-update-only mode to disallow
// future sends of CreateSession messages.
stateUpdateOnlyMode = true
default: default:
log.Errorf("received unsupported message type: %T "+ log.Errorf("Received unsupported message type: %T "+
"from %s", nextMsg, id) "from %s", nextMsg, id)
return
} }
} }
}
// handleCreateSession processes a CreateSession message from the peer, and returns
// a CreateSessionReply in response. This method will only succeed if no existing
// session info is known about the session id. If an existing session is found,
// the reward address is returned in case the client lost our reply.
func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
req *wtwire.CreateSession) error {
// TODO(conner): validate accept against policy
// Query the db for session info belonging to the client's session id.
existingInfo, err := s.cfg.DB.GetSessionInfo(id)
switch {
// We already have a session corresponding to this session id, return an
// error signaling that it already exists in our database. We return the
// reward address to the client in case they were not able to process
// our reply earlier.
case err == nil:
log.Debugf("Already have session for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CreateSessionCodeAlreadyExists,
existingInfo.RewardAddress,
)
// Some other database error occurred, return a temporary failure.
case err != wtdb.ErrSessionNotFound:
log.Errorf("unable to load session info for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
// Now that we've established that this session does not exist in the
// database, retrieve the sweep address that will be given to the
// client. This address is to be included by the client when signing
// sweep transactions destined for this tower, if its negotiated output
// is not dust.
rewardAddress, err := s.cfg.NewAddress()
if err != nil {
log.Errorf("unable to generate reward addr for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
// Construct the pkscript the client should pay to when signing justice
// transactions for this session.
rewardScript, err := txscript.PayToAddrScript(rewardAddress)
if err != nil {
log.Errorf("unable to generate reward script for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
// Ensure that the requested blob type is supported by our tower.
if !blob.IsSupportedType(req.BlobType) {
log.Debugf("Rejecting CreateSession from %s, unsupported blob "+
"type %s", id, req.BlobType)
return s.replyCreateSession(
peer, id, wtwire.CreateSessionCodeRejectBlobType, nil,
)
}
// TODO(conner): create invoice for upfront payment
// Assemble the session info using the agreed upon parameters, reward
// address, and session id.
info := wtdb.SessionInfo{
ID: *id,
Policy: wtpolicy.Policy{
BlobType: req.BlobType,
MaxUpdates: req.MaxUpdates,
RewardBase: req.RewardBase,
RewardRate: req.RewardRate,
SweepFeeRate: req.SweepFeeRate,
},
RewardAddress: rewardScript,
}
// Insert the session info into the watchtower's database. If
// successful, the session will now be ready for use.
err = s.cfg.DB.InsertSessionInfo(&info)
if err != nil {
log.Errorf("unable to create session for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeTemporaryFailure, nil,
)
}
log.Infof("Accepted session for %s", id)
return s.replyCreateSession(
peer, id, wtwire.CodeOK, rewardScript,
)
}
// handleStateUpdate processes a StateUpdate message request from a client. An
// attempt will be made to insert the update into the db, where it is validated
// against the client's session. The possible errors are then mapped back to
// StateUpdateCodes specified by the watchtower wire protocol, and sent back
// using a StateUpdateReply message.
func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
update *wtwire.StateUpdate) error {
var (
lastApplied uint16
failCode wtwire.ErrorCode
err error
)
sessionUpdate := wtdb.SessionStateUpdate{
ID: *id,
Hint: update.Hint,
SeqNum: update.SeqNum,
LastApplied: update.LastApplied,
EncryptedBlob: update.EncryptedBlob,
}
lastApplied, err = s.cfg.DB.InsertStateUpdate(&sessionUpdate)
switch {
case err == nil:
log.Debugf("State update %d accepted for %s",
update.SeqNum, id)
failCode = wtwire.CodeOK
// Return a permanent failure if a client tries to send an update for
// which we have no session.
case err == wtdb.ErrSessionNotFound:
failCode = wtwire.CodePermanentFailure
case err == wtdb.ErrSeqNumAlreadyApplied:
failCode = wtwire.CodePermanentFailure
// TODO(conner): remove session state for protocol
// violation. Could also double as clean up method for
// session-related state.
case err == wtdb.ErrLastAppliedReversion:
failCode = wtwire.StateUpdateCodeClientBehind
case err == wtdb.ErrSessionConsumed:
failCode = wtwire.StateUpdateCodeMaxUpdatesExceeded
case err == wtdb.ErrUpdateOutOfOrder:
failCode = wtwire.StateUpdateCodeSeqNumOutOfOrder
default:
failCode = wtwire.CodeTemporaryFailure
}
if s.cfg.NoAckUpdates {
return &connFailure{
ID: *id,
Code: uint16(failCode),
}
}
return s.replyStateUpdate(
peer, id, failCode, lastApplied,
)
}
// connFailure is a default error used when a request failed with a non-zero // connFailure is a default error used when a request failed with a non-zero
// error code. // error code.
@ -493,66 +293,6 @@ func (f *connFailure) Error() string {
) )
} }
// replyCreateSession sends a response to a CreateSession from a client. If the
// status code in the reply is OK, the error from the write will be bubbled up.
// Otherwise, this method returns a connection error to ensure we don't continue
// communication with the client.
func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
code wtwire.ErrorCode, data []byte) error {
msg := &wtwire.CreateSessionReply{
Code: code,
Data: data,
}
err := s.sendMessage(peer, msg)
if err != nil {
log.Errorf("unable to send CreateSessionReply to %s", id)
}
// Return the write error if the request succeeded.
if code == wtwire.CodeOK {
return err
}
// Otherwise the request failed, return a connection failure to
// disconnect the client.
return &connFailure{
ID: *id,
Code: uint16(code),
}
}
// replyStateUpdate sends a response to a StateUpdate from a client. If the
// status code in the reply is OK, the error from the write will be bubbled up.
// Otherwise, this method returns a connection error to ensure we don't continue
// communication with the client.
func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID,
code wtwire.StateUpdateCode, lastApplied uint16) error {
msg := &wtwire.StateUpdateReply{
Code: code,
LastApplied: lastApplied,
}
err := s.sendMessage(peer, msg)
if err != nil {
log.Errorf("unable to send StateUpdateReply to %s", id)
}
// Return the write error if the request succeeded.
if code == wtwire.CodeOK {
return err
}
// Otherwise the request failed, return a connection failure to
// disconnect the client.
return &connFailure{
ID: *id,
Code: uint16(code),
}
}
// readMessage receives and parses the next message from the given Peer. An // readMessage receives and parses the next message from the given Peer. An
// error is returned if a message is not received before the server's read // error is returned if a message is not received before the server's read
// timeout, the read off the wire failed, or the message could not be // timeout, the read off the wire failed, or the message could not be

View file

@ -226,13 +226,13 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
// Create a new client and connect to server. // Create a new client and connect to server.
peerPub := randPubKey(t) peerPub := randPubKey(t)
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0) peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, s, peer, test.initMsg, timeoutDuration)
// Send the CreateSession message, and wait for a reply. // Send the CreateSession message, and wait for a reply.
sendMsg(t, i, test.createMsg, peer, timeoutDuration) sendMsg(t, test.createMsg, peer, timeoutDuration)
reply := recvReply( reply := recvReply(
t, i, "CreateSessionReply", peer, timeoutDuration, t, "MsgCreateSessionReply", peer, timeoutDuration,
).(*wtwire.CreateSessionReply) ).(*wtwire.CreateSessionReply)
// Verify that the server's response matches our expectation. // Verify that the server's response matches our expectation.
@ -254,13 +254,13 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
// Simulate a peer with the same session id connection to the server // Simulate a peer with the same session id connection to the server
// again. // again.
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, s, peer, test.initMsg, timeoutDuration)
// Send the _same_ CreateSession message as the first attempt. // Send the _same_ CreateSession message as the first attempt.
sendMsg(t, i, test.createMsg, peer, timeoutDuration) sendMsg(t, test.createMsg, peer, timeoutDuration)
reply = recvReply( reply = recvReply(
t, i, "CreateSessionReply", peer, timeoutDuration, t, "MsgCreateSessionReply", peer, timeoutDuration,
).(*wtwire.CreateSessionReply) ).(*wtwire.CreateSessionReply)
// Ensure that the server's reply matches our expected response for a // Ensure that the server's reply matches our expected response for a
@ -550,14 +550,14 @@ var stateUpdateTests = []stateUpdateTestCase{
func TestServerStateUpdates(t *testing.T) { func TestServerStateUpdates(t *testing.T) {
t.Parallel() t.Parallel()
for i, test := range stateUpdateTests { for _, test := range stateUpdateTests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
testServerStateUpdates(t, i, test) testServerStateUpdates(t, test)
}) })
} }
} }
func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) { func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) {
const timeoutDuration = 100 * time.Millisecond const timeoutDuration = 100 * time.Millisecond
s := initServer(t, nil, timeoutDuration) s := initServer(t, nil, timeoutDuration)
@ -568,17 +568,17 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
// Create a new client and connect to the server. // Create a new client and connect to the server.
peerPub := randPubKey(t) peerPub := randPubKey(t)
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0) peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, s, peer, test.initMsg, timeoutDuration)
// Register a session for this client to use in the subsequent tests. // Register a session for this client to use in the subsequent tests.
sendMsg(t, i, test.createMsg, peer, timeoutDuration) sendMsg(t, test.createMsg, peer, timeoutDuration)
initReply := recvReply( initReply := recvReply(
t, i, "CreateSessionReply", peer, timeoutDuration, t, "MsgCreateSessionReply", peer, timeoutDuration,
).(*wtwire.CreateSessionReply) ).(*wtwire.CreateSessionReply)
// Fail if the server rejected our proposed CreateSession message. // Fail if the server rejected our proposed CreateSession message.
if initReply.Code != wtwire.CodeOK { if initReply.Code != wtwire.CodeOK {
t.Fatalf("[test %d] server rejected session init", i) t.Fatalf("server rejected session init")
} }
// Check that the server closed the connection used to register the // Check that the server closed the connection used to register the
@ -588,7 +588,7 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
// Now that the original connection has been closed, connect a new // Now that the original connection has been closed, connect a new
// client with the same session id. // client with the same session id.
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, s, peer, test.initMsg, timeoutDuration)
// Send the intended StateUpdate messages in series. // Send the intended StateUpdate messages in series.
for j, update := range test.updates { for j, update := range test.updates {
@ -599,21 +599,21 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
assertConnClosed(t, peer, 2*timeoutDuration) assertConnClosed(t, peer, 2*timeoutDuration)
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, i, s, peer, test.initMsg, timeoutDuration) connect(t, s, peer, test.initMsg, timeoutDuration)
continue continue
} }
// Send the state update and verify it against our expected // Send the state update and verify it against our expected
// response. // response.
sendMsg(t, i, update, peer, timeoutDuration) sendMsg(t, update, peer, timeoutDuration)
reply := recvReply( reply := recvReply(
t, i, "StateUpdateReply", peer, timeoutDuration, t, "MsgStateUpdateReply", peer, timeoutDuration,
).(*wtwire.StateUpdateReply) ).(*wtwire.StateUpdateReply)
if !reflect.DeepEqual(reply, test.replies[j]) { if !reflect.DeepEqual(reply, test.replies[j]) {
t.Fatalf("[test %d, update %d] expected reply "+ t.Fatalf("[update %d] expected reply "+
"%v, got %d", i, j, "%v, got %d", j,
test.replies[j], reply) test.replies[j], reply)
} }
} }
@ -622,16 +622,148 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
assertConnClosed(t, peer, 2*timeoutDuration) assertConnClosed(t, peer, 2*timeoutDuration)
} }
func connect(t *testing.T, i int, s wtserver.Interface, peer *wtmock.MockPeer, // TestServerDeleteSession asserts the response to a DeleteSession request, and
// checking that the proper error is returned when the session doesn't exist and
// that a successful deletion does not disrupt other sessions.
func TestServerDeleteSession(t *testing.T) {
db := wtdb.NewMockDB()
localPub := randPubKey(t)
// Initialize two distinct peers with different session ids.
peerPub1 := randPubKey(t)
peerPub2 := randPubKey(t)
id1 := wtdb.NewSessionIDFromPubKey(peerPub1)
id2 := wtdb.NewSessionIDFromPubKey(peerPub2)
// Create closure to simplify assertions on session existence with the
// server's database.
hasSession := func(t *testing.T, id *wtdb.SessionID, shouldHave bool) {
t.Helper()
_, err := db.GetSessionInfo(id)
switch {
case shouldHave && err != nil:
t.Fatalf("expected server to have session %s, got: %v",
id, err)
case !shouldHave && err != wtdb.ErrSessionNotFound:
t.Fatalf("expected ErrSessionNotFound for session %s, "+
"got: %v", id, err)
}
}
initMsg := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
)
createSession := &wtwire.CreateSession{
BlobType: blob.TypeDefault,
MaxUpdates: 1000,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
}
const timeoutDuration = 100 * time.Millisecond
s := initServer(t, db, timeoutDuration)
defer s.Stop()
// Create a session for peer2 so that the server's db isn't completely
// empty.
peer2 := wtmock.NewMockPeer(localPub, peerPub2, nil, 0)
connect(t, s, peer2, initMsg, timeoutDuration)
sendMsg(t, createSession, peer2, timeoutDuration)
assertConnClosed(t, peer2, 2*timeoutDuration)
// Our initial assertions are that peer2 has a valid session, but peer1
// has not created one.
hasSession(t, &id1, false)
hasSession(t, &id2, true)
peer1Msgs := []struct {
send wtwire.Message
recv wtwire.Message
assert func(t *testing.T)
}{
{
// Deleting unknown session should fail.
send: &wtwire.DeleteSession{},
recv: &wtwire.DeleteSessionReply{
Code: wtwire.DeleteSessionCodeNotFound,
},
assert: func(t *testing.T) {
// Peer2 should still be only session.
hasSession(t, &id1, false)
hasSession(t, &id2, true)
},
},
{
// Create session for peer1.
send: createSession,
recv: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: addrScript,
},
assert: func(t *testing.T) {
// Both peers should have sessions.
hasSession(t, &id1, true)
hasSession(t, &id2, true)
},
},
{
// Delete peer1's session.
send: &wtwire.DeleteSession{},
recv: &wtwire.DeleteSessionReply{
Code: wtwire.CodeOK,
},
assert: func(t *testing.T) {
// Peer1's session should have been removed.
hasSession(t, &id1, false)
hasSession(t, &id2, true)
},
},
}
// Now as peer1, process the canned messages defined above. This will:
// 1. Try to delete an unknown session and get a not found error code.
// 2. Create a new session using the same parameters as peer2.
// 3. Delete the newly created session and get an OK.
for _, msg := range peer1Msgs {
peer1 := wtmock.NewMockPeer(localPub, peerPub1, nil, 0)
connect(t, s, peer1, initMsg, timeoutDuration)
sendMsg(t, msg.send, peer1, timeoutDuration)
reply := recvReply(
t, msg.recv.MsgType().String(), peer1, timeoutDuration,
)
if !reflect.DeepEqual(reply, msg.recv) {
t.Fatalf("expected reply: %v, got: %v", msg.recv, reply)
}
assertConnClosed(t, peer1, 2*timeoutDuration)
// Invoke assertions after completing the request/response
// dance.
msg.assert(t)
}
}
func connect(t *testing.T, s wtserver.Interface, peer *wtmock.MockPeer,
initMsg *wtwire.Init, timeout time.Duration) { initMsg *wtwire.Init, timeout time.Duration) {
t.Helper()
s.InboundPeerConnected(peer) s.InboundPeerConnected(peer)
sendMsg(t, i, initMsg, peer, timeout) sendMsg(t, initMsg, peer, timeout)
recvReply(t, i, "Init", peer, timeout) recvReply(t, "MsgInit", peer, timeout)
} }
// sendMsg sends a wtwire.Message message via a wtmock.MockPeer. // sendMsg sends a wtwire.Message message via a wtmock.MockPeer.
func sendMsg(t *testing.T, i int, msg wtwire.Message, func sendMsg(t *testing.T, msg wtwire.Message,
peer *wtmock.MockPeer, timeout time.Duration) { peer *wtmock.MockPeer, timeout time.Duration) {
t.Helper() t.Helper()
@ -639,22 +771,22 @@ func sendMsg(t *testing.T, i int, msg wtwire.Message,
var b bytes.Buffer var b bytes.Buffer
_, err := wtwire.WriteMessage(&b, msg, 0) _, err := wtwire.WriteMessage(&b, msg, 0)
if err != nil { if err != nil {
t.Fatalf("[test %d] unable to encode %T message: %v", t.Fatalf("unable to encode %T message: %v",
i, msg, err) msg, err)
} }
select { select {
case peer.IncomingMsgs <- b.Bytes(): case peer.IncomingMsgs <- b.Bytes():
case <-time.After(2 * timeout): case <-time.After(2 * timeout):
t.Fatalf("[test %d] unable to send %T message", i, msg) t.Fatalf("unable to send %T message", msg)
} }
} }
// recvReply receives a message from the server, and parses it according to // recvReply receives a message from the server, and parses it according to
// expected reply type. The supported replies are CreateSessionReply and // expected reply type. The supported replies are CreateSessionReply and
// StateUpdateReply. // StateUpdateReply.
func recvReply(t *testing.T, i int, name string, func recvReply(t *testing.T, name string, peer *wtmock.MockPeer,
peer *wtmock.MockPeer, timeout time.Duration) wtwire.Message { timeout time.Duration) wtwire.Message {
t.Helper() t.Helper()
@ -667,29 +799,34 @@ func recvReply(t *testing.T, i int, name string,
case b := <-peer.OutgoingMsgs: case b := <-peer.OutgoingMsgs:
msg, err = wtwire.ReadMessage(bytes.NewReader(b), 0) msg, err = wtwire.ReadMessage(bytes.NewReader(b), 0)
if err != nil { if err != nil {
t.Fatalf("[test %d] unable to decode server "+ t.Fatalf("unable to decode server "+
"reply: %v", i, err) "reply: %v", err)
} }
case <-time.After(2 * timeout): case <-time.After(2 * timeout):
t.Fatalf("[test %d] server did not reply", i) t.Fatalf("server did not reply")
} }
switch name { switch name {
case "Init": case "MsgInit":
if _, ok := msg.(*wtwire.Init); !ok { if _, ok := msg.(*wtwire.Init); !ok {
t.Fatalf("[test %d] expected %s reply "+ t.Fatalf("expected %s reply message, "+
"message, got %T", i, name, msg) "got %T", name, msg)
} }
case "CreateSessionReply": case "MsgCreateSessionReply":
if _, ok := msg.(*wtwire.CreateSessionReply); !ok { if _, ok := msg.(*wtwire.CreateSessionReply); !ok {
t.Fatalf("[test %d] expected %s reply "+ t.Fatalf("expected %s reply message, "+
"message, got %T", i, name, msg) "got %T", name, msg)
} }
case "StateUpdateReply": case "MsgStateUpdateReply":
if _, ok := msg.(*wtwire.StateUpdateReply); !ok { if _, ok := msg.(*wtwire.StateUpdateReply); !ok {
t.Fatalf("[test %d] expected %s reply "+ t.Fatalf("expected %s reply message, "+
"message, got %T", i, name, msg) "got %T", name, msg)
}
case "MsgDeleteSessionReply":
if _, ok := msg.(*wtwire.DeleteSessionReply); !ok {
t.Fatalf("expected %s reply message, "+
"got %T", name, msg)
} }
} }

View file

@ -0,0 +1,157 @@
package wtserver
import (
"fmt"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
// handleStateUpdates processes a stream of StateUpdate requests from the
// client. The provided update should be the first such update read, subsequent
// updates will be consumed if the peer does not signal IsComplete on a
// particular update.
func (s *Server) handleStateUpdates(peer Peer, id *wtdb.SessionID,
update *wtwire.StateUpdate) error {
// Set the current update to the first update read off the wire.
// Additional updates will be read if this value is set to nil after
// processing the first.
var curUpdate = update
for {
// If this is not the first update, read the next state update
// from the peer.
if curUpdate == nil {
nextMsg, err := s.readMessage(peer)
if err != nil {
return err
}
var ok bool
curUpdate, ok = nextMsg.(*wtwire.StateUpdate)
if !ok {
return fmt.Errorf("client sent %T after "+
"StateUpdate", nextMsg)
}
}
// Try to accept the state update from the client.
err := s.handleStateUpdate(peer, id, curUpdate)
if err != nil {
return err
}
// If the client signals that this is last StateUpdate
// message, we can disconnect the client.
if curUpdate.IsComplete == 1 {
return nil
}
// Reset the current update to read subsequent updates in the
// stream.
curUpdate = nil
select {
case <-s.quit:
return ErrServerExiting
default:
}
}
}
// handleStateUpdate processes a StateUpdate message request from a client. An
// attempt will be made to insert the update into the db, where it is validated
// against the client's session. The possible errors are then mapped back to
// StateUpdateCodes specified by the watchtower wire protocol, and sent back
// using a StateUpdateReply message.
func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
update *wtwire.StateUpdate) error {
var (
lastApplied uint16
failCode wtwire.ErrorCode
err error
)
sessionUpdate := wtdb.SessionStateUpdate{
ID: *id,
Hint: update.Hint,
SeqNum: update.SeqNum,
LastApplied: update.LastApplied,
EncryptedBlob: update.EncryptedBlob,
}
lastApplied, err = s.cfg.DB.InsertStateUpdate(&sessionUpdate)
switch {
case err == nil:
log.Debugf("State update %d accepted for %s",
update.SeqNum, id)
failCode = wtwire.CodeOK
// Return a permanent failure if a client tries to send an update for
// which we have no session.
case err == wtdb.ErrSessionNotFound:
failCode = wtwire.CodePermanentFailure
case err == wtdb.ErrSeqNumAlreadyApplied:
failCode = wtwire.CodePermanentFailure
// TODO(conner): remove session state for protocol
// violation. Could also double as clean up method for
// session-related state.
case err == wtdb.ErrLastAppliedReversion:
failCode = wtwire.StateUpdateCodeClientBehind
case err == wtdb.ErrSessionConsumed:
failCode = wtwire.StateUpdateCodeMaxUpdatesExceeded
case err == wtdb.ErrUpdateOutOfOrder:
failCode = wtwire.StateUpdateCodeSeqNumOutOfOrder
default:
failCode = wtwire.CodeTemporaryFailure
}
if s.cfg.NoAckUpdates {
return &connFailure{
ID: *id,
Code: uint16(failCode),
}
}
return s.replyStateUpdate(
peer, id, failCode, lastApplied,
)
}
// replyStateUpdate sends a response to a StateUpdate from a client. If the
// status code in the reply is OK, the error from the write will be bubbled up.
// Otherwise, this method returns a connection error to ensure we don't continue
// communication with the client.
func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID,
code wtwire.StateUpdateCode, lastApplied uint16) error {
msg := &wtwire.StateUpdateReply{
Code: code,
LastApplied: lastApplied,
}
err := s.sendMessage(peer, msg)
if err != nil {
log.Errorf("unable to send StateUpdateReply to %s", id)
}
// Return the write error if the request succeeded.
if code == wtwire.CodeOK {
return err
}
// Otherwise the request failed, return a connection failure to
// disconnect the client.
return &connFailure{
ID: *id,
Code: uint16(code),
}
}

View file

@ -0,0 +1,45 @@
package wtwire
import "io"
// DeleteSession is sent from the client to the tower to signal that the tower
// can delete all session state for the session key used to authenticate the
// brontide connection. This should be done by the client once all channels that
// have state updates in the session have been resolved on-chain.
type DeleteSession struct{}
// Compile-time constraint to ensure DeleteSession implements the wtwire.Message
// interface.
var _ Message = (*DeleteSession)(nil)
// Decode deserializes a serialized DeleteSession message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *DeleteSession) Decode(r io.Reader, pver uint32) error {
return nil
}
// Encode serializes the target DeleteSession message into the passed io.Writer
// observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *DeleteSession) Encode(w io.Writer, pver uint32) error {
return nil
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the wtwire.Message interface.
func (m *DeleteSession) MsgType() MessageType {
return MsgDeleteSession
}
// MaxPayloadLength returns the maximum allowed payload size for a DeleteSession
// message observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *DeleteSession) MaxPayloadLength(uint32) uint32 {
return 0
}

View file

@ -0,0 +1,66 @@
package wtwire
import "io"
// DeleteSessionCode is an error code returned by a watchtower in response to a
// DeleteSession message.
type DeleteSessionCode = ErrorCode
const (
// DeleteSessionCodeNotFound is returned when the watchtower does not
// know of the requested session. This may indicate an error on the
// client side, or that the tower had already deleted the session in a
// prior request that the client may not have received.
DeleteSessionCodeNotFound DeleteSessionCode = 80
// TODO(conner): add String method after wtclient is merged
)
// DeleteSessionReply is a message sent in response to a client's DeleteSession
// request. The message indicates whether or not the deletion was a success or
// failure.
type DeleteSessionReply struct {
// Code will be non-zero if the watchtower was not able to delete the
// requested session.
Code DeleteSessionCode
}
// A compile time check to ensure DeleteSessionReply implements the
// wtwire.Message interface.
var _ Message = (*DeleteSessionReply)(nil)
// Decode deserializes a serialized DeleteSessionReply message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *DeleteSessionReply) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&m.Code,
)
}
// Encode serializes the target DeleteSessionReply into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the wtwire.Message interface.
func (m *DeleteSessionReply) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
m.Code,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the wtwire.Message interface.
func (m *DeleteSessionReply) MsgType() MessageType {
return MsgDeleteSessionReply
}
// MaxPayloadLength returns the maximum allowed payload size for a
// DeleteSessionReply complete message observing the specified protocol version.
//
// This is part of the wtwire.Message interface.
func (m *DeleteSessionReply) MaxPayloadLength(uint32) uint32 {
return 2
}

View file

@ -40,6 +40,13 @@ const (
// MsgStateUpdateReply identifies an encoded StateUpdateReply message. // MsgStateUpdateReply identifies an encoded StateUpdateReply message.
MsgStateUpdateReply MessageType = 305 MsgStateUpdateReply MessageType = 305
// MsgDeleteSession identifies an encoded DeleteSession message.
MsgDeleteSession MessageType = 306
// MsgDeleteSessionReply identifies an encoded DeleteSessionReply
// message.
MsgDeleteSessionReply MessageType = 307
) )
// String returns a human readable description of the message type. // String returns a human readable description of the message type.
@ -55,6 +62,10 @@ func (m MessageType) String() string {
return "MsgStateUpdate" return "MsgStateUpdate"
case MsgStateUpdateReply: case MsgStateUpdateReply:
return "MsgStateUpdateReply" return "MsgStateUpdateReply"
case MsgDeleteSession:
return "MsgDeleteSession"
case MsgDeleteSessionReply:
return "MsgDeleteSessionReply"
case MsgError: case MsgError:
return "Error" return "Error"
default: default:
@ -97,6 +108,10 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
msg = &StateUpdate{} msg = &StateUpdate{}
case MsgStateUpdateReply: case MsgStateUpdateReply:
msg = &StateUpdateReply{} msg = &StateUpdateReply{}
case MsgDeleteSession:
msg = &DeleteSession{}
case MsgDeleteSessionReply:
msg = &DeleteSessionReply{}
case MsgError: case MsgError:
msg = &Error{} msg = &Error{}
default: default:

View file

@ -126,6 +126,18 @@ func TestWatchtowerWireProtocol(t *testing.T) {
return mainScenario(&m) return mainScenario(&m)
}, },
}, },
{
msgType: wtwire.MsgDeleteSession,
scenario: func(m wtwire.DeleteSession) bool {
return mainScenario(&m)
},
},
{
msgType: wtwire.MsgDeleteSessionReply,
scenario: func(m wtwire.DeleteSessionReply) bool {
return mainScenario(&m)
},
},
{ {
msgType: wtwire.MsgError, msgType: wtwire.MsgError,
scenario: func(m wtwire.Error) bool { scenario: func(m wtwire.Error) bool {