diff --git a/htlcswitch/held_htlc_set.go b/htlcswitch/held_htlc_set.go new file mode 100644 index 000000000..f7f0562cb --- /dev/null +++ b/htlcswitch/held_htlc_set.go @@ -0,0 +1,75 @@ +package htlcswitch + +import ( + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/channeldb" +) + +// heldHtlcSet keeps track of outstanding intercepted forwards. It exposes +// several methods to manipulate the underlying map structure in a consistent +// way. +type heldHtlcSet struct { + set map[channeldb.CircuitKey]InterceptedForward +} + +func newHeldHtlcSet() *heldHtlcSet { + return &heldHtlcSet{ + set: make(map[channeldb.CircuitKey]InterceptedForward), + } +} + +// forEach iterates over all held forwards and calls the given callback for each +// of them. +func (h *heldHtlcSet) forEach(cb func(InterceptedForward)) { + for _, fwd := range h.set { + cb(fwd) + } +} + +// popAll calls the callback for each forward and removes them from the set. +func (h *heldHtlcSet) popAll(cb func(InterceptedForward)) { + for _, fwd := range h.set { + cb(fwd) + } + + h.set = make(map[channeldb.CircuitKey]InterceptedForward) +} + +// pop returns the specified forward and removes it from the set. +func (h *heldHtlcSet) pop(key channeldb.CircuitKey) (InterceptedForward, error) { + intercepted, ok := h.set[key] + if !ok { + return nil, fmt.Errorf("fwd %v not found", key) + } + + delete(h.set, key) + + return intercepted, nil +} + +// exists tests whether the specified forward is part of the set. +func (h *heldHtlcSet) exists(key channeldb.CircuitKey) bool { + _, ok := h.set[key] + + return ok +} + +// push adds the specified forward to the set. An error is returned if the +// forward exists already. +func (h *heldHtlcSet) push(key channeldb.CircuitKey, + fwd InterceptedForward) error { + + if fwd == nil { + return errors.New("nil fwd pushed") + } + + if h.exists(key) { + return errors.New("htlc already exists in set") + } + + h.set[key] = fwd + + return nil +} diff --git a/htlcswitch/held_htlc_set_test.go b/htlcswitch/held_htlc_set_test.go new file mode 100644 index 000000000..6a39fc5fc --- /dev/null +++ b/htlcswitch/held_htlc_set_test.go @@ -0,0 +1,80 @@ +package htlcswitch + +import ( + "testing" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +func TestHeldHtlcSetEmpty(t *testing.T) { + set := newHeldHtlcSet() + + // Test operations on an empty set. + require.False(t, set.exists(channeldb.CircuitKey{})) + + _, err := set.pop(channeldb.CircuitKey{}) + require.Error(t, err) + + set.popAll( + func(_ InterceptedForward) { + require.Fail(t, "unexpected fwd") + }, + ) +} + +func TestHeldHtlcSet(t *testing.T) { + set := newHeldHtlcSet() + + key := channeldb.CircuitKey{ + ChanID: lnwire.NewShortChanIDFromInt(1), + HtlcID: 2, + } + + // Test pushing a nil forward. + require.Error(t, set.push(key, nil)) + + // Test pushing a forward. + fwd := &interceptedForward{ + htlc: &lnwire.UpdateAddHTLC{}, + } + require.NoError(t, set.push(key, fwd)) + + // Re-pushing should fail. + require.Error(t, set.push(key, fwd)) + + // Test popping the fwd. + poppedFwd, err := set.pop(key) + require.NoError(t, err) + require.Equal(t, fwd, poppedFwd) + + _, err = set.pop(key) + require.Error(t, err) + + // Pushing the forward again. + require.NoError(t, set.push(key, fwd)) + + // Test for each. + var cbCalled bool + set.forEach(func(_ InterceptedForward) { + cbCalled = true + + require.Equal(t, fwd, poppedFwd) + }) + require.True(t, cbCalled) + + // Test popping all forwards. + cbCalled = false + set.popAll( + func(_ InterceptedForward) { + cbCalled = true + + require.Equal(t, fwd, poppedFwd) + }, + ) + require.True(t, cbCalled) + + _, err = set.pop(key) + require.Error(t, err) +} diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 15b6a8799..9675c3a4b 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -57,8 +57,8 @@ type InterceptableSwitch struct { // interceptor is the handler for intercepted packets. interceptor ForwardInterceptor - // holdForwards keeps track of outstanding intercepted forwards. - holdForwards map[channeldb.CircuitKey]InterceptedForward + // heldHtlcSet keeps track of outstanding intercepted forwards. + heldHtlcSet *heldHtlcSet // cltvRejectDelta defines the number of blocks before the expiry of the // htlc where we no longer intercept it and instead cancel it back. @@ -152,7 +152,7 @@ func NewInterceptableSwitch(cfg *InterceptableSwitchConfig) *InterceptableSwitch intercepted: make(chan *interceptedPackets), onchainIntercepted: make(chan InterceptedForward), interceptorRegistration: make(chan ForwardInterceptor), - holdForwards: make(map[channeldb.CircuitKey]InterceptedForward), + heldHtlcSet: newHeldHtlcSet(), resolutionChan: make(chan *fwdResolution), requireInterceptor: cfg.RequireInterceptor, cltvRejectDelta: cfg.CltvRejectDelta, @@ -231,7 +231,14 @@ func (s *InterceptableSwitch) run() error { case packets := <-s.intercepted: var notIntercepted []*htlcPacket for _, p := range packets.packets { - if !s.interceptForward(p, packets.isReplay) { + intercepted, err := s.interceptForward( + p, packets.isReplay, + ) + if err != nil { + return err + } + + if !intercepted { notIntercepted = append( notIntercepted, p, ) @@ -252,7 +259,9 @@ func (s *InterceptableSwitch) run() error { // already intercepted in the off-chain flow. And even // if not, it is safe to signal replay so that we won't // unexpectedly skip over this htlc. - s.forward(fwd, true) + if _, err := s.forward(fwd, true); err != nil { + return err + } case res := <-s.resolutionChan: res.errChan <- s.resolve(res.resolution) @@ -287,9 +296,7 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) { if interceptor != nil { log.Debugf("Interceptor connected") - for _, fwd := range s.holdForwards { - s.sendForward(fwd) - } + s.heldHtlcSet.forEach(s.sendForward) return } @@ -305,20 +312,19 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) { // Interceptor is not required. Release held forwards. log.Infof("Interceptor disconnected, resolving held packets") - for _, fwd := range s.holdForwards { - if err := fwd.Resume(); err != nil { + s.heldHtlcSet.popAll(func(fwd InterceptedForward) { + err := fwd.Resume() + if err != nil { log.Errorf("Failed to resume hold forward %v", err) } - } - s.holdForwards = make(map[channeldb.CircuitKey]InterceptedForward) + }) } func (s *InterceptableSwitch) resolve(res *FwdResolution) error { - intercepted, ok := s.holdForwards[res.Key] - if !ok { - return fmt.Errorf("fwd %v not found", res.Key) + intercepted, err := s.heldHtlcSet.pop(res.Key) + if err != nil { + return err } - delete(s.holdForwards, res.Key) switch res.Action { case FwdActionResume: @@ -405,13 +411,13 @@ func (s *InterceptableSwitch) ForwardPacket( // interceptForward forwards the packet to the external interceptor after // checking the interception criteria. func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, - isReplay bool) bool { + isReplay bool) (bool, error) { switch htlc := packet.htlc.(type) { case *lnwire.UpdateAddHTLC: // We are not interested in intercepting initiated payments. if packet.incomingChanID == hop.Source { - return false + return false, nil } intercepted := &interceptedForward{ @@ -435,28 +441,28 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, // will remain stuck and potentially force-close the // channel. But in the end, we should never get here, so // the actual return value doesn't matter that much. - return false + return false, nil } if handled { - return true + return true, nil } return s.forward(intercepted, isReplay) default: - return false + return false, nil } } // forward records the intercepted htlc and forwards it to the interceptor. func (s *InterceptableSwitch) forward( - fwd InterceptedForward, isReplay bool) bool { + fwd InterceptedForward, isReplay bool) (bool, error) { inKey := fwd.Packet().IncomingCircuit // Ignore already held htlcs. - if _, ok := s.holdForwards[inKey]; ok { - return true + if s.heldHtlcSet.exists(inKey) { + return true, nil } // If there is no interceptor currently registered, configuration and packet @@ -464,7 +470,7 @@ func (s *InterceptableSwitch) forward( if s.interceptor == nil { // Process normally if an interceptor is not required. if !s.requireInterceptor { - return false + return false, nil } // We are in interceptor-required mode. If this is a new packet, it is @@ -478,23 +484,28 @@ func (s *InterceptableSwitch) forward( log.Errorf("Cannot fail packet: %v", err) } - return true + return true, nil } // This packet is a replay. It is not safe to fail back, because the // interceptor may still signal otherwise upon reconnect. Keep the // packet in the queue until then. - s.holdForwards[inKey] = fwd + if err := s.heldHtlcSet.push(inKey, fwd); err != nil { + return false, err + } - return true + return true, nil } // There is an interceptor registered. We can forward the packet right now. // Hold it in the queue too to track what is outstanding. - s.holdForwards[inKey] = fwd + if err := s.heldHtlcSet.push(inKey, fwd); err != nil { + return false, err + } + s.sendForward(fwd) - return true + return true, nil } // handleExpired checks that the htlc isn't too close to the channel