htlcswitch: add heldHtlcSet

Isolation of the set logic so that it will be easier to add watchdog functionality later.
This commit is contained in:
Joost Jager 2022-08-15 16:24:38 +02:00
parent a6df9567ba
commit 9c063db698
No known key found for this signature in database
GPG Key ID: B9A26449A5528325
3 changed files with 196 additions and 30 deletions

View File

@ -0,0 +1,75 @@
package htlcswitch
import (
"errors"
"fmt"
"github.com/lightningnetwork/lnd/channeldb"
)
// heldHtlcSet keeps track of outstanding intercepted forwards. It exposes
// several methods to manipulate the underlying map structure in a consistent
// way.
type heldHtlcSet struct {
set map[channeldb.CircuitKey]InterceptedForward
}
func newHeldHtlcSet() *heldHtlcSet {
return &heldHtlcSet{
set: make(map[channeldb.CircuitKey]InterceptedForward),
}
}
// forEach iterates over all held forwards and calls the given callback for each
// of them.
func (h *heldHtlcSet) forEach(cb func(InterceptedForward)) {
for _, fwd := range h.set {
cb(fwd)
}
}
// popAll calls the callback for each forward and removes them from the set.
func (h *heldHtlcSet) popAll(cb func(InterceptedForward)) {
for _, fwd := range h.set {
cb(fwd)
}
h.set = make(map[channeldb.CircuitKey]InterceptedForward)
}
// pop returns the specified forward and removes it from the set.
func (h *heldHtlcSet) pop(key channeldb.CircuitKey) (InterceptedForward, error) {
intercepted, ok := h.set[key]
if !ok {
return nil, fmt.Errorf("fwd %v not found", key)
}
delete(h.set, key)
return intercepted, nil
}
// exists tests whether the specified forward is part of the set.
func (h *heldHtlcSet) exists(key channeldb.CircuitKey) bool {
_, ok := h.set[key]
return ok
}
// push adds the specified forward to the set. An error is returned if the
// forward exists already.
func (h *heldHtlcSet) push(key channeldb.CircuitKey,
fwd InterceptedForward) error {
if fwd == nil {
return errors.New("nil fwd pushed")
}
if h.exists(key) {
return errors.New("htlc already exists in set")
}
h.set[key] = fwd
return nil
}

View File

@ -0,0 +1,80 @@
package htlcswitch
import (
"testing"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
)
func TestHeldHtlcSetEmpty(t *testing.T) {
set := newHeldHtlcSet()
// Test operations on an empty set.
require.False(t, set.exists(channeldb.CircuitKey{}))
_, err := set.pop(channeldb.CircuitKey{})
require.Error(t, err)
set.popAll(
func(_ InterceptedForward) {
require.Fail(t, "unexpected fwd")
},
)
}
func TestHeldHtlcSet(t *testing.T) {
set := newHeldHtlcSet()
key := channeldb.CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(1),
HtlcID: 2,
}
// Test pushing a nil forward.
require.Error(t, set.push(key, nil))
// Test pushing a forward.
fwd := &interceptedForward{
htlc: &lnwire.UpdateAddHTLC{},
}
require.NoError(t, set.push(key, fwd))
// Re-pushing should fail.
require.Error(t, set.push(key, fwd))
// Test popping the fwd.
poppedFwd, err := set.pop(key)
require.NoError(t, err)
require.Equal(t, fwd, poppedFwd)
_, err = set.pop(key)
require.Error(t, err)
// Pushing the forward again.
require.NoError(t, set.push(key, fwd))
// Test for each.
var cbCalled bool
set.forEach(func(_ InterceptedForward) {
cbCalled = true
require.Equal(t, fwd, poppedFwd)
})
require.True(t, cbCalled)
// Test popping all forwards.
cbCalled = false
set.popAll(
func(_ InterceptedForward) {
cbCalled = true
require.Equal(t, fwd, poppedFwd)
},
)
require.True(t, cbCalled)
_, err = set.pop(key)
require.Error(t, err)
}

View File

