mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-23 14:40:30 +01:00
Merge pull request #2779 from cfromknecht/wtserver-delete-session
watchtower/wtserver: add DeleteSession request
This commit is contained in:
commit
9143067014
11 changed files with 725 additions and 329 deletions
|
@ -70,6 +70,33 @@ func (db *MockDB) InsertSessionInfo(info *SessionInfo) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *MockDB) DeleteSession(target SessionID) error {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
// Fail if the session doesn't exit.
|
||||||
|
if _, ok := db.sessions[target]; !ok {
|
||||||
|
return ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the target session.
|
||||||
|
delete(db.sessions, target)
|
||||||
|
|
||||||
|
// Remove the state updates for any blobs stored under the target
|
||||||
|
// session identifier.
|
||||||
|
for hint, sessionUpdates := range db.blobs {
|
||||||
|
delete(sessionUpdates, target)
|
||||||
|
|
||||||
|
//If this was the last state update, we can also remove the hint
|
||||||
|
//that would map to an empty set.
|
||||||
|
if len(sessionUpdates) == 0 {
|
||||||
|
delete(db.blobs, hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
|
func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
|
||||||
db.mu.Lock()
|
db.mu.Lock()
|
||||||
defer db.mu.Unlock()
|
defer db.mu.Unlock()
|
||||||
|
|
136
watchtower/wtserver/create_session.go
Normal file
136
watchtower/wtserver/create_session.go
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
package wtserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/btcsuite/btcd/txscript"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleCreateSession processes a CreateSession message from the peer, and returns
|
||||||
|
// a CreateSessionReply in response. This method will only succeed if no existing
|
||||||
|
// session info is known about the session id. If an existing session is found,
|
||||||
|
// the reward address is returned in case the client lost our reply.
|
||||||
|
func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
||||||
|
req *wtwire.CreateSession) error {
|
||||||
|
|
||||||
|
// TODO(conner): validate accept against policy
|
||||||
|
|
||||||
|
// Query the db for session info belonging to the client's session id.
|
||||||
|
existingInfo, err := s.cfg.DB.GetSessionInfo(id)
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// We already have a session corresponding to this session id, return an
|
||||||
|
// error signaling that it already exists in our database. We return the
|
||||||
|
// reward address to the client in case they were not able to process
|
||||||
|
// our reply earlier.
|
||||||
|
case err == nil:
|
||||||
|
log.Debugf("Already have session for %s", id)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CreateSessionCodeAlreadyExists,
|
||||||
|
existingInfo.RewardAddress,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Some other database error occurred, return a temporary failure.
|
||||||
|
case err != wtdb.ErrSessionNotFound:
|
||||||
|
log.Errorf("unable to load session info for %s", id)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we've established that this session does not exist in the
|
||||||
|
// database, retrieve the sweep address that will be given to the
|
||||||
|
// client. This address is to be included by the client when signing
|
||||||
|
// sweep transactions destined for this tower, if its negotiated output
|
||||||
|
// is not dust.
|
||||||
|
rewardAddress, err := s.cfg.NewAddress()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to generate reward addr for %s", id)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the pkscript the client should pay to when signing justice
|
||||||
|
// transactions for this session.
|
||||||
|
rewardScript, err := txscript.PayToAddrScript(rewardAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to generate reward script for %s", id)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that the requested blob type is supported by our tower.
|
||||||
|
if !blob.IsSupportedType(req.BlobType) {
|
||||||
|
log.Debugf("Rejecting CreateSession from %s, unsupported blob "+
|
||||||
|
"type %s", id, req.BlobType)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CreateSessionCodeRejectBlobType, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(conner): create invoice for upfront payment
|
||||||
|
|
||||||
|
// Assemble the session info using the agreed upon parameters, reward
|
||||||
|
// address, and session id.
|
||||||
|
info := wtdb.SessionInfo{
|
||||||
|
ID: *id,
|
||||||
|
Policy: wtpolicy.Policy{
|
||||||
|
BlobType: req.BlobType,
|
||||||
|
MaxUpdates: req.MaxUpdates,
|
||||||
|
RewardBase: req.RewardBase,
|
||||||
|
RewardRate: req.RewardRate,
|
||||||
|
SweepFeeRate: req.SweepFeeRate,
|
||||||
|
},
|
||||||
|
RewardAddress: rewardScript,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the session info into the watchtower's database. If
|
||||||
|
// successful, the session will now be ready for use.
|
||||||
|
err = s.cfg.DB.InsertSessionInfo(&info)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to create session for %s", id)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeTemporaryFailure, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Accepted session for %s", id)
|
||||||
|
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeOK, rewardScript,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// replyCreateSession sends a response to a CreateSession from a client. If the
|
||||||
|
// status code in the reply is OK, the error from the write will be bubbled up.
|
||||||
|
// Otherwise, this method returns a connection error to ensure we don't continue
|
||||||
|
// communication with the client.
|
||||||
|
func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
|
||||||
|
code wtwire.ErrorCode, data []byte) error {
|
||||||
|
|
||||||
|
msg := &wtwire.CreateSessionReply{
|
||||||
|
Code: code,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.sendMessage(peer, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to send CreateSessionReply to %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the write error if the request succeeded.
|
||||||
|
if code == wtwire.CodeOK {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the request failed, return a connection failure to
|
||||||
|
// disconnect the client.
|
||||||
|
return &connFailure{
|
||||||
|
ID: *id,
|
||||||
|
Code: uint16(code),
|
||||||
|
}
|
||||||
|
}
|
57
watchtower/wtserver/delete_session.go
Normal file
57
watchtower/wtserver/delete_session.go
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
package wtserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleDeleteSession processes a DeleteSession request for a client with given
|
||||||
|
// SessionID. The id is assumed to have been previously authenticated by the
|
||||||
|
// brontide connection.
|
||||||
|
func (s *Server) handleDeleteSession(peer Peer, id *wtdb.SessionID) error {
|
||||||
|
var failCode wtwire.DeleteSessionCode
|
||||||
|
|
||||||
|
// Delete all session data associated with id.
|
||||||
|
err := s.cfg.DB.DeleteSession(*id)
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
failCode = wtwire.CodeOK
|
||||||
|
|
||||||
|
log.Debugf("Session %s deleted", id)
|
||||||
|
|
||||||
|
case err == wtdb.ErrSessionNotFound:
|
||||||
|
failCode = wtwire.DeleteSessionCodeNotFound
|
||||||
|
|
||||||
|
default:
|
||||||
|
failCode = wtwire.CodeTemporaryFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.replyDeleteSession(peer, id, failCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// replyDeleteSession sends a DeleteSessionReply back to the peer containing the
|
||||||
|
// error code resulting from processes a DeleteSession request.
|
||||||
|
func (s *Server) replyDeleteSession(peer Peer, id *wtdb.SessionID,
|
||||||
|
code wtwire.DeleteSessionCode) error {
|
||||||
|
|
||||||
|
msg := &wtwire.DeleteSessionReply{
|
||||||
|
Code: code,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.sendMessage(peer, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Unable to send DeleteSessionReply to %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the write error if the request succeeded.
|
||||||
|
if code == wtwire.CodeOK {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the request failed, return a connection failure to
|
||||||
|
// disconnect the client.
|
||||||
|
return &connFailure{
|
||||||
|
ID: *id,
|
||||||
|
Code: uint16(code),
|
||||||
|
}
|
||||||
|
}
|
|
@ -63,4 +63,8 @@ type DB interface {
|
||||||
// validates the update against the current SessionInfo stored under the
|
// validates the update against the current SessionInfo stored under the
|
||||||
// update's session id..
|
// update's session id..
|
||||||
InsertStateUpdate(*wtdb.SessionStateUpdate) (uint16, error)
|
InsertStateUpdate(*wtdb.SessionStateUpdate) (uint16, error)
|
||||||
|
|
||||||
|
// DeleteSession removes all data associated with a particular session
|
||||||
|
// id from the tower's database.
|
||||||
|
DeleteSession(wtdb.SessionID) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,12 +11,9 @@ import (
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
"github.com/btcsuite/btcd/connmgr"
|
"github.com/btcsuite/btcd/connmgr"
|
||||||
"github.com/btcsuite/btcd/txscript"
|
|
||||||
"github.com/btcsuite/btcutil"
|
"github.com/btcsuite/btcutil"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,6 +21,10 @@ var (
|
||||||
// ErrPeerAlreadyConnected signals that a peer with the same session id
|
// ErrPeerAlreadyConnected signals that a peer with the same session id
|
||||||
// is already active within the server.
|
// is already active within the server.
|
||||||
ErrPeerAlreadyConnected = errors.New("peer already connected")
|
ErrPeerAlreadyConnected = errors.New("peer already connected")
|
||||||
|
|
||||||
|
// ErrServerExiting signals that a request could not be processed
|
||||||
|
// because the server has been requested to shut down.
|
||||||
|
ErrServerExiting = errors.New("server shutting down")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config abstracts the primary components and dependencies of the server.
|
// Config abstracts the primary components and dependencies of the server.
|
||||||
|
@ -242,242 +243,41 @@ func (s *Server) handleClient(peer Peer) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// stateUpdateOnlyMode will become true if the client's first message is
|
|
||||||
// a StateUpdate. If instead, it is a CreateSession, this method will exit
|
|
||||||
// immediately after replying. We track this to ensure that the client
|
|
||||||
// can't send a CreateSession after having already sent a StateUpdate.
|
|
||||||
var stateUpdateOnlyMode bool
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-s.quit:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
nextMsg, err := s.readMessage(peer)
|
nextMsg, err := s.readMessage(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to read watchtower msg from %x: %v",
|
log.Errorf("Unable to read watchtower msg from %s: %v",
|
||||||
id[:], err)
|
id, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the request according to the message's type.
|
|
||||||
switch msg := nextMsg.(type) {
|
switch msg := nextMsg.(type) {
|
||||||
|
|
||||||
// A CreateSession indicates a request to establish a new session
|
|
||||||
// with our watchtower.
|
|
||||||
case *wtwire.CreateSession:
|
case *wtwire.CreateSession:
|
||||||
// Ensure CreateSession can only be sent as the first
|
|
||||||
// message.
|
|
||||||
if stateUpdateOnlyMode {
|
|
||||||
log.Errorf("client %x sent CreateSession after "+
|
|
||||||
"StateUpdate", id)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to open a new session for this client.
|
// Attempt to open a new session for this client.
|
||||||
err := s.handleCreateSession(peer, &id, msg)
|
err = s.handleCreateSession(peer, &id, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to handle CreateSession "+
|
log.Errorf("Unable to handle CreateSession "+
|
||||||
"from %s: %v", id, err)
|
"from %s: %v", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exit after replying to CreateSession.
|
case *wtwire.DeleteSession:
|
||||||
return
|
err = s.handleDeleteSession(peer, &id)
|
||||||
|
|
||||||
// A StateUpdate indicates an existing client attempting to
|
|
||||||
// back-up a revoked commitment state.
|
|
||||||
case *wtwire.StateUpdate:
|
|
||||||
// Try to accept the state update from the client.
|
|
||||||
err := s.handleStateUpdate(peer, &id, msg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to handle StateUpdate "+
|
log.Errorf("Unable to handle DeleteSession "+
|
||||||
"from %s: %v", id, err)
|
"from %s: %v", id, err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the client signals that this is last StateUpdate
|
case *wtwire.StateUpdate:
|
||||||
// message, we can disconnect the client.
|
err = s.handleStateUpdates(peer, &id, msg)
|
||||||
if msg.IsComplete == 1 {
|
if err != nil {
|
||||||
return
|
log.Errorf("Unable to handle StateUpdate "+
|
||||||
|
"from %s: %v", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The client has signaled that more StateUpdates are
|
|
||||||
// yet to come. Enter state-update-only mode to disallow
|
|
||||||
// future sends of CreateSession messages.
|
|
||||||
stateUpdateOnlyMode = true
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
log.Errorf("received unsupported message type: %T "+
|
log.Errorf("Received unsupported message type: %T "+
|
||||||
"from %s", nextMsg, id)
|
"from %s", nextMsg, id)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// handleCreateSession processes a CreateSession message from the peer, and returns
|
|
||||||
// a CreateSessionReply in response. This method will only succeed if no existing
|
|
||||||
// session info is known about the session id. If an existing session is found,
|
|
||||||
// the reward address is returned in case the client lost our reply.
|
|
||||||
func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
|
||||||
req *wtwire.CreateSession) error {
|
|
||||||
|
|
||||||
// TODO(conner): validate accept against policy
|
|
||||||
|
|
||||||
// Query the db for session info belonging to the client's session id.
|
|
||||||
existingInfo, err := s.cfg.DB.GetSessionInfo(id)
|
|
||||||
switch {
|
|
||||||
|
|
||||||
// We already have a session corresponding to this session id, return an
|
|
||||||
// error signaling that it already exists in our database. We return the
|
|
||||||
// reward address to the client in case they were not able to process
|
|
||||||
// our reply earlier.
|
|
||||||
case err == nil:
|
|
||||||
log.Debugf("Already have session for %s", id)
|
|
||||||
return s.replyCreateSession(
|
|
||||||
peer, id, wtwire.CreateSessionCodeAlreadyExists,
|
|
||||||
existingInfo.RewardAddress,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Some other database error occurred, return a temporary failure.
|
|
||||||
case err != wtdb.ErrSessionNotFound:
|
|
||||||
log.Errorf("unable to load session info for %s", id)
|
|
||||||
return s.replyCreateSession(
|
|
||||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we've established that this session does not exist in the
|
|
||||||
// database, retrieve the sweep address that will be given to the
|
|
||||||
// client. This address is to be included by the client when signing
|
|
||||||
// sweep transactions destined for this tower, if its negotiated output
|
|
||||||
// is not dust.
|
|
||||||
rewardAddress, err := s.cfg.NewAddress()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to generate reward addr for %s", id)
|
|
||||||
return s.replyCreateSession(
|
|
||||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Construct the pkscript the client should pay to when signing justice
|
|
||||||
// transactions for this session.
|
|
||||||
rewardScript, err := txscript.PayToAddrScript(rewardAddress)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to generate reward script for %s", id)
|
|
||||||
return s.replyCreateSession(
|
|
||||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that the requested blob type is supported by our tower.
|
|
||||||
if !blob.IsSupportedType(req.BlobType) {
|
|
||||||
log.Debugf("Rejecting CreateSession from %s, unsupported blob "+
|
|
||||||
"type %s", id, req.BlobType)
|
|
||||||
return s.replyCreateSession(
|
|
||||||
peer, id, wtwire.CreateSessionCodeRejectBlobType, nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(conner): create invoice for upfront payment
|
|
||||||
|
|
||||||
// Assemble the session info using the agreed upon parameters, reward
|
|
||||||
// address, and session id.
|
|
||||||
info := wtdb.SessionInfo{
|
|
||||||
ID: *id,
|
|
||||||
Policy: wtpolicy.Policy{
|
|
||||||
BlobType: req.BlobType,
|
|
||||||
MaxUpdates: req.MaxUpdates,
|
|
||||||
RewardBase: req.RewardBase,
|
|
||||||
RewardRate: req.RewardRate,
|
|
||||||
SweepFeeRate: req.SweepFeeRate,
|
|
||||||
},
|
|
||||||
RewardAddress: rewardScript,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert the session info into the watchtower's database. If
|
|
||||||
// successful, the session will now be ready for use.
|
|
||||||
err = s.cfg.DB.InsertSessionInfo(&info)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to create session for %s", id)
|
|
||||||
return s.replyCreateSession(
|
|
||||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Accepted session for %s", id)
|
|
||||||
|
|
||||||
return s.replyCreateSession(
|
|
||||||
peer, id, wtwire.CodeOK, rewardScript,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleStateUpdate processes a StateUpdate message request from a client. An
|
|
||||||
// attempt will be made to insert the update into the db, where it is validated
|
|
||||||
// against the client's session. The possible errors are then mapped back to
|
|
||||||
// StateUpdateCodes specified by the watchtower wire protocol, and sent back
|
|
||||||
// using a StateUpdateReply message.
|
|
||||||
func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
|
|
||||||
update *wtwire.StateUpdate) error {
|
|
||||||
|
|
||||||
var (
|
|
||||||
lastApplied uint16
|
|
||||||
failCode wtwire.ErrorCode
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
sessionUpdate := wtdb.SessionStateUpdate{
|
|
||||||
ID: *id,
|
|
||||||
Hint: update.Hint,
|
|
||||||
SeqNum: update.SeqNum,
|
|
||||||
LastApplied: update.LastApplied,
|
|
||||||
EncryptedBlob: update.EncryptedBlob,
|
|
||||||
}
|
|
||||||
|
|
||||||
lastApplied, err = s.cfg.DB.InsertStateUpdate(&sessionUpdate)
|
|
||||||
switch {
|
|
||||||
case err == nil:
|
|
||||||
log.Debugf("State update %d accepted for %s",
|
|
||||||
update.SeqNum, id)
|
|
||||||
|
|
||||||
failCode = wtwire.CodeOK
|
|
||||||
|
|
||||||
// Return a permanent failure if a client tries to send an update for
|
|
||||||
// which we have no session.
|
|
||||||
case err == wtdb.ErrSessionNotFound:
|
|
||||||
failCode = wtwire.CodePermanentFailure
|
|
||||||
|
|
||||||
case err == wtdb.ErrSeqNumAlreadyApplied:
|
|
||||||
failCode = wtwire.CodePermanentFailure
|
|
||||||
|
|
||||||
// TODO(conner): remove session state for protocol
|
|
||||||
// violation. Could also double as clean up method for
|
|
||||||
// session-related state.
|
|
||||||
|
|
||||||
case err == wtdb.ErrLastAppliedReversion:
|
|
||||||
failCode = wtwire.StateUpdateCodeClientBehind
|
|
||||||
|
|
||||||
case err == wtdb.ErrSessionConsumed:
|
|
||||||
failCode = wtwire.StateUpdateCodeMaxUpdatesExceeded
|
|
||||||
|
|
||||||
case err == wtdb.ErrUpdateOutOfOrder:
|
|
||||||
failCode = wtwire.StateUpdateCodeSeqNumOutOfOrder
|
|
||||||
|
|
||||||
default:
|
|
||||||
failCode = wtwire.CodeTemporaryFailure
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.cfg.NoAckUpdates {
|
|
||||||
return &connFailure{
|
|
||||||
ID: *id,
|
|
||||||
Code: uint16(failCode),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.replyStateUpdate(
|
|
||||||
peer, id, failCode, lastApplied,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// connFailure is a default error used when a request failed with a non-zero
|
// connFailure is a default error used when a request failed with a non-zero
|
||||||
// error code.
|
// error code.
|
||||||
|
@ -493,66 +293,6 @@ func (f *connFailure) Error() string {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// replyCreateSession sends a response to a CreateSession from a client. If the
|
|
||||||
// status code in the reply is OK, the error from the write will be bubbled up.
|
|
||||||
// Otherwise, this method returns a connection error to ensure we don't continue
|
|
||||||
// communication with the client.
|
|
||||||
func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
|
|
||||||
code wtwire.ErrorCode, data []byte) error {
|
|
||||||
|
|
||||||
msg := &wtwire.CreateSessionReply{
|
|
||||||
Code: code,
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.sendMessage(peer, msg)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to send CreateSessionReply to %s", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the write error if the request succeeded.
|
|
||||||
if code == wtwire.CodeOK {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise the request failed, return a connection failure to
|
|
||||||
// disconnect the client.
|
|
||||||
return &connFailure{
|
|
||||||
ID: *id,
|
|
||||||
Code: uint16(code),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// replyStateUpdate sends a response to a StateUpdate from a client. If the
|
|
||||||
// status code in the reply is OK, the error from the write will be bubbled up.
|
|
||||||
// Otherwise, this method returns a connection error to ensure we don't continue
|
|
||||||
// communication with the client.
|
|
||||||
func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID,
|
|
||||||
code wtwire.StateUpdateCode, lastApplied uint16) error {
|
|
||||||
|
|
||||||
msg := &wtwire.StateUpdateReply{
|
|
||||||
Code: code,
|
|
||||||
LastApplied: lastApplied,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.sendMessage(peer, msg)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to send StateUpdateReply to %s", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the write error if the request succeeded.
|
|
||||||
if code == wtwire.CodeOK {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise the request failed, return a connection failure to
|
|
||||||
// disconnect the client.
|
|
||||||
return &connFailure{
|
|
||||||
ID: *id,
|
|
||||||
Code: uint16(code),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// readMessage receives and parses the next message from the given Peer. An
|
// readMessage receives and parses the next message from the given Peer. An
|
||||||
// error is returned if a message is not received before the server's read
|
// error is returned if a message is not received before the server's read
|
||||||
// timeout, the read off the wire failed, or the message could not be
|
// timeout, the read off the wire failed, or the message could not be
|
||||||
|
|
|
@ -226,13 +226,13 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
|
||||||
// Create a new client and connect to server.
|
// Create a new client and connect to server.
|
||||||
peerPub := randPubKey(t)
|
peerPub := randPubKey(t)
|
||||||
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
connect(t, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
// Send the CreateSession message, and wait for a reply.
|
// Send the CreateSession message, and wait for a reply.
|
||||||
sendMsg(t, i, test.createMsg, peer, timeoutDuration)
|
sendMsg(t, test.createMsg, peer, timeoutDuration)
|
||||||
|
|
||||||
reply := recvReply(
|
reply := recvReply(
|
||||||
t, i, "CreateSessionReply", peer, timeoutDuration,
|
t, "MsgCreateSessionReply", peer, timeoutDuration,
|
||||||
).(*wtwire.CreateSessionReply)
|
).(*wtwire.CreateSessionReply)
|
||||||
|
|
||||||
// Verify that the server's response matches our expectation.
|
// Verify that the server's response matches our expectation.
|
||||||
|
@ -254,13 +254,13 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
|
||||||
// Simulate a peer with the same session id connection to the server
|
// Simulate a peer with the same session id connection to the server
|
||||||
// again.
|
// again.
|
||||||
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
connect(t, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
// Send the _same_ CreateSession message as the first attempt.
|
// Send the _same_ CreateSession message as the first attempt.
|
||||||
sendMsg(t, i, test.createMsg, peer, timeoutDuration)
|
sendMsg(t, test.createMsg, peer, timeoutDuration)
|
||||||
|
|
||||||
reply = recvReply(
|
reply = recvReply(
|
||||||
t, i, "CreateSessionReply", peer, timeoutDuration,
|
t, "MsgCreateSessionReply", peer, timeoutDuration,
|
||||||
).(*wtwire.CreateSessionReply)
|
).(*wtwire.CreateSessionReply)
|
||||||
|
|
||||||
// Ensure that the server's reply matches our expected response for a
|
// Ensure that the server's reply matches our expected response for a
|
||||||
|
@ -550,14 +550,14 @@ var stateUpdateTests = []stateUpdateTestCase{
|
||||||
func TestServerStateUpdates(t *testing.T) {
|
func TestServerStateUpdates(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
for i, test := range stateUpdateTests {
|
for _, test := range stateUpdateTests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
testServerStateUpdates(t, i, test)
|
testServerStateUpdates(t, test)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
|
func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) {
|
||||||
const timeoutDuration = 100 * time.Millisecond
|
const timeoutDuration = 100 * time.Millisecond
|
||||||
|
|
||||||
s := initServer(t, nil, timeoutDuration)
|
s := initServer(t, nil, timeoutDuration)
|
||||||
|
@ -568,17 +568,17 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
|
||||||
// Create a new client and connect to the server.
|
// Create a new client and connect to the server.
|
||||||
peerPub := randPubKey(t)
|
peerPub := randPubKey(t)
|
||||||
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
connect(t, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
// Register a session for this client to use in the subsequent tests.
|
// Register a session for this client to use in the subsequent tests.
|
||||||
sendMsg(t, i, test.createMsg, peer, timeoutDuration)
|
sendMsg(t, test.createMsg, peer, timeoutDuration)
|
||||||
initReply := recvReply(
|
initReply := recvReply(
|
||||||
t, i, "CreateSessionReply", peer, timeoutDuration,
|
t, "MsgCreateSessionReply", peer, timeoutDuration,
|
||||||
).(*wtwire.CreateSessionReply)
|
).(*wtwire.CreateSessionReply)
|
||||||
|
|
||||||
// Fail if the server rejected our proposed CreateSession message.
|
// Fail if the server rejected our proposed CreateSession message.
|
||||||
if initReply.Code != wtwire.CodeOK {
|
if initReply.Code != wtwire.CodeOK {
|
||||||
t.Fatalf("[test %d] server rejected session init", i)
|
t.Fatalf("server rejected session init")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the server closed the connection used to register the
|
// Check that the server closed the connection used to register the
|
||||||
|
@ -588,7 +588,7 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
|
||||||
// Now that the original connection has been closed, connect a new
|
// Now that the original connection has been closed, connect a new
|
||||||
// client with the same session id.
|
// client with the same session id.
|
||||||
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
connect(t, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
// Send the intended StateUpdate messages in series.
|
// Send the intended StateUpdate messages in series.
|
||||||
for j, update := range test.updates {
|
for j, update := range test.updates {
|
||||||
|
@ -599,21 +599,21 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
|
||||||
assertConnClosed(t, peer, 2*timeoutDuration)
|
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||||
|
|
||||||
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||||
connect(t, i, s, peer, test.initMsg, timeoutDuration)
|
connect(t, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the state update and verify it against our expected
|
// Send the state update and verify it against our expected
|
||||||
// response.
|
// response.
|
||||||
sendMsg(t, i, update, peer, timeoutDuration)
|
sendMsg(t, update, peer, timeoutDuration)
|
||||||
reply := recvReply(
|
reply := recvReply(
|
||||||
t, i, "StateUpdateReply", peer, timeoutDuration,
|
t, "MsgStateUpdateReply", peer, timeoutDuration,
|
||||||
).(*wtwire.StateUpdateReply)
|
).(*wtwire.StateUpdateReply)
|
||||||
|
|
||||||
if !reflect.DeepEqual(reply, test.replies[j]) {
|
if !reflect.DeepEqual(reply, test.replies[j]) {
|
||||||
t.Fatalf("[test %d, update %d] expected reply "+
|
t.Fatalf("[update %d] expected reply "+
|
||||||
"%v, got %d", i, j,
|
"%v, got %d", j,
|
||||||
test.replies[j], reply)
|
test.replies[j], reply)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -622,16 +622,148 @@ func testServerStateUpdates(t *testing.T, i int, test stateUpdateTestCase) {
|
||||||
assertConnClosed(t, peer, 2*timeoutDuration)
|
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(t *testing.T, i int, s wtserver.Interface, peer *wtmock.MockPeer,
|
// TestServerDeleteSession asserts the response to a DeleteSession request, and
|
||||||
|
// checking that the proper error is returned when the session doesn't exist and
|
||||||
|
// that a successful deletion does not disrupt other sessions.
|
||||||
|
func TestServerDeleteSession(t *testing.T) {
|
||||||
|
db := wtdb.NewMockDB()
|
||||||
|
|
||||||
|
localPub := randPubKey(t)
|
||||||
|
|
||||||
|
// Initialize two distinct peers with different session ids.
|
||||||
|
peerPub1 := randPubKey(t)
|
||||||
|
peerPub2 := randPubKey(t)
|
||||||
|
|
||||||
|
id1 := wtdb.NewSessionIDFromPubKey(peerPub1)
|
||||||
|
id2 := wtdb.NewSessionIDFromPubKey(peerPub2)
|
||||||
|
|
||||||
|
// Create closure to simplify assertions on session existence with the
|
||||||
|
// server's database.
|
||||||
|
hasSession := func(t *testing.T, id *wtdb.SessionID, shouldHave bool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, err := db.GetSessionInfo(id)
|
||||||
|
switch {
|
||||||
|
case shouldHave && err != nil:
|
||||||
|
t.Fatalf("expected server to have session %s, got: %v",
|
||||||
|
id, err)
|
||||||
|
case !shouldHave && err != wtdb.ErrSessionNotFound:
|
||||||
|
t.Fatalf("expected ErrSessionNotFound for session %s, "+
|
||||||
|
"got: %v", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
initMsg := wtwire.NewInitMessage(
|
||||||
|
lnwire.NewRawFeatureVector(),
|
||||||
|
testnetChainHash,
|
||||||
|
)
|
||||||
|
|
||||||
|
createSession := &wtwire.CreateSession{
|
||||||
|
BlobType: blob.TypeDefault,
|
||||||
|
MaxUpdates: 1000,
|
||||||
|
RewardBase: 0,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
const timeoutDuration = 100 * time.Millisecond
|
||||||
|
|
||||||
|
s := initServer(t, db, timeoutDuration)
|
||||||
|
defer s.Stop()
|
||||||
|
|
||||||
|
// Create a session for peer2 so that the server's db isn't completely
|
||||||
|
// empty.
|
||||||
|
peer2 := wtmock.NewMockPeer(localPub, peerPub2, nil, 0)
|
||||||
|
connect(t, s, peer2, initMsg, timeoutDuration)
|
||||||
|
sendMsg(t, createSession, peer2, timeoutDuration)
|
||||||
|
assertConnClosed(t, peer2, 2*timeoutDuration)
|
||||||
|
|
||||||
|
// Our initial assertions are that peer2 has a valid session, but peer1
|
||||||
|
// has not created one.
|
||||||
|
hasSession(t, &id1, false)
|
||||||
|
hasSession(t, &id2, true)
|
||||||
|
|
||||||
|
peer1Msgs := []struct {
|
||||||
|
send wtwire.Message
|
||||||
|
recv wtwire.Message
|
||||||
|
assert func(t *testing.T)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// Deleting unknown session should fail.
|
||||||
|
send: &wtwire.DeleteSession{},
|
||||||
|
recv: &wtwire.DeleteSessionReply{
|
||||||
|
Code: wtwire.DeleteSessionCodeNotFound,
|
||||||
|
},
|
||||||
|
assert: func(t *testing.T) {
|
||||||
|
// Peer2 should still be only session.
|
||||||
|
hasSession(t, &id1, false)
|
||||||
|
hasSession(t, &id2, true)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Create session for peer1.
|
||||||
|
send: createSession,
|
||||||
|
recv: &wtwire.CreateSessionReply{
|
||||||
|
Code: wtwire.CodeOK,
|
||||||
|
Data: addrScript,
|
||||||
|
},
|
||||||
|
assert: func(t *testing.T) {
|
||||||
|
// Both peers should have sessions.
|
||||||
|
hasSession(t, &id1, true)
|
||||||
|
hasSession(t, &id2, true)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
// Delete peer1's session.
|
||||||
|
send: &wtwire.DeleteSession{},
|
||||||
|
recv: &wtwire.DeleteSessionReply{
|
||||||
|
Code: wtwire.CodeOK,
|
||||||
|
},
|
||||||
|
assert: func(t *testing.T) {
|
||||||
|
// Peer1's session should have been removed.
|
||||||
|
hasSession(t, &id1, false)
|
||||||
|
hasSession(t, &id2, true)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now as peer1, process the canned messages defined above. This will:
|
||||||
|
// 1. Try to delete an unknown session and get a not found error code.
|
||||||
|
// 2. Create a new session using the same parameters as peer2.
|
||||||
|
// 3. Delete the newly created session and get an OK.
|
||||||
|
for _, msg := range peer1Msgs {
|
||||||
|
peer1 := wtmock.NewMockPeer(localPub, peerPub1, nil, 0)
|
||||||
|
connect(t, s, peer1, initMsg, timeoutDuration)
|
||||||
|
sendMsg(t, msg.send, peer1, timeoutDuration)
|
||||||
|
reply := recvReply(
|
||||||
|
t, msg.recv.MsgType().String(), peer1, timeoutDuration,
|
||||||
|
)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(reply, msg.recv) {
|
||||||
|
t.Fatalf("expected reply: %v, got: %v", msg.recv, reply)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertConnClosed(t, peer1, 2*timeoutDuration)
|
||||||
|
|
||||||
|
// Invoke assertions after completing the request/response
|
||||||
|
// dance.
|
||||||
|
msg.assert(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func connect(t *testing.T, s wtserver.Interface, peer *wtmock.MockPeer,
|
||||||
initMsg *wtwire.Init, timeout time.Duration) {
|
initMsg *wtwire.Init, timeout time.Duration) {
|
||||||
|
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
s.InboundPeerConnected(peer)
|
s.InboundPeerConnected(peer)
|
||||||
sendMsg(t, i, initMsg, peer, timeout)
|
sendMsg(t, initMsg, peer, timeout)
|
||||||
recvReply(t, i, "Init", peer, timeout)
|
recvReply(t, "MsgInit", peer, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendMsg sends a wtwire.Message message via a wtmock.MockPeer.
|
// sendMsg sends a wtwire.Message message via a wtmock.MockPeer.
|
||||||
func sendMsg(t *testing.T, i int, msg wtwire.Message,
|
func sendMsg(t *testing.T, msg wtwire.Message,
|
||||||
peer *wtmock.MockPeer, timeout time.Duration) {
|
peer *wtmock.MockPeer, timeout time.Duration) {
|
||||||
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
@ -639,22 +771,22 @@ func sendMsg(t *testing.T, i int, msg wtwire.Message,
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
_, err := wtwire.WriteMessage(&b, msg, 0)
|
_, err := wtwire.WriteMessage(&b, msg, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("[test %d] unable to encode %T message: %v",
|
t.Fatalf("unable to encode %T message: %v",
|
||||||
i, msg, err)
|
msg, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case peer.IncomingMsgs <- b.Bytes():
|
case peer.IncomingMsgs <- b.Bytes():
|
||||||
case <-time.After(2 * timeout):
|
case <-time.After(2 * timeout):
|
||||||
t.Fatalf("[test %d] unable to send %T message", i, msg)
|
t.Fatalf("unable to send %T message", msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// recvReply receives a message from the server, and parses it according to
|
// recvReply receives a message from the server, and parses it according to
|
||||||
// expected reply type. The supported replies are CreateSessionReply and
|
// expected reply type. The supported replies are CreateSessionReply and
|
||||||
// StateUpdateReply.
|
// StateUpdateReply.
|
||||||
func recvReply(t *testing.T, i int, name string,
|
func recvReply(t *testing.T, name string, peer *wtmock.MockPeer,
|
||||||
peer *wtmock.MockPeer, timeout time.Duration) wtwire.Message {
|
timeout time.Duration) wtwire.Message {
|
||||||
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
@ -667,29 +799,34 @@ func recvReply(t *testing.T, i int, name string,
|
||||||
case b := <-peer.OutgoingMsgs:
|
case b := <-peer.OutgoingMsgs:
|
||||||
msg, err = wtwire.ReadMessage(bytes.NewReader(b), 0)
|
msg, err = wtwire.ReadMessage(bytes.NewReader(b), 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("[test %d] unable to decode server "+
|
t.Fatalf("unable to decode server "+
|
||||||
"reply: %v", i, err)
|
"reply: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-time.After(2 * timeout):
|
case <-time.After(2 * timeout):
|
||||||
t.Fatalf("[test %d] server did not reply", i)
|
t.Fatalf("server did not reply")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch name {
|
switch name {
|
||||||
case "Init":
|
case "MsgInit":
|
||||||
if _, ok := msg.(*wtwire.Init); !ok {
|
if _, ok := msg.(*wtwire.Init); !ok {
|
||||||
t.Fatalf("[test %d] expected %s reply "+
|
t.Fatalf("expected %s reply message, "+
|
||||||
"message, got %T", i, name, msg)
|
"got %T", name, msg)
|
||||||
}
|
}
|
||||||
case "CreateSessionReply":
|
case "MsgCreateSessionReply":
|
||||||
if _, ok := msg.(*wtwire.CreateSessionReply); !ok {
|
if _, ok := msg.(*wtwire.CreateSessionReply); !ok {
|
||||||
t.Fatalf("[test %d] expected %s reply "+
|
t.Fatalf("expected %s reply message, "+
|
||||||
"message, got %T", i, name, msg)
|
"got %T", name, msg)
|
||||||
}
|
}
|
||||||
case "StateUpdateReply":
|
case "MsgStateUpdateReply":
|
||||||
if _, ok := msg.(*wtwire.StateUpdateReply); !ok {
|
if _, ok := msg.(*wtwire.StateUpdateReply); !ok {
|
||||||
t.Fatalf("[test %d] expected %s reply "+
|
t.Fatalf("expected %s reply message, "+
|
||||||
"message, got %T", i, name, msg)
|
"got %T", name, msg)
|
||||||
|
}
|
||||||
|
case "MsgDeleteSessionReply":
|
||||||
|
if _, ok := msg.(*wtwire.DeleteSessionReply); !ok {
|
||||||
|
t.Fatalf("expected %s reply message, "+
|
||||||
|
"got %T", name, msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
157
watchtower/wtserver/state_update.go
Normal file
157
watchtower/wtserver/state_update.go
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
package wtserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtwire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleStateUpdates processes a stream of StateUpdate requests from the
|
||||||
|
// client. The provided update should be the first such update read, subsequent
|
||||||
|
// updates will be consumed if the peer does not signal IsComplete on a
|
||||||
|
// particular update.
|
||||||
|
func (s *Server) handleStateUpdates(peer Peer, id *wtdb.SessionID,
|
||||||
|
update *wtwire.StateUpdate) error {
|
||||||
|
|
||||||
|
// Set the current update to the first update read off the wire.
|
||||||
|
// Additional updates will be read if this value is set to nil after
|
||||||
|
// processing the first.
|
||||||
|
var curUpdate = update
|
||||||
|
for {
|
||||||
|
// If this is not the first update, read the next state update
|
||||||
|
// from the peer.
|
||||||
|
if curUpdate == nil {
|
||||||
|
nextMsg, err := s.readMessage(peer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
curUpdate, ok = nextMsg.(*wtwire.StateUpdate)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("client sent %T after "+
|
||||||
|
"StateUpdate", nextMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to accept the state update from the client.
|
||||||
|
err := s.handleStateUpdate(peer, id, curUpdate)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client signals that this is last StateUpdate
|
||||||
|
// message, we can disconnect the client.
|
||||||
|
if curUpdate.IsComplete == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the current update to read subsequent updates in the
|
||||||
|
// stream.
|
||||||
|
curUpdate = nil
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-s.quit:
|
||||||
|
return ErrServerExiting
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStateUpdate processes a StateUpdate message request from a client. An
|
||||||
|
// attempt will be made to insert the update into the db, where it is validated
|
||||||
|
// against the client's session. The possible errors are then mapped back to
|
||||||
|
// StateUpdateCodes specified by the watchtower wire protocol, and sent back
|
||||||
|
// using a StateUpdateReply message.
|
||||||
|
func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
|
||||||
|
update *wtwire.StateUpdate) error {
|
||||||
|
|
||||||
|
var (
|
||||||
|
lastApplied uint16
|
||||||
|
failCode wtwire.ErrorCode
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
sessionUpdate := wtdb.SessionStateUpdate{
|
||||||
|
ID: *id,
|
||||||
|
Hint: update.Hint,
|
||||||
|
SeqNum: update.SeqNum,
|
||||||
|
LastApplied: update.LastApplied,
|
||||||
|
EncryptedBlob: update.EncryptedBlob,
|
||||||
|
}
|
||||||
|
|
||||||
|
lastApplied, err = s.cfg.DB.InsertStateUpdate(&sessionUpdate)
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
log.Debugf("State update %d accepted for %s",
|
||||||
|
update.SeqNum, id)
|
||||||
|
|
||||||
|
failCode = wtwire.CodeOK
|
||||||
|
|
||||||
|
// Return a permanent failure if a client tries to send an update for
|
||||||
|
// which we have no session.
|
||||||
|
case err == wtdb.ErrSessionNotFound:
|
||||||
|
failCode = wtwire.CodePermanentFailure
|
||||||
|
|
||||||
|
case err == wtdb.ErrSeqNumAlreadyApplied:
|
||||||
|
failCode = wtwire.CodePermanentFailure
|
||||||
|
|
||||||
|
// TODO(conner): remove session state for protocol
|
||||||
|
// violation. Could also double as clean up method for
|
||||||
|
// session-related state.
|
||||||
|
|
||||||
|
case err == wtdb.ErrLastAppliedReversion:
|
||||||
|
failCode = wtwire.StateUpdateCodeClientBehind
|
||||||
|
|
||||||
|
case err == wtdb.ErrSessionConsumed:
|
||||||
|
failCode = wtwire.StateUpdateCodeMaxUpdatesExceeded
|
||||||
|
|
||||||
|
case err == wtdb.ErrUpdateOutOfOrder:
|
||||||
|
failCode = wtwire.StateUpdateCodeSeqNumOutOfOrder
|
||||||
|
|
||||||
|
default:
|
||||||
|
failCode = wtwire.CodeTemporaryFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.cfg.NoAckUpdates {
|
||||||
|
return &connFailure{
|
||||||
|
ID: *id,
|
||||||
|
Code: uint16(failCode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.replyStateUpdate(
|
||||||
|
peer, id, failCode, lastApplied,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// replyStateUpdate sends a response to a StateUpdate from a client. If the
|
||||||
|
// status code in the reply is OK, the error from the write will be bubbled up.
|
||||||
|
// Otherwise, this method returns a connection error to ensure we don't continue
|
||||||
|
// communication with the client.
|
||||||
|
func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID,
|
||||||
|
code wtwire.StateUpdateCode, lastApplied uint16) error {
|
||||||
|
|
||||||
|
msg := &wtwire.StateUpdateReply{
|
||||||
|
Code: code,
|
||||||
|
LastApplied: lastApplied,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.sendMessage(peer, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to send StateUpdateReply to %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the write error if the request succeeded.
|
||||||
|
if code == wtwire.CodeOK {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the request failed, return a connection failure to
|
||||||
|
// disconnect the client.
|
||||||
|
return &connFailure{
|
||||||
|
ID: *id,
|
||||||
|
Code: uint16(code),
|
||||||
|
}
|
||||||
|
}
|
45
watchtower/wtwire/delete_session.go
Normal file
45
watchtower/wtwire/delete_session.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package wtwire
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
// DeleteSession is sent from the client to the tower to signal that the tower
|
||||||
|
// can delete all session state for the session key used to authenticate the
|
||||||
|
// brontide connection. This should be done by the client once all channels that
|
||||||
|
// have state updates in the session have been resolved on-chain.
|
||||||
|
type DeleteSession struct{}
|
||||||
|
|
||||||
|
// Compile-time constraint to ensure DeleteSession implements the wtwire.Message
|
||||||
|
// interface.
|
||||||
|
var _ Message = (*DeleteSession)(nil)
|
||||||
|
|
||||||
|
// Decode deserializes a serialized DeleteSession message stored in the passed
|
||||||
|
// io.Reader observing the specified protocol version.
|
||||||
|
//
|
||||||
|
// This is part of the wtwire.Message interface.
|
||||||
|
func (m *DeleteSession) Decode(r io.Reader, pver uint32) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode serializes the target DeleteSession message into the passed io.Writer
|
||||||
|
// observing the specified protocol version.
|
||||||
|
//
|
||||||
|
// This is part of the wtwire.Message interface.
|
||||||
|
func (m *DeleteSession) Encode(w io.Writer, pver uint32) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MsgType returns the integer uniquely identifying this message type on the
|
||||||
|
// wire.
|
||||||
|
//
|
||||||
|
// This is part of the wtwire.Message interface.
|
||||||
|
func (m *DeleteSession) MsgType() MessageType {
|
||||||
|
return MsgDeleteSession
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxPayloadLength returns the maximum allowed payload size for a DeleteSession
|
||||||
|
// message observing the specified protocol version.
|
||||||
|
//
|
||||||
|
// This is part of the wtwire.Message interface.
|
||||||
|
func (m *DeleteSession) MaxPayloadLength(uint32) uint32 {
|
||||||
|
return 0
|
||||||
|
}
|
66
watchtower/wtwire/delete_session_reply.go
Normal file
66
watchtower/wtwire/delete_session_reply.go
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
package wtwire
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
// DeleteSessionCode is an error code returned by a watchtower in response to a
|
||||||
|
// DeleteSession message.
|
||||||
|
type DeleteSessionCode = ErrorCode
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DeleteSessionCodeNotFound is returned when the watchtower does not
|
||||||
|
// know of the requested session. This may indicate an error on the
|
||||||
|
// client side, or that the tower had already deleted the session in a
|
||||||
|
// prior request that the client may not have received.
|
||||||
|
DeleteSessionCodeNotFound DeleteSessionCode = 80
|
||||||
|
|
||||||
|
// TODO(conner): add String method after wtclient is merged
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeleteSessionReply is a message sent in response to a client's DeleteSession
|
||||||
|
// request. The message indicates whether or not the deletion was a success or
|
||||||
|
// failure.
|
||||||
|
type DeleteSessionReply struct {
|
||||||
|
// Code will be non-zero if the watchtower was not able to delete the
|
||||||
|
// requested session.
|
||||||
|
Code DeleteSessionCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// A compile time check to ensure DeleteSessionReply implements the
|
||||||
|
// wtwire.Message interface.
|
||||||
|
var _ Message = (*DeleteSessionReply)(nil)
|
||||||
|
|
||||||
|
// Decode deserializes a serialized DeleteSessionReply message stored in the
|
||||||
|
// passed io.Reader observing the specified protocol version.
|
||||||
|
//
|
||||||
|
// This is part of the wtwire.Message interface.
|
||||||
|
func (m *DeleteSessionReply) Decode(r io.Reader, pver uint32) error {
|
||||||
|
return ReadElements(r,
|
||||||
|
&m.Code,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode serializes the target DeleteSessionReply into the passed io.Writer
|
||||||
|
// observing the protocol version specified.
|
||||||
|
//
|
||||||
|
// This is part of the wtwire.Message interface.
|
||||||
|
func (m *DeleteSessionReply) Encode(w io.Writer, pver uint32) error {
|
||||||
|
return WriteElements(w,
|
||||||
|
m.Code,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MsgType returns the integer uniquely identifying this message type on the
|
||||||
|
// wire.
|
||||||
|
//
|
||||||
|
// This is part of the wtwire.Message interface.
|
||||||
|
func (m *DeleteSessionReply) MsgType() MessageType {
|
||||||
|
return MsgDeleteSessionReply
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxPayloadLength returns the maximum allowed payload size for a
|
||||||
|
// DeleteSessionReply complete message observing the specified protocol version.
|
||||||
|
//
|
||||||
|
// This is part of the wtwire.Message interface.
|
||||||
|
func (m *DeleteSessionReply) MaxPayloadLength(uint32) uint32 {
|
||||||
|
return 2
|
||||||
|
}
|
|
@ -40,6 +40,13 @@ const (
|
||||||
|
|
||||||
// MsgStateUpdateReply identifies an encoded StateUpdateReply message.
|
// MsgStateUpdateReply identifies an encoded StateUpdateReply message.
|
||||||
MsgStateUpdateReply MessageType = 305
|
MsgStateUpdateReply MessageType = 305
|
||||||
|
|
||||||
|
// MsgDeleteSession identifies an encoded DeleteSession message.
|
||||||
|
MsgDeleteSession MessageType = 306
|
||||||
|
|
||||||
|
// MsgDeleteSessionReply identifies an encoded DeleteSessionReply
|
||||||
|
// message.
|
||||||
|
MsgDeleteSessionReply MessageType = 307
|
||||||
)
|
)
|
||||||
|
|
||||||
// String returns a human readable description of the message type.
|
// String returns a human readable description of the message type.
|
||||||
|
@ -55,6 +62,10 @@ func (m MessageType) String() string {
|
||||||
return "MsgStateUpdate"
|
return "MsgStateUpdate"
|
||||||
case MsgStateUpdateReply:
|
case MsgStateUpdateReply:
|
||||||
return "MsgStateUpdateReply"
|
return "MsgStateUpdateReply"
|
||||||
|
case MsgDeleteSession:
|
||||||
|
return "MsgDeleteSession"
|
||||||
|
case MsgDeleteSessionReply:
|
||||||
|
return "MsgDeleteSessionReply"
|
||||||
case MsgError:
|
case MsgError:
|
||||||
return "Error"
|
return "Error"
|
||||||
default:
|
default:
|
||||||
|
@ -97,6 +108,10 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
|
||||||
msg = &StateUpdate{}
|
msg = &StateUpdate{}
|
||||||
case MsgStateUpdateReply:
|
case MsgStateUpdateReply:
|
||||||
msg = &StateUpdateReply{}
|
msg = &StateUpdateReply{}
|
||||||
|
case MsgDeleteSession:
|
||||||
|
msg = &DeleteSession{}
|
||||||
|
case MsgDeleteSessionReply:
|
||||||
|
msg = &DeleteSessionReply{}
|
||||||
case MsgError:
|
case MsgError:
|
||||||
msg = &Error{}
|
msg = &Error{}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -126,6 +126,18 @@ func TestWatchtowerWireProtocol(t *testing.T) {
|
||||||
return mainScenario(&m)
|
return mainScenario(&m)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
msgType: wtwire.MsgDeleteSession,
|
||||||
|
scenario: func(m wtwire.DeleteSession) bool {
|
||||||
|
return mainScenario(&m)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msgType: wtwire.MsgDeleteSessionReply,
|
||||||
|
scenario: func(m wtwire.DeleteSessionReply) bool {
|
||||||
|
return mainScenario(&m)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
msgType: wtwire.MsgError,
|
msgType: wtwire.MsgError,
|
||||||
scenario: func(m wtwire.Error) bool {
|
scenario: func(m wtwire.Error) bool {
|
||||||
|
|
Loading…
Add table
Reference in a new issue