mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-13 11:09:23 +01:00
htlcswitch: Assign each pending payment a unique ID.
This simplifies the pending payment handling code because it allows it be handled in nearly the same way as forwarded HTLCs by treating an empty channel ID as local dispatch.
This commit is contained in:
parent
4a29fbdab2
commit
40fb0ddcfc
3 changed files with 66 additions and 101 deletions
|
@ -721,19 +721,16 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
|
||||||
"local_log_index=%v, batch_size=%v",
|
"local_log_index=%v, batch_size=%v",
|
||||||
htlc.PaymentHash[:], index, l.batchCounter+1)
|
htlc.PaymentHash[:], index, l.batchCounter+1)
|
||||||
|
|
||||||
// If packet was forwarded from another channel link then we should
|
// Create circuit (remember the path) in order to forward settle/fail
|
||||||
// create circuit (remember the path) in order to forward settle/fail
|
|
||||||
// packet back.
|
// packet back.
|
||||||
if pkt.incomingChanID != (lnwire.ShortChannelID{}) {
|
l.cfg.Switch.addCircuit(&PaymentCircuit{
|
||||||
l.cfg.Switch.addCircuit(&PaymentCircuit{
|
PaymentHash: htlc.PaymentHash,
|
||||||
PaymentHash: htlc.PaymentHash,
|
IncomingChanID: pkt.incomingChanID,
|
||||||
IncomingChanID: pkt.incomingChanID,
|
IncomingHTLCID: pkt.incomingHTLCID,
|
||||||
IncomingHTLCID: pkt.incomingHTLCID,
|
OutgoingChanID: l.ShortChanID(),
|
||||||
OutgoingChanID: l.ShortChanID(),
|
OutgoingHTLCID: index,
|
||||||
OutgoingHTLCID: index,
|
ErrorEncrypter: pkt.obfuscator,
|
||||||
ErrorEncrypter: pkt.obfuscator,
|
})
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
htlc.ID = index
|
htlc.ID = index
|
||||||
l.cfg.Peer.SendMessage(htlc)
|
l.cfg.Peer.SendMessage(htlc)
|
||||||
|
|
|
@ -1448,7 +1448,7 @@ func newSingleLinkTestHarness(chanAmt btcutil.Amount) (ChannelLink, func(), erro
|
||||||
aliceCfg := ChannelLinkConfig{
|
aliceCfg := ChannelLinkConfig{
|
||||||
FwrdingPolicy: globalPolicy,
|
FwrdingPolicy: globalPolicy,
|
||||||
Peer: &alicePeer,
|
Peer: &alicePeer,
|
||||||
Switch: nil,
|
Switch: New(Config{}),
|
||||||
DecodeHopIterator: decoder.DecodeHopIterator,
|
DecodeHopIterator: decoder.DecodeHopIterator,
|
||||||
DecodeOnionObfuscator: func(io.Reader) (ErrorEncrypter, lnwire.FailCode) {
|
DecodeOnionObfuscator: func(io.Reader) (ErrorEncrypter, lnwire.FailCode) {
|
||||||
return obfuscator, lnwire.CodeNone
|
return obfuscator, lnwire.CodeNone
|
||||||
|
|
|
@ -121,11 +121,13 @@ type Switch struct {
|
||||||
// service was initialized with.
|
// service was initialized with.
|
||||||
cfg *Config
|
cfg *Config
|
||||||
|
|
||||||
// pendingPayments is correspondence of user payments and its hashes,
|
// pendingPayments stores payments initiated by the user that are not yet
|
||||||
// which is used to save the payments which made by user and notify
|
// settled. The map is used to later look up the payments and notify the
|
||||||
// them about result later.
|
// user of the result when they are complete. Each payment is given a unique
|
||||||
pendingPayments map[lnwallet.PaymentHash][]*pendingPayment
|
// integer ID when it is created.
|
||||||
|
pendingPayments map[uint64]*pendingPayment
|
||||||
pendingMutex sync.RWMutex
|
pendingMutex sync.RWMutex
|
||||||
|
nextPendingID uint64
|
||||||
|
|
||||||
// circuits is storage for payment circuits which are used to
|
// circuits is storage for payment circuits which are used to
|
||||||
// forward the settle/fail htlc updates back to the add htlc initiator.
|
// forward the settle/fail htlc updates back to the add htlc initiator.
|
||||||
|
@ -171,7 +173,7 @@ func New(cfg Config) *Switch {
|
||||||
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
|
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
|
||||||
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
|
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
|
||||||
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
|
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
|
||||||
pendingPayments: make(map[lnwallet.PaymentHash][]*pendingPayment),
|
pendingPayments: make(map[uint64]*pendingPayment),
|
||||||
htlcPlex: make(chan *plexPacket),
|
htlcPlex: make(chan *plexPacket),
|
||||||
chanCloseRequests: make(chan *ChanClose),
|
chanCloseRequests: make(chan *ChanClose),
|
||||||
linkControl: make(chan interface{}),
|
linkControl: make(chan interface{}),
|
||||||
|
@ -195,19 +197,21 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.pendingMutex.Lock()
|
s.pendingMutex.Lock()
|
||||||
s.pendingPayments[htlc.PaymentHash] = append(
|
paymentID := s.nextPendingID
|
||||||
s.pendingPayments[htlc.PaymentHash], payment)
|
s.nextPendingID++
|
||||||
|
s.pendingPayments[paymentID] = payment
|
||||||
s.pendingMutex.Unlock()
|
s.pendingMutex.Unlock()
|
||||||
|
|
||||||
// Generate and send new update packet, if error will be received on
|
// Generate and send new update packet, if error will be received on
|
||||||
// this stage it means that packet haven't left boundaries of our
|
// this stage it means that packet haven't left boundaries of our
|
||||||
// system and something wrong happened.
|
// system and something wrong happened.
|
||||||
packet := &htlcPacket{
|
packet := &htlcPacket{
|
||||||
destNode: nextNode,
|
incomingHTLCID: paymentID,
|
||||||
htlc: htlc,
|
destNode: nextNode,
|
||||||
|
htlc: htlc,
|
||||||
}
|
}
|
||||||
if err := s.forward(packet); err != nil {
|
if err := s.forward(packet); err != nil {
|
||||||
s.removePendingPayment(payment.amount, payment.paymentHash)
|
s.removePendingPayment(paymentID)
|
||||||
return zeroPreimage, err
|
return zeroPreimage, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -345,7 +349,16 @@ func (s *Switch) forward(packet *htlcPacket) error {
|
||||||
// o <-settle-- o <--settle-- o
|
// o <-settle-- o <--settle-- o
|
||||||
// Alice Bob Carol
|
// Alice Bob Carol
|
||||||
//
|
//
|
||||||
func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket) error {
|
func (s *Switch) handleLocalDispatch(packet *htlcPacket) error {
|
||||||
|
// Pending payments use a special interpretation of the incomingChanID and
|
||||||
|
// incomingHTLCID fields on packet where the channel ID is blank and the
|
||||||
|
// HTLC ID is the payment ID. The switch basically views the users of the
|
||||||
|
// node as a special channel that also offers a sequence of HTLCs.
|
||||||
|
payment, err := s.findPayment(packet.incomingHTLCID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
switch htlc := packet.htlc.(type) {
|
switch htlc := packet.htlc.(type) {
|
||||||
|
|
||||||
// User have created the htlc update therefore we should find the
|
// User have created the htlc update therefore we should find the
|
||||||
|
@ -407,6 +420,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
|
||||||
// manages then channel.
|
// manages then channel.
|
||||||
//
|
//
|
||||||
// TODO(roasbeef): should return with an error
|
// TODO(roasbeef): should return with an error
|
||||||
|
packet.outgoingChanID = destination.ShortChanID()
|
||||||
destination.HandleSwitchPacket(packet)
|
destination.HandleSwitchPacket(packet)
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
@ -416,7 +430,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
|
||||||
// Notify the user that his payment was successfully proceed.
|
// Notify the user that his payment was successfully proceed.
|
||||||
payment.err <- nil
|
payment.err <- nil
|
||||||
payment.preimage <- htlc.PaymentPreimage
|
payment.preimage <- htlc.PaymentPreimage
|
||||||
s.removePendingPayment(payment.amount, payment.paymentHash)
|
s.removePendingPayment(packet.incomingHTLCID)
|
||||||
|
|
||||||
// We've just received a fail update which means we can finalize the
|
// We've just received a fail update which means we can finalize the
|
||||||
// user payment and return fail response.
|
// user payment and return fail response.
|
||||||
|
@ -439,7 +453,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
|
||||||
}
|
}
|
||||||
|
|
||||||
payment.preimage <- zeroPreimage
|
payment.preimage <- zeroPreimage
|
||||||
s.removePendingPayment(payment.amount, payment.paymentHash)
|
s.removePendingPayment(packet.incomingHTLCID)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return errors.New("wrong update type")
|
return errors.New("wrong update type")
|
||||||
|
@ -458,6 +472,12 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
||||||
// payment circuit within our internal state so we can properly forward
|
// payment circuit within our internal state so we can properly forward
|
||||||
// the ultimate settle message back latter.
|
// the ultimate settle message back latter.
|
||||||
case *lnwire.UpdateAddHTLC:
|
case *lnwire.UpdateAddHTLC:
|
||||||
|
if packet.incomingChanID == (lnwire.ShortChannelID{}) {
|
||||||
|
// A blank incomingChanID indicates that this is a pending
|
||||||
|
// user-initiated payment.
|
||||||
|
return s.handleLocalDispatch(packet)
|
||||||
|
}
|
||||||
|
|
||||||
source, err := s.getLinkByShortID(packet.incomingChanID)
|
source, err := s.getLinkByShortID(packet.incomingChanID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := errors.Errorf("unable to find channel link "+
|
err := errors.Errorf("unable to find channel link "+
|
||||||
|
@ -581,15 +601,21 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
||||||
circuit.OutgoingChanID)
|
circuit.OutgoingChanID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
packet.incomingChanID = circuit.IncomingChanID
|
||||||
|
packet.incomingHTLCID = circuit.IncomingHTLCID
|
||||||
|
|
||||||
|
// A blank IncomingChanID in a circuit indicates that it is a
|
||||||
|
// pending user-initiated payment.
|
||||||
|
if circuit.IncomingChanID == (lnwire.ShortChannelID{}) {
|
||||||
|
return s.handleLocalDispatch(packet)
|
||||||
|
}
|
||||||
|
|
||||||
// Obfuscate the error message for fail updates before sending back
|
// Obfuscate the error message for fail updates before sending back
|
||||||
// through the circuit.
|
// through the circuit.
|
||||||
if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated {
|
if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated {
|
||||||
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
|
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
|
||||||
htlc.Reason)
|
htlc.Reason)
|
||||||
}
|
}
|
||||||
|
|
||||||
packet.incomingChanID = circuit.IncomingChanID
|
|
||||||
packet.incomingHTLCID = circuit.IncomingHTLCID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
source, err := s.getLinkByShortID(packet.incomingChanID)
|
source, err := s.getLinkByShortID(packet.incomingChanID)
|
||||||
|
@ -696,37 +722,7 @@ func (s *Switch) htlcForwarder() {
|
||||||
// packet concretely, then either forward it along, or
|
// packet concretely, then either forward it along, or
|
||||||
// interpret a return packet to a locally initialized one.
|
// interpret a return packet to a locally initialized one.
|
||||||
case cmd := <-s.htlcPlex:
|
case cmd := <-s.htlcPlex:
|
||||||
var (
|
cmd.err <- s.handlePacketForward(cmd.pkt)
|
||||||
paymentHash lnwallet.PaymentHash
|
|
||||||
amount lnwire.MilliSatoshi
|
|
||||||
)
|
|
||||||
|
|
||||||
// Only three types of message should be forwarded:
|
|
||||||
// add, fails, and settles. Anything else is an error.
|
|
||||||
switch m := cmd.pkt.htlc.(type) {
|
|
||||||
case *lnwire.UpdateAddHTLC:
|
|
||||||
paymentHash = m.PaymentHash
|
|
||||||
amount = m.Amount
|
|
||||||
case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC:
|
|
||||||
paymentHash = cmd.pkt.payHash
|
|
||||||
amount = cmd.pkt.amount
|
|
||||||
default:
|
|
||||||
cmd.err <- errors.New("wrong type of update")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we can locate this packet in our local records,
|
|
||||||
// then this means a local sub-system initiated it.
|
|
||||||
// Otherwise, this is just a packet to be forwarded, so
|
|
||||||
// we'll treat it as so.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): can fast path this
|
|
||||||
payment, err := s.findPayment(amount, paymentHash)
|
|
||||||
if err != nil {
|
|
||||||
cmd.err <- s.handlePacketForward(cmd.pkt)
|
|
||||||
} else {
|
|
||||||
cmd.err <- s.handleLocalDispatch(payment, cmd.pkt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The log ticker has fired, so we'll calculate some forwarding
|
// The log ticker has fired, so we'll calculate some forwarding
|
||||||
// stats for the last 10 seconds to display within the logs to
|
// stats for the last 10 seconds to display within the logs to
|
||||||
|
@ -1034,64 +1030,36 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
|
||||||
|
|
||||||
// removePendingPayment is the helper function which removes the pending user
|
// removePendingPayment is the helper function which removes the pending user
|
||||||
// payment.
|
// payment.
|
||||||
func (s *Switch) removePendingPayment(amount lnwire.MilliSatoshi,
|
func (s *Switch) removePendingPayment(paymentID uint64) error {
|
||||||
hash lnwallet.PaymentHash) error {
|
|
||||||
|
|
||||||
s.pendingMutex.Lock()
|
s.pendingMutex.Lock()
|
||||||
defer s.pendingMutex.Unlock()
|
defer s.pendingMutex.Unlock()
|
||||||
|
|
||||||
payments, ok := s.pendingPayments[hash]
|
if _, ok := s.pendingPayments[paymentID]; !ok {
|
||||||
if ok {
|
return errors.Errorf("Cannot find pending payment with ID %d",
|
||||||
for i, payment := range payments {
|
paymentID)
|
||||||
if payment.amount == amount {
|
|
||||||
// Delete without preserving order
|
|
||||||
// Google: Golang slice tricks
|
|
||||||
payments[i] = payments[len(payments)-1]
|
|
||||||
payments[len(payments)-1] = nil
|
|
||||||
s.pendingPayments[hash] = payments[:len(payments)-1]
|
|
||||||
|
|
||||||
if len(s.pendingPayments[hash]) == 0 {
|
|
||||||
delete(s.pendingPayments, hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.Errorf("unable to remove pending payment with "+
|
delete(s.pendingPayments, paymentID)
|
||||||
"hash(%v) and amount(%v)", hash, amount)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// findPayment is the helper function which find the payment.
|
// findPayment is the helper function which find the payment.
|
||||||
func (s *Switch) findPayment(amount lnwire.MilliSatoshi,
|
func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) {
|
||||||
hash lnwallet.PaymentHash) (*pendingPayment, error) {
|
|
||||||
|
|
||||||
s.pendingMutex.RLock()
|
s.pendingMutex.RLock()
|
||||||
defer s.pendingMutex.RUnlock()
|
defer s.pendingMutex.RUnlock()
|
||||||
|
|
||||||
payments, ok := s.pendingPayments[hash]
|
payment, ok := s.pendingPayments[paymentID]
|
||||||
if ok {
|
if !ok {
|
||||||
for _, payment := range payments {
|
return nil, errors.Errorf("Cannot find pending payment with ID %d",
|
||||||
if payment.amount == amount {
|
paymentID)
|
||||||
return payment, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return payment, nil
|
||||||
return nil, errors.Errorf("unable to remove pending payment with "+
|
|
||||||
"hash(%v) and amount(%v)", hash, amount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// numPendingPayments is helper function which returns the overall number of
|
// numPendingPayments is helper function which returns the overall number of
|
||||||
// pending user payments.
|
// pending user payments.
|
||||||
func (s *Switch) numPendingPayments() int {
|
func (s *Switch) numPendingPayments() int {
|
||||||
var l int
|
return len(s.pendingPayments)
|
||||||
for _, payments := range s.pendingPayments {
|
|
||||||
l += len(payments)
|
|
||||||
}
|
|
||||||
|
|
||||||
return l
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addCircuit adds a circuit to the switch's in-memory mapping.
|
// addCircuit adds a circuit to the switch's in-memory mapping.
|
||||||
|
|
Loading…
Add table
Reference in a new issue