@ -57,8 +57,8 @@ type InterceptableSwitch struct {
// interceptor is the handler for intercepted packets.
interceptor ForwardInterceptor
// holdForwards keeps track of outstanding intercepted forwards.
holdForwards map[channeldb.CircuitKey]InterceptedForward
// heldHtlcSet keeps track of outstanding intercepted forwards.
heldHtlcSet *heldHtlcSet
// cltvRejectDelta defines the number of blocks before the expiry of the
// htlc where we no longer intercept it and instead cancel it back.
@ -152,7 +152,7 @@ func NewInterceptableSwitch(cfg *InterceptableSwitchConfig) *InterceptableSwitch
intercepted: make(chan *interceptedPackets),
onchainIntercepted: make(chan InterceptedForward),
interceptorRegistration: make(chan ForwardInterceptor),
holdForwards: make(map[channeldb.CircuitKey]InterceptedForward),
heldHtlcSet: newHeldHtlcSet(),
resolutionChan: make(chan *fwdResolution),
requireInterceptor: cfg.RequireInterceptor,
cltvRejectDelta: cfg.CltvRejectDelta,
@ -231,7 +231,14 @@ func (s *InterceptableSwitch) run() error {
case packets := <-s.intercepted:
var notIntercepted []*htlcPacket
for _, p := range packets.packets {
if !s.interceptForward(p, packets.isReplay) {
intercepted, err := s.interceptForward(
p, packets.isReplay,
)
if err != nil {
return err
}
if !intercepted {
notIntercepted = append(
notIntercepted, p,
)
@ -252,7 +259,9 @@ func (s *InterceptableSwitch) run() error {
// already intercepted in the off-chain flow. And even
// if not, it is safe to signal replay so that we won't
// unexpectedly skip over this htlc.
s.forward(fwd, true)
if _, err := s.forward(fwd, true); err != nil {
return err
}
case res := <-s.resolutionChan:
res.errChan <- s.resolve(res.resolution)
@ -287,9 +296,7 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
if interceptor != nil {
log.Debugf("Interceptor connected")
for _, fwd := range s.holdForwards {
s.sendForward(fwd)
}
s.heldHtlcSet.forEach(s.sendForward)
return
}
@ -305,20 +312,19 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
// Interceptor is not required. Release held forwards.
log.Infof("Interceptor disconnected, resolving held packets")
for _, fwd := range s.holdForwards {
if err := fwd.Resume(); err != nil {
s.heldHtlcSet.popAll(func(fwd InterceptedForward) {
err := fwd.Resume()
if err != nil {
log.Errorf("Failed to resume hold forward %v", err)
}
}
s.holdForwards = make(map[channeldb.CircuitKey]InterceptedForward)
})
}
func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
intercepted, ok := s.holdForwards[res.Key]
if !ok {
return fmt.Errorf("fwd %v not found", res.Key)
intercepted, err := s.heldHtlcSet.pop(res.Key)
if err != nil {
return err
}
delete(s.holdForwards, res.Key)
switch res.Action {
case FwdActionResume:
@ -405,13 +411,13 @@ func (s *InterceptableSwitch) ForwardPacket(
// interceptForward forwards the packet to the external interceptor after
// checking the interception criteria.
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
isReplay bool) bool {
isReplay bool) (bool, error) {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
// We are not interested in intercepting initiated payments.
if packet.incomingChanID == hop.Source {
return false
return false, nil
}
intercepted := &interceptedForward{
@ -435,28 +441,28 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
// will remain stuck and potentially force-close the
// channel. But in the end, we should never get here, so
// the actual return value doesn't matter that much.
return false
return false, nil
}
if handled {
return true
return true, nil
}
return s.forward(intercepted, isReplay)
default:
return false
return false, nil
}
}
// forward records the intercepted htlc and forwards it to the interceptor.
func (s *InterceptableSwitch) forward(
fwd InterceptedForward, isReplay bool) bool {
fwd InterceptedForward, isReplay bool) (bool, error) {
inKey := fwd.Packet().IncomingCircuit
// Ignore already held htlcs.
if _, ok := s.holdForwards[inKey]; ok {
return true
if s.heldHtlcSet.exists(inKey) {
return true, nil
}
// If there is no interceptor currently registered, configuration and packet
@ -464,7 +470,7 @@ func (s *InterceptableSwitch) forward(
if s.interceptor == nil {
// Process normally if an interceptor is not required.
if !s.requireInterceptor {
return false
return false, nil
}
// We are in interceptor-required mode. If this is a new packet, it is
@ -478,23 +484,28 @@ func (s *InterceptableSwitch) forward(
log.Errorf("Cannot fail packet: %v", err)
}
return true
return true, nil
}
// This packet is a replay. It is not safe to fail back, because the
// interceptor may still signal otherwise upon reconnect. Keep the
// packet in the queue until then.
s.holdForwards[inKey] = fwd
if err := s.heldHtlcSet.push(inKey, fwd); err != nil {
return false, err
}
return true
return true, nil
}
// There is an interceptor registered. We can forward the packet right now.
// Hold it in the queue too to track what is outstanding.
s.holdForwards[inKey] = fwd
if err := s.heldHtlcSet.push(inKey, fwd); err != nil {
return false, err
}
s.sendForward(fwd)
return true
return true, nil
}
// handleExpired checks that the htlc isn't too close to the channel