htlcswitch: Assign each pending payment a unique ID.

This simplifies the pending payment handling code because it allows it
be handled in nearly the same way as forwarded HTLCs by treating an
empty channel ID as local dispatch.
This commit is contained in:
Jim Posen 2017-10-30 12:57:32 -07:00 committed by Olaoluwa Osuntokun
parent 4a29fbdab2
commit 40fb0ddcfc
3 changed files with 66 additions and 101 deletions

View file

@ -721,19 +721,16 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
"local_log_index=%v, batch_size=%v", "local_log_index=%v, batch_size=%v",
htlc.PaymentHash[:], index, l.batchCounter+1) htlc.PaymentHash[:], index, l.batchCounter+1)
// If packet was forwarded from another channel link then we should // Create circuit (remember the path) in order to forward settle/fail
// create circuit (remember the path) in order to forward settle/fail
// packet back. // packet back.
if pkt.incomingChanID != (lnwire.ShortChannelID{}) { l.cfg.Switch.addCircuit(&PaymentCircuit{
l.cfg.Switch.addCircuit(&PaymentCircuit{ PaymentHash: htlc.PaymentHash,
PaymentHash: htlc.PaymentHash, IncomingChanID: pkt.incomingChanID,
IncomingChanID: pkt.incomingChanID, IncomingHTLCID: pkt.incomingHTLCID,
IncomingHTLCID: pkt.incomingHTLCID, OutgoingChanID: l.ShortChanID(),
OutgoingChanID: l.ShortChanID(), OutgoingHTLCID: index,
OutgoingHTLCID: index, ErrorEncrypter: pkt.obfuscator,
ErrorEncrypter: pkt.obfuscator, })
})
}
htlc.ID = index htlc.ID = index
l.cfg.Peer.SendMessage(htlc) l.cfg.Peer.SendMessage(htlc)

View file

