mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 09:53:54 +01:00
htlcswitch: add heldHtlcSet
Isolation of the set logic so that it will be easier to add watchdog functionality later.
This commit is contained in:
parent
a6df9567ba
commit
9c063db698
75
htlcswitch/held_htlc_set.go
Normal file
75
htlcswitch/held_htlc_set.go
Normal file
@ -0,0 +1,75 @@
|
||||
package htlcswitch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
)
|
||||
|
||||
// heldHtlcSet keeps track of outstanding intercepted forwards. It exposes
|
||||
// several methods to manipulate the underlying map structure in a consistent
|
||||
// way.
|
||||
type heldHtlcSet struct {
|
||||
set map[channeldb.CircuitKey]InterceptedForward
|
||||
}
|
||||
|
||||
func newHeldHtlcSet() *heldHtlcSet {
|
||||
return &heldHtlcSet{
|
||||
set: make(map[channeldb.CircuitKey]InterceptedForward),
|
||||
}
|
||||
}
|
||||
|
||||
// forEach iterates over all held forwards and calls the given callback for each
|
||||
// of them.
|
||||
func (h *heldHtlcSet) forEach(cb func(InterceptedForward)) {
|
||||
for _, fwd := range h.set {
|
||||
cb(fwd)
|
||||
}
|
||||
}
|
||||
|
||||
// popAll calls the callback for each forward and removes them from the set.
|
||||
func (h *heldHtlcSet) popAll(cb func(InterceptedForward)) {
|
||||
for _, fwd := range h.set {
|
||||
cb(fwd)
|
||||
}
|
||||
|
||||
h.set = make(map[channeldb.CircuitKey]InterceptedForward)
|
||||
}
|
||||
|
||||
// pop returns the specified forward and removes it from the set.
|
||||
func (h *heldHtlcSet) pop(key channeldb.CircuitKey) (InterceptedForward, error) {
|
||||
intercepted, ok := h.set[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("fwd %v not found", key)
|
||||
}
|
||||
|
||||
delete(h.set, key)
|
||||
|
||||
return intercepted, nil
|
||||
}
|
||||
|
||||
// exists tests whether the specified forward is part of the set.
|
||||
func (h *heldHtlcSet) exists(key channeldb.CircuitKey) bool {
|
||||
_, ok := h.set[key]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// push adds the specified forward to the set. An error is returned if the
|
||||
// forward exists already.
|
||||
func (h *heldHtlcSet) push(key channeldb.CircuitKey,
|
||||
fwd InterceptedForward) error {
|
||||
|
||||
if fwd == nil {
|
||||
return errors.New("nil fwd pushed")
|
||||
}
|
||||
|
||||
if h.exists(key) {
|
||||
return errors.New("htlc already exists in set")
|
||||
}
|
||||
|
||||
h.set[key] = fwd
|
||||
|
||||
return nil
|
||||
}
|
80
htlcswitch/held_htlc_set_test.go
Normal file
80
htlcswitch/held_htlc_set_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
package htlcswitch
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHeldHtlcSetEmpty(t *testing.T) {
|
||||
set := newHeldHtlcSet()
|
||||
|
||||
// Test operations on an empty set.
|
||||
require.False(t, set.exists(channeldb.CircuitKey{}))
|
||||
|
||||
_, err := set.pop(channeldb.CircuitKey{})
|
||||
require.Error(t, err)
|
||||
|
||||
set.popAll(
|
||||
func(_ InterceptedForward) {
|
||||
require.Fail(t, "unexpected fwd")
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestHeldHtlcSet(t *testing.T) {
|
||||
set := newHeldHtlcSet()
|
||||
|
||||
key := channeldb.CircuitKey{
|
||||
ChanID: lnwire.NewShortChanIDFromInt(1),
|
||||
HtlcID: 2,
|
||||
}
|
||||
|
||||
// Test pushing a nil forward.
|
||||
require.Error(t, set.push(key, nil))
|
||||
|
||||
// Test pushing a forward.
|
||||
fwd := &interceptedForward{
|
||||
htlc: &lnwire.UpdateAddHTLC{},
|
||||
}
|
||||
require.NoError(t, set.push(key, fwd))
|
||||
|
||||
// Re-pushing should fail.
|
||||
require.Error(t, set.push(key, fwd))
|
||||
|
||||
// Test popping the fwd.
|
||||
poppedFwd, err := set.pop(key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fwd, poppedFwd)
|
||||
|
||||
_, err = set.pop(key)
|
||||
require.Error(t, err)
|
||||
|
||||
// Pushing the forward again.
|
||||
require.NoError(t, set.push(key, fwd))
|
||||
|
||||
// Test for each.
|
||||
var cbCalled bool
|
||||
set.forEach(func(_ InterceptedForward) {
|
||||
cbCalled = true
|
||||
|
||||
require.Equal(t, fwd, poppedFwd)
|
||||
})
|
||||
require.True(t, cbCalled)
|
||||
|
||||
// Test popping all forwards.
|
||||
cbCalled = false
|
||||
set.popAll(
|
||||
func(_ InterceptedForward) {
|
||||
cbCalled = true
|
||||
|
||||
require.Equal(t, fwd, poppedFwd)
|
||||
},
|
||||
)
|
||||
require.True(t, cbCalled)
|
||||
|
||||
_, err = set.pop(key)
|
||||
require.Error(t, err)
|
||||
}
|
@ -57,8 +57,8 @@ type InterceptableSwitch struct {
|
||||
// interceptor is the handler for intercepted packets.
|
||||
interceptor ForwardInterceptor
|
||||
|
||||
// holdForwards keeps track of outstanding intercepted forwards.
|
||||
holdForwards map[channeldb.CircuitKey]InterceptedForward
|
||||
// heldHtlcSet keeps track of outstanding intercepted forwards.
|
||||
heldHtlcSet *heldHtlcSet
|
||||
|
||||
// cltvRejectDelta defines the number of blocks before the expiry of the
|
||||
// htlc where we no longer intercept it and instead cancel it back.
|
||||
@ -152,7 +152,7 @@ func NewInterceptableSwitch(cfg *InterceptableSwitchConfig) *InterceptableSwitch
|
||||
intercepted: make(chan *interceptedPackets),
|
||||
onchainIntercepted: make(chan InterceptedForward),
|
||||
interceptorRegistration: make(chan ForwardInterceptor),
|
||||
holdForwards: make(map[channeldb.CircuitKey]InterceptedForward),
|
||||
heldHtlcSet: newHeldHtlcSet(),
|
||||
resolutionChan: make(chan *fwdResolution),
|
||||
requireInterceptor: cfg.RequireInterceptor,
|
||||
cltvRejectDelta: cfg.CltvRejectDelta,
|
||||
@ -231,7 +231,14 @@ func (s *InterceptableSwitch) run() error {
|
||||
case packets := <-s.intercepted:
|
||||
var notIntercepted []*htlcPacket
|
||||
for _, p := range packets.packets {
|
||||
if !s.interceptForward(p, packets.isReplay) {
|
||||
intercepted, err := s.interceptForward(
|
||||
p, packets.isReplay,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !intercepted {
|
||||
notIntercepted = append(
|
||||
notIntercepted, p,
|
||||
)
|
||||
@ -252,7 +259,9 @@ func (s *InterceptableSwitch) run() error {
|
||||
// already intercepted in the off-chain flow. And even
|
||||
// if not, it is safe to signal replay so that we won't
|
||||
// unexpectedly skip over this htlc.
|
||||
s.forward(fwd, true)
|
||||
if _, err := s.forward(fwd, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case res := <-s.resolutionChan:
|
||||
res.errChan <- s.resolve(res.resolution)
|
||||
@ -287,9 +296,7 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
|
||||
if interceptor != nil {
|
||||
log.Debugf("Interceptor connected")
|
||||
|
||||
for _, fwd := range s.holdForwards {
|
||||
s.sendForward(fwd)
|
||||
}
|
||||
s.heldHtlcSet.forEach(s.sendForward)
|
||||
|
||||
return
|
||||
}
|
||||
@ -305,20 +312,19 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
|
||||
// Interceptor is not required. Release held forwards.
|
||||
log.Infof("Interceptor disconnected, resolving held packets")
|
||||
|
||||
for _, fwd := range s.holdForwards {
|
||||
if err := fwd.Resume(); err != nil {
|
||||
s.heldHtlcSet.popAll(func(fwd InterceptedForward) {
|
||||
err := fwd.Resume()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to resume hold forward %v", err)
|
||||
}
|
||||
}
|
||||
s.holdForwards = make(map[channeldb.CircuitKey]InterceptedForward)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
|
||||
intercepted, ok := s.holdForwards[res.Key]
|
||||
if !ok {
|
||||
return fmt.Errorf("fwd %v not found", res.Key)
|
||||
intercepted, err := s.heldHtlcSet.pop(res.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(s.holdForwards, res.Key)
|
||||
|
||||
switch res.Action {
|
||||
case FwdActionResume:
|
||||
@ -405,13 +411,13 @@ func (s *InterceptableSwitch) ForwardPacket(
|
||||
// interceptForward forwards the packet to the external interceptor after
|
||||
// checking the interception criteria.
|
||||
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
|
||||
isReplay bool) bool {
|
||||
isReplay bool) (bool, error) {
|
||||
|
||||
switch htlc := packet.htlc.(type) {
|
||||
case *lnwire.UpdateAddHTLC:
|
||||
// We are not interested in intercepting initiated payments.
|
||||
if packet.incomingChanID == hop.Source {
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
intercepted := &interceptedForward{
|
||||
@ -435,28 +441,28 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
|
||||
// will remain stuck and potentially force-close the
|
||||
// channel. But in the end, we should never get here, so
|
||||
// the actual return value doesn't matter that much.
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
if handled {
|
||||
return true
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return s.forward(intercepted, isReplay)
|
||||
|
||||
default:
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// forward records the intercepted htlc and forwards it to the interceptor.
|
||||
func (s *InterceptableSwitch) forward(
|
||||
fwd InterceptedForward, isReplay bool) bool {
|
||||
fwd InterceptedForward, isReplay bool) (bool, error) {
|
||||
|
||||
inKey := fwd.Packet().IncomingCircuit
|
||||
|
||||
// Ignore already held htlcs.
|
||||
if _, ok := s.holdForwards[inKey]; ok {
|
||||
return true
|
||||
if s.heldHtlcSet.exists(inKey) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// If there is no interceptor currently registered, configuration and packet
|
||||
@ -464,7 +470,7 @@ func (s *InterceptableSwitch) forward(
|
||||
if s.interceptor == nil {
|
||||
// Process normally if an interceptor is not required.
|
||||
if !s.requireInterceptor {
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// We are in interceptor-required mode. If this is a new packet, it is
|
||||
@ -478,23 +484,28 @@ func (s *InterceptableSwitch) forward(
|
||||
log.Errorf("Cannot fail packet: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// This packet is a replay. It is not safe to fail back, because the
|
||||
// interceptor may still signal otherwise upon reconnect. Keep the
|
||||
// packet in the queue until then.
|
||||
s.holdForwards[inKey] = fwd
|
||||
if err := s.heldHtlcSet.push(inKey, fwd); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// There is an interceptor registered. We can forward the packet right now.
|
||||
// Hold it in the queue too to track what is outstanding.
|
||||
s.holdForwards[inKey] = fwd
|
||||
if err := s.heldHtlcSet.push(inKey, fwd); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
s.sendForward(fwd)
|
||||
|
||||
return true
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// handleExpired checks that the htlc isn't too close to the channel
|
||||
|
Loading…
Reference in New Issue
Block a user