diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index d4e9518c8..425f99786 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1104,185 +1104,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // payment circuit within our internal state so we can properly forward // the ultimate settle message back latter. case *lnwire.UpdateAddHTLC: - // Check if the node is set to reject all onward HTLCs and also make - // sure that HTLC is not from the source node. - if s.cfg.RejectHTLC { - failure := NewDetailedLinkError( - &lnwire.FailChannelDisabled{}, - OutgoingFailureForwardsDisabled, - ) - - return s.failAddPacket(packet, failure) - } - - // Before we attempt to find a non-strict forwarding path for - // this htlc, check whether the htlc is being routed over the - // same incoming and outgoing channel. If our node does not - // allow forwards of this nature, we fail the htlc early. This - // check is in place to disallow inefficiently routed htlcs from - // locking up our balance. With channels where the - // option-scid-alias feature was negotiated, we also have to be - // sure that the IDs aren't the same since one or both could be - // an alias. - linkErr := s.checkCircularForward( - packet.incomingChanID, packet.outgoingChanID, - s.cfg.AllowCircularRoute, htlc.PaymentHash, - ) - if linkErr != nil { - return s.failAddPacket(packet, linkErr) - } - - s.indexMtx.RLock() - targetLink, err := s.getLinkByMapping(packet) - if err != nil { - s.indexMtx.RUnlock() - - log.Debugf("unable to find link with "+ - "destination %v", packet.outgoingChanID) - - // If packet was forwarded from another channel link - // than we should notify this link that some error - // occurred. - linkError := NewLinkError( - &lnwire.FailUnknownNextPeer{}, - ) - - return s.failAddPacket(packet, linkError) - } - targetPeerKey := targetLink.PeerPubKey() - interfaceLinks, _ := s.getLinks(targetPeerKey) - s.indexMtx.RUnlock() - - // We'll keep track of any HTLC failures during the link - // selection process. This way we can return the error for - // precise link that the sender selected, while optimistically - // trying all links to utilize our available bandwidth. - linkErrs := make(map[lnwire.ShortChannelID]*LinkError) - - // Find all destination channel links with appropriate - // bandwidth. - var destinations []ChannelLink - for _, link := range interfaceLinks { - var failure *LinkError - - // We'll skip any links that aren't yet eligible for - // forwarding. - if !link.EligibleToForward() { - failure = NewDetailedLinkError( - &lnwire.FailUnknownNextPeer{}, - OutgoingFailureLinkNotEligible, - ) - } else { - // We'll ensure that the HTLC satisfies the - // current forwarding conditions of this target - // link. - currentHeight := atomic.LoadUint32(&s.bestHeight) - failure = link.CheckHtlcForward( - htlc.PaymentHash, packet.incomingAmount, - packet.amount, packet.incomingTimeout, - packet.outgoingTimeout, - packet.inboundFee, - currentHeight, - packet.originalOutgoingChanID, - ) - } - - // If this link can forward the htlc, add it to the set - // of destinations. - if failure == nil { - destinations = append(destinations, link) - continue - } - - linkErrs[link.ShortChanID()] = failure - } - - // If we had a forwarding failure due to the HTLC not - // satisfying the current policy, then we'll send back an - // error, but ensure we send back the error sourced at the - // *target* link. - if len(destinations) == 0 { - // At this point, some or all of the links rejected the - // HTLC so we couldn't forward it. So we'll try to look - // up the error that came from the source. - linkErr, ok := linkErrs[packet.outgoingChanID] - if !ok { - // If we can't find the error of the source, - // then we'll return an unknown next peer, - // though this should never happen. - linkErr = NewLinkError( - &lnwire.FailUnknownNextPeer{}, - ) - log.Warnf("unable to find err source for "+ - "outgoing_link=%v, errors=%v", - packet.outgoingChanID, - lnutils.SpewLogClosure(linkErrs)) - } - - log.Tracef("incoming HTLC(%x) violated "+ - "target outgoing link (id=%v) policy: %v", - htlc.PaymentHash[:], packet.outgoingChanID, - linkErr) - - return s.failAddPacket(packet, linkErr) - } - - // Choose a random link out of the set of links that can forward - // this htlc. The reason for randomization is to evenly - // distribute the htlc load without making assumptions about - // what the best channel is. - destination := destinations[rand.Intn(len(destinations))] // nolint:gosec - - // Retrieve the incoming link by its ShortChannelID. Note that - // the incomingChanID is never set to hop.Source here. - s.indexMtx.RLock() - incomingLink, err := s.getLinkByShortID(packet.incomingChanID) - s.indexMtx.RUnlock() - if err != nil { - // If we couldn't find the incoming link, we can't - // evaluate the incoming's exposure to dust, so we just - // fail the HTLC back. - linkErr := NewLinkError( - &lnwire.FailTemporaryChannelFailure{}, - ) - - return s.failAddPacket(packet, linkErr) - } - - // Evaluate whether this HTLC would increase our fee exposure - // over the threshold on the incoming link. If it does, fail it - // backwards. - if s.dustExceedsFeeThreshold( - incomingLink, packet.incomingAmount, true, - ) { - // The incoming dust exceeds the threshold, so we fail - // the add back. - linkErr := NewLinkError( - &lnwire.FailTemporaryChannelFailure{}, - ) - - return s.failAddPacket(packet, linkErr) - } - - // Also evaluate whether this HTLC would increase our fee - // exposure over the threshold on the destination link. If it - // does, fail it back. - if s.dustExceedsFeeThreshold( - destination, packet.amount, false, - ) { - // The outgoing dust exceeds the threshold, so we fail - // the add back. - linkErr := NewLinkError( - &lnwire.FailTemporaryChannelFailure{}, - ) - - return s.failAddPacket(packet, linkErr) - } - - // Send the packet to the destination channel link which - // manages the channel. - packet.outgoingChanID = destination.ShortChanID() - return destination.handleSwitchPacket(packet) + return s.handlePacketAdd(packet, htlc) case *lnwire.UpdateFailHTLC, *lnwire.UpdateFulfillHTLC: // If the source of this packet has not been set, use the @@ -3052,3 +2874,180 @@ func (s *Switch) AddAliasForLink(chanID lnwire.ChannelID, return nil } + +// handlePacketAdd handles forwarding an Add packet. +func (s *Switch) handlePacketAdd(packet *htlcPacket, + htlc *lnwire.UpdateAddHTLC) error { + + // Check if the node is set to reject all onward HTLCs and also make + // sure that HTLC is not from the source node. + if s.cfg.RejectHTLC { + failure := NewDetailedLinkError( + &lnwire.FailChannelDisabled{}, + OutgoingFailureForwardsDisabled, + ) + + return s.failAddPacket(packet, failure) + } + + // Before we attempt to find a non-strict forwarding path for this + // htlc, check whether the htlc is being routed over the same incoming + // and outgoing channel. If our node does not allow forwards of this + // nature, we fail the htlc early. This check is in place to disallow + // inefficiently routed htlcs from locking up our balance. With + // channels where the option-scid-alias feature was negotiated, we also + // have to be sure that the IDs aren't the same since one or both could + // be an alias. + linkErr := s.checkCircularForward( + packet.incomingChanID, packet.outgoingChanID, + s.cfg.AllowCircularRoute, htlc.PaymentHash, + ) + if linkErr != nil { + return s.failAddPacket(packet, linkErr) + } + + s.indexMtx.RLock() + targetLink, err := s.getLinkByMapping(packet) + if err != nil { + s.indexMtx.RUnlock() + + log.Debugf("unable to find link with "+ + "destination %v", packet.outgoingChanID) + + // If packet was forwarded from another channel link than we + // should notify this link that some error occurred. + linkError := NewLinkError( + &lnwire.FailUnknownNextPeer{}, + ) + + return s.failAddPacket(packet, linkError) + } + targetPeerKey := targetLink.PeerPubKey() + interfaceLinks, _ := s.getLinks(targetPeerKey) + s.indexMtx.RUnlock() + + // We'll keep track of any HTLC failures during the link selection + // process. This way we can return the error for precise link that the + // sender selected, while optimistically trying all links to utilize + // our available bandwidth. + linkErrs := make(map[lnwire.ShortChannelID]*LinkError) + + // Find all destination channel links with appropriate bandwidth. + var destinations []ChannelLink + for _, link := range interfaceLinks { + var failure *LinkError + + // We'll skip any links that aren't yet eligible for + // forwarding. + if !link.EligibleToForward() { + failure = NewDetailedLinkError( + &lnwire.FailUnknownNextPeer{}, + OutgoingFailureLinkNotEligible, + ) + } else { + // We'll ensure that the HTLC satisfies the current + // forwarding conditions of this target link. + currentHeight := atomic.LoadUint32(&s.bestHeight) + failure = link.CheckHtlcForward( + htlc.PaymentHash, packet.incomingAmount, + packet.amount, packet.incomingTimeout, + packet.outgoingTimeout, + packet.inboundFee, + currentHeight, + packet.originalOutgoingChanID, + ) + } + + // If this link can forward the htlc, add it to the set of + // destinations. + if failure == nil { + destinations = append(destinations, link) + continue + } + + linkErrs[link.ShortChanID()] = failure + } + + // If we had a forwarding failure due to the HTLC not satisfying the + // current policy, then we'll send back an error, but ensure we send + // back the error sourced at the *target* link. + if len(destinations) == 0 { + // At this point, some or all of the links rejected the HTLC so + // we couldn't forward it. So we'll try to look up the error + // that came from the source. + linkErr, ok := linkErrs[packet.outgoingChanID] + if !ok { + // If we can't find the error of the source, then we'll + // return an unknown next peer, though this should + // never happen. + linkErr = NewLinkError( + &lnwire.FailUnknownNextPeer{}, + ) + log.Warnf("unable to find err source for "+ + "outgoing_link=%v, errors=%v", + packet.outgoingChanID, + lnutils.SpewLogClosure(linkErrs)) + } + + log.Tracef("incoming HTLC(%x) violated "+ + "target outgoing link (id=%v) policy: %v", + htlc.PaymentHash[:], packet.outgoingChanID, + linkErr) + + return s.failAddPacket(packet, linkErr) + } + + // Choose a random link out of the set of links that can forward this + // htlc. The reason for randomization is to evenly distribute the htlc + // load without making assumptions about what the best channel is. + destination := destinations[rand.Intn(len(destinations))] // nolint:gosec + + // Retrieve the incoming link by its ShortChannelID. Note that the + // incomingChanID is never set to hop.Source here. + s.indexMtx.RLock() + incomingLink, err := s.getLinkByShortID(packet.incomingChanID) + s.indexMtx.RUnlock() + if err != nil { + // If we couldn't find the incoming link, we can't evaluate the + // incoming's exposure to dust, so we just fail the HTLC back. + linkErr := NewLinkError( + &lnwire.FailTemporaryChannelFailure{}, + ) + + return s.failAddPacket(packet, linkErr) + } + + // Evaluate whether this HTLC would increase our fee exposure over the + // threshold on the incoming link. If it does, fail it backwards. + if s.dustExceedsFeeThreshold( + incomingLink, packet.incomingAmount, true, + ) { + // The incoming dust exceeds the threshold, so we fail the add + // back. + linkErr := NewLinkError( + &lnwire.FailTemporaryChannelFailure{}, + ) + + return s.failAddPacket(packet, linkErr) + } + + // Also evaluate whether this HTLC would increase our fee exposure over + // the threshold on the destination link. If it does, fail it back. + if s.dustExceedsFeeThreshold( + destination, packet.amount, false, + ) { + // The outgoing dust exceeds the threshold, so we fail the add + // back. + linkErr := NewLinkError( + &lnwire.FailTemporaryChannelFailure{}, + ) + + return s.failAddPacket(packet, linkErr) + } + + // Send the packet to the destination channel link which manages the + // channel. + packet.outgoingChanID = destination.ShortChanID() + + return destination.handleSwitchPacket(packet) +}