@ -1448,7 +1448,7 @@ func newSingleLinkTestHarness(chanAmt btcutil.Amount) (ChannelLink, func(), erro
aliceCfg := ChannelLinkConfig{ aliceCfg := ChannelLinkConfig{
FwrdingPolicy: globalPolicy, FwrdingPolicy: globalPolicy,
Peer: &alicePeer, Peer: &alicePeer,
Switch: nil, Switch: New(Config{}),
DecodeHopIterator: decoder.DecodeHopIterator, DecodeHopIterator: decoder.DecodeHopIterator,
DecodeOnionObfuscator: func(io.Reader) (ErrorEncrypter, lnwire.FailCode) { DecodeOnionObfuscator: func(io.Reader) (ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone return obfuscator, lnwire.CodeNone

View file

@ -121,11 +121,13 @@ type Switch struct {
// service was initialized with. // service was initialized with.
cfg *Config cfg *Config
// pendingPayments is correspondence of user payments and its hashes, // pendingPayments stores payments initiated by the user that are not yet
// which is used to save the payments which made by user and notify // settled. The map is used to later look up the payments and notify the
// them about result later. // user of the result when they are complete. Each payment is given a unique
pendingPayments map[lnwallet.PaymentHash][]*pendingPayment // integer ID when it is created.
pendingPayments map[uint64]*pendingPayment
pendingMutex sync.RWMutex pendingMutex sync.RWMutex
nextPendingID uint64
// circuits is storage for payment circuits which are used to // circuits is storage for payment circuits which are used to
// forward the settle/fail htlc updates back to the add htlc initiator. // forward the settle/fail htlc updates back to the add htlc initiator.
@ -171,7 +173,7 @@ func New(cfg Config) *Switch {
linkIndex: make(map[lnwire.ChannelID]ChannelLink), linkIndex: make(map[lnwire.ChannelID]ChannelLink),
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}), interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
pendingPayments: make(map[lnwallet.PaymentHash][]*pendingPayment), pendingPayments: make(map[uint64]*pendingPayment),
htlcPlex: make(chan *plexPacket), htlcPlex: make(chan *plexPacket),
chanCloseRequests: make(chan *ChanClose), chanCloseRequests: make(chan *ChanClose),
linkControl: make(chan interface{}), linkControl: make(chan interface{}),
@ -195,19 +197,21 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
} }
s.pendingMutex.Lock() s.pendingMutex.Lock()
s.pendingPayments[htlc.PaymentHash] = append( paymentID := s.nextPendingID
s.pendingPayments[htlc.PaymentHash], payment) s.nextPendingID++
s.pendingPayments[paymentID] = payment
s.pendingMutex.Unlock() s.pendingMutex.Unlock()
// Generate and send new update packet, if error will be received on // Generate and send new update packet, if error will be received on
// this stage it means that packet haven't left boundaries of our // this stage it means that packet haven't left boundaries of our
// system and something wrong happened. // system and something wrong happened.
packet := &htlcPacket{ packet := &htlcPacket{
destNode: nextNode, incomingHTLCID: paymentID,
htlc: htlc, destNode: nextNode,
htlc: htlc,
} }
if err := s.forward(packet); err != nil { if err := s.forward(packet); err != nil {
s.removePendingPayment(payment.amount, payment.paymentHash) s.removePendingPayment(paymentID)
return zeroPreimage, err return zeroPreimage, err
} }
@ -345,7 +349,16 @@ func (s *Switch) forward(packet *htlcPacket) error {
// o <-settle-- o <--settle-- o // o <-settle-- o <--settle-- o
// Alice Bob Carol // Alice Bob Carol
// //
func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket) error { func (s *Switch) handleLocalDispatch(packet *htlcPacket) error {
// Pending payments use a special interpretation of the incomingChanID and
// incomingHTLCID fields on packet where the channel ID is blank and the
// HTLC ID is the payment ID. The switch basically views the users of the
// node as a special channel that also offers a sequence of HTLCs.
payment, err := s.findPayment(packet.incomingHTLCID)
if err != nil {
return err
}
switch htlc := packet.htlc.(type) { switch htlc := packet.htlc.(type) {
// User have created the htlc update therefore we should find the // User have created the htlc update therefore we should find the
@ -407,6 +420,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
// manages then channel. // manages then channel.
// //
// TODO(roasbeef): should return with an error // TODO(roasbeef): should return with an error
packet.outgoingChanID = destination.ShortChanID()
destination.HandleSwitchPacket(packet) destination.HandleSwitchPacket(packet)
return nil return nil
@ -416,7 +430,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
// Notify the user that his payment was successfully proceed. // Notify the user that his payment was successfully proceed.
payment.err <- nil payment.err <- nil
payment.preimage <- htlc.PaymentPreimage payment.preimage <- htlc.PaymentPreimage
s.removePendingPayment(payment.amount, payment.paymentHash) s.removePendingPayment(packet.incomingHTLCID)
// We've just received a fail update which means we can finalize the // We've just received a fail update which means we can finalize the
// user payment and return fail response. // user payment and return fail response.
@ -439,7 +453,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
} }
payment.preimage <- zeroPreimage payment.preimage <- zeroPreimage
s.removePendingPayment(payment.amount, payment.paymentHash) s.removePendingPayment(packet.incomingHTLCID)
default: default:
return errors.New("wrong update type") return errors.New("wrong update type")
@ -458,6 +472,12 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// payment circuit within our internal state so we can properly forward // payment circuit within our internal state so we can properly forward
// the ultimate settle message back latter. // the ultimate settle message back latter.
case *lnwire.UpdateAddHTLC: case *lnwire.UpdateAddHTLC:
if packet.incomingChanID == (lnwire.ShortChannelID{}) {
// A blank incomingChanID indicates that this is a pending
// user-initiated payment.
return s.handleLocalDispatch(packet)
}
source, err := s.getLinkByShortID(packet.incomingChanID) source, err := s.getLinkByShortID(packet.incomingChanID)
if err != nil { if err != nil {
err := errors.Errorf("unable to find channel link "+ err := errors.Errorf("unable to find channel link "+
@ -581,15 +601,21 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
circuit.OutgoingChanID) circuit.OutgoingChanID)
} }
packet.incomingChanID = circuit.IncomingChanID
packet.incomingHTLCID = circuit.IncomingHTLCID
// A blank IncomingChanID in a circuit indicates that it is a
// pending user-initiated payment.
if circuit.IncomingChanID == (lnwire.ShortChannelID{}) {
return s.handleLocalDispatch(packet)
}
// Obfuscate the error message for fail updates before sending back // Obfuscate the error message for fail updates before sending back
// through the circuit. // through the circuit.
if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated { if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated {
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt( htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
htlc.Reason) htlc.Reason)
} }
packet.incomingChanID = circuit.IncomingChanID
packet.incomingHTLCID = circuit.IncomingHTLCID
} }
source, err := s.getLinkByShortID(packet.incomingChanID) source, err := s.getLinkByShortID(packet.incomingChanID)
@ -696,37 +722,7 @@ func (s *Switch) htlcForwarder() {
// packet concretely, then either forward it along, or // packet concretely, then either forward it along, or
// interpret a return packet to a locally initialized one. // interpret a return packet to a locally initialized one.
case cmd := <-s.htlcPlex: case cmd := <-s.htlcPlex:
var ( cmd.err <- s.handlePacketForward(cmd.pkt)
paymentHash lnwallet.PaymentHash
amount lnwire.MilliSatoshi
)
// Only three types of message should be forwarded:
// add, fails, and settles. Anything else is an error.
switch m := cmd.pkt.htlc.(type) {
case *lnwire.UpdateAddHTLC:
paymentHash = m.PaymentHash
amount = m.Amount
case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC:
paymentHash = cmd.pkt.payHash
amount = cmd.pkt.amount
default:
cmd.err <- errors.New("wrong type of update")
return
}
// If we can locate this packet in our local records,
// then this means a local sub-system initiated it.
// Otherwise, this is just a packet to be forwarded, so
// we'll treat it as so.
//
// TODO(roasbeef): can fast path this
payment, err := s.findPayment(amount, paymentHash)
if err != nil {
cmd.err <- s.handlePacketForward(cmd.pkt)
} else {
cmd.err <- s.handleLocalDispatch(payment, cmd.pkt)
}
// The log ticker has fired, so we'll calculate some forwarding // The log ticker has fired, so we'll calculate some forwarding
// stats for the last 10 seconds to display within the logs to // stats for the last 10 seconds to display within the logs to
@ -1034,64 +1030,36 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
// removePendingPayment is the helper function which removes the pending user // removePendingPayment is the helper function which removes the pending user
// payment. // payment.
func (s *Switch) removePendingPayment(amount lnwire.MilliSatoshi, func (s *Switch) removePendingPayment(paymentID uint64) error {
hash lnwallet.PaymentHash) error {
s.pendingMutex.Lock() s.pendingMutex.Lock()
defer s.pendingMutex.Unlock() defer s.pendingMutex.Unlock()
payments, ok := s.pendingPayments[hash] if _, ok := s.pendingPayments[paymentID]; !ok {
if ok { return errors.Errorf("Cannot find pending payment with ID %d",
for i, payment := range payments { paymentID)
if payment.amount == amount {
// Delete without preserving order
// Google: Golang slice tricks
payments[i] = payments[len(payments)-1]
payments[len(payments)-1] = nil
s.pendingPayments[hash] = payments[:len(payments)-1]
if len(s.pendingPayments[hash]) == 0 {
delete(s.pendingPayments, hash)
}
return nil
}
}
} }
return errors.Errorf("unable to remove pending payment with "+ delete(s.pendingPayments, paymentID)
"hash(%v) and amount(%v)", hash, amount) return nil
} }
// findPayment is the helper function which find the payment. // findPayment is the helper function which find the payment.
func (s *Switch) findPayment(amount lnwire.MilliSatoshi, func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) {
hash lnwallet.PaymentHash) (*pendingPayment, error) {
s.pendingMutex.RLock() s.pendingMutex.RLock()
defer s.pendingMutex.RUnlock() defer s.pendingMutex.RUnlock()
payments, ok := s.pendingPayments[hash] payment, ok := s.pendingPayments[paymentID]
if ok { if !ok {
for _, payment := range payments { return nil, errors.Errorf("Cannot find pending payment with ID %d",
if payment.amount == amount { paymentID)
return payment, nil
}
}
} }
return payment, nil
return nil, errors.Errorf("unable to remove pending payment with "+
"hash(%v) and amount(%v)", hash, amount)
} }
// numPendingPayments is helper function which returns the overall number of // numPendingPayments is helper function which returns the overall number of
// pending user payments. // pending user payments.
func (s *Switch) numPendingPayments() int { func (s *Switch) numPendingPayments() int {
var l int return len(s.pendingPayments)
for _, payments := range s.pendingPayments {
l += len(payments)
}
return l
} }
// addCircuit adds a circuit to the switch's in-memory mapping. // addCircuit adds a circuit to the switch's in-memory mapping.