From 40fb0ddcfce2824288244a4d8123ca5b08f0d473 Mon Sep 17 00:00:00 2001 From: Jim Posen Date: Mon, 30 Oct 2017 12:57:32 -0700 Subject: [PATCH] htlcswitch: Assign each pending payment a unique ID. This simplifies the pending payment handling code because it allows it be handled in nearly the same way as forwarded HTLCs by treating an empty channel ID as local dispatch. --- htlcswitch/link.go | 21 +++--- htlcswitch/link_test.go | 2 +- htlcswitch/switch.go | 144 ++++++++++++++++------------------------ 3 files changed, 66 insertions(+), 101 deletions(-) diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 82600b4e8..7b9f74866 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -721,19 +721,16 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { "local_log_index=%v, batch_size=%v", htlc.PaymentHash[:], index, l.batchCounter+1) - // If packet was forwarded from another channel link then we should - // create circuit (remember the path) in order to forward settle/fail + // Create circuit (remember the path) in order to forward settle/fail // packet back. - if pkt.incomingChanID != (lnwire.ShortChannelID{}) { - l.cfg.Switch.addCircuit(&PaymentCircuit{ - PaymentHash: htlc.PaymentHash, - IncomingChanID: pkt.incomingChanID, - IncomingHTLCID: pkt.incomingHTLCID, - OutgoingChanID: l.ShortChanID(), - OutgoingHTLCID: index, - ErrorEncrypter: pkt.obfuscator, - }) - } + l.cfg.Switch.addCircuit(&PaymentCircuit{ + PaymentHash: htlc.PaymentHash, + IncomingChanID: pkt.incomingChanID, + IncomingHTLCID: pkt.incomingHTLCID, + OutgoingChanID: l.ShortChanID(), + OutgoingHTLCID: index, + ErrorEncrypter: pkt.obfuscator, + }) htlc.ID = index l.cfg.Peer.SendMessage(htlc) diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 636518d7e..635b940c8 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1448,7 +1448,7 @@ func newSingleLinkTestHarness(chanAmt btcutil.Amount) (ChannelLink, func(), erro aliceCfg := ChannelLinkConfig{ FwrdingPolicy: globalPolicy, Peer: &alicePeer, - Switch: nil, + Switch: New(Config{}), DecodeHopIterator: decoder.DecodeHopIterator, DecodeOnionObfuscator: func(io.Reader) (ErrorEncrypter, lnwire.FailCode) { return obfuscator, lnwire.CodeNone diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 1219b24bd..6524af9a7 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -121,11 +121,13 @@ type Switch struct { // service was initialized with. cfg *Config - // pendingPayments is correspondence of user payments and its hashes, - // which is used to save the payments which made by user and notify - // them about result later. - pendingPayments map[lnwallet.PaymentHash][]*pendingPayment + // pendingPayments stores payments initiated by the user that are not yet + // settled. The map is used to later look up the payments and notify the + // user of the result when they are complete. Each payment is given a unique + // integer ID when it is created. + pendingPayments map[uint64]*pendingPayment pendingMutex sync.RWMutex + nextPendingID uint64 // circuits is storage for payment circuits which are used to // forward the settle/fail htlc updates back to the add htlc initiator. @@ -171,7 +173,7 @@ func New(cfg Config) *Switch { linkIndex: make(map[lnwire.ChannelID]ChannelLink), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}), - pendingPayments: make(map[lnwallet.PaymentHash][]*pendingPayment), + pendingPayments: make(map[uint64]*pendingPayment), htlcPlex: make(chan *plexPacket), chanCloseRequests: make(chan *ChanClose), linkControl: make(chan interface{}), @@ -195,19 +197,21 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, } s.pendingMutex.Lock() - s.pendingPayments[htlc.PaymentHash] = append( - s.pendingPayments[htlc.PaymentHash], payment) + paymentID := s.nextPendingID + s.nextPendingID++ + s.pendingPayments[paymentID] = payment s.pendingMutex.Unlock() // Generate and send new update packet, if error will be received on // this stage it means that packet haven't left boundaries of our // system and something wrong happened. packet := &htlcPacket{ - destNode: nextNode, - htlc: htlc, + incomingHTLCID: paymentID, + destNode: nextNode, + htlc: htlc, } if err := s.forward(packet); err != nil { - s.removePendingPayment(payment.amount, payment.paymentHash) + s.removePendingPayment(paymentID) return zeroPreimage, err } @@ -345,7 +349,16 @@ func (s *Switch) forward(packet *htlcPacket) error { // o <-settle-- o <--settle-- o // Alice Bob Carol // -func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket) error { +func (s *Switch) handleLocalDispatch(packet *htlcPacket) error { + // Pending payments use a special interpretation of the incomingChanID and + // incomingHTLCID fields on packet where the channel ID is blank and the + // HTLC ID is the payment ID. The switch basically views the users of the + // node as a special channel that also offers a sequence of HTLCs. + payment, err := s.findPayment(packet.incomingHTLCID) + if err != nil { + return err + } + switch htlc := packet.htlc.(type) { // User have created the htlc update therefore we should find the @@ -407,6 +420,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket // manages then channel. // // TODO(roasbeef): should return with an error + packet.outgoingChanID = destination.ShortChanID() destination.HandleSwitchPacket(packet) return nil @@ -416,7 +430,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket // Notify the user that his payment was successfully proceed. payment.err <- nil payment.preimage <- htlc.PaymentPreimage - s.removePendingPayment(payment.amount, payment.paymentHash) + s.removePendingPayment(packet.incomingHTLCID) // We've just received a fail update which means we can finalize the // user payment and return fail response. @@ -439,7 +453,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket } payment.preimage <- zeroPreimage - s.removePendingPayment(payment.amount, payment.paymentHash) + s.removePendingPayment(packet.incomingHTLCID) default: return errors.New("wrong update type") @@ -458,6 +472,12 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // payment circuit within our internal state so we can properly forward // the ultimate settle message back latter. case *lnwire.UpdateAddHTLC: + if packet.incomingChanID == (lnwire.ShortChannelID{}) { + // A blank incomingChanID indicates that this is a pending + // user-initiated payment. + return s.handleLocalDispatch(packet) + } + source, err := s.getLinkByShortID(packet.incomingChanID) if err != nil { err := errors.Errorf("unable to find channel link "+ @@ -581,15 +601,21 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { circuit.OutgoingChanID) } + packet.incomingChanID = circuit.IncomingChanID + packet.incomingHTLCID = circuit.IncomingHTLCID + + // A blank IncomingChanID in a circuit indicates that it is a + // pending user-initiated payment. + if circuit.IncomingChanID == (lnwire.ShortChannelID{}) { + return s.handleLocalDispatch(packet) + } + // Obfuscate the error message for fail updates before sending back // through the circuit. if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated { htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt( htlc.Reason) } - - packet.incomingChanID = circuit.IncomingChanID - packet.incomingHTLCID = circuit.IncomingHTLCID } source, err := s.getLinkByShortID(packet.incomingChanID) @@ -696,37 +722,7 @@ func (s *Switch) htlcForwarder() { // packet concretely, then either forward it along, or // interpret a return packet to a locally initialized one. case cmd := <-s.htlcPlex: - var ( - paymentHash lnwallet.PaymentHash - amount lnwire.MilliSatoshi - ) - - // Only three types of message should be forwarded: - // add, fails, and settles. Anything else is an error. - switch m := cmd.pkt.htlc.(type) { - case *lnwire.UpdateAddHTLC: - paymentHash = m.PaymentHash - amount = m.Amount - case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC: - paymentHash = cmd.pkt.payHash - amount = cmd.pkt.amount - default: - cmd.err <- errors.New("wrong type of update") - return - } - - // If we can locate this packet in our local records, - // then this means a local sub-system initiated it. - // Otherwise, this is just a packet to be forwarded, so - // we'll treat it as so. - // - // TODO(roasbeef): can fast path this - payment, err := s.findPayment(amount, paymentHash) - if err != nil { - cmd.err <- s.handlePacketForward(cmd.pkt) - } else { - cmd.err <- s.handleLocalDispatch(payment, cmd.pkt) - } + cmd.err <- s.handlePacketForward(cmd.pkt) // The log ticker has fired, so we'll calculate some forwarding // stats for the last 10 seconds to display within the logs to @@ -1034,64 +1030,36 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) { // removePendingPayment is the helper function which removes the pending user // payment. -func (s *Switch) removePendingPayment(amount lnwire.MilliSatoshi, - hash lnwallet.PaymentHash) error { - +func (s *Switch) removePendingPayment(paymentID uint64) error { s.pendingMutex.Lock() defer s.pendingMutex.Unlock() - payments, ok := s.pendingPayments[hash] - if ok { - for i, payment := range payments { - if payment.amount == amount { - // Delete without preserving order - // Google: Golang slice tricks - payments[i] = payments[len(payments)-1] - payments[len(payments)-1] = nil - s.pendingPayments[hash] = payments[:len(payments)-1] - - if len(s.pendingPayments[hash]) == 0 { - delete(s.pendingPayments, hash) - } - - return nil - } - } + if _, ok := s.pendingPayments[paymentID]; !ok { + return errors.Errorf("Cannot find pending payment with ID %d", + paymentID) } - return errors.Errorf("unable to remove pending payment with "+ - "hash(%v) and amount(%v)", hash, amount) + delete(s.pendingPayments, paymentID) + return nil } // findPayment is the helper function which find the payment. -func (s *Switch) findPayment(amount lnwire.MilliSatoshi, - hash lnwallet.PaymentHash) (*pendingPayment, error) { - +func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) { s.pendingMutex.RLock() defer s.pendingMutex.RUnlock() - payments, ok := s.pendingPayments[hash] - if ok { - for _, payment := range payments { - if payment.amount == amount { - return payment, nil - } - } + payment, ok := s.pendingPayments[paymentID] + if !ok { + return nil, errors.Errorf("Cannot find pending payment with ID %d", + paymentID) } - - return nil, errors.Errorf("unable to remove pending payment with "+ - "hash(%v) and amount(%v)", hash, amount) + return payment, nil } // numPendingPayments is helper function which returns the overall number of // pending user payments. func (s *Switch) numPendingPayments() int { - var l int - for _, payments := range s.pendingPayments { - l += len(payments) - } - - return l + return len(s.pendingPayments) } // addCircuit adds a circuit to the switch's in-memory mapping.