mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 18:10:34 +01:00
htlcswitch: add heldHtlcSet
Isolation of the set logic so that it will be easier to add watchdog functionality later.
This commit is contained in:
parent
a6df9567ba
commit
9c063db698
75
htlcswitch/held_htlc_set.go
Normal file
75
htlcswitch/held_htlc_set.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package htlcswitch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// heldHtlcSet keeps track of outstanding intercepted forwards. It exposes
|
||||||
|
// several methods to manipulate the underlying map structure in a consistent
|
||||||
|
// way.
|
||||||
|
type heldHtlcSet struct {
|
||||||
|
set map[channeldb.CircuitKey]InterceptedForward
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHeldHtlcSet() *heldHtlcSet {
|
||||||
|
return &heldHtlcSet{
|
||||||
|
set: make(map[channeldb.CircuitKey]InterceptedForward),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// forEach iterates over all held forwards and calls the given callback for each
|
||||||
|
// of them.
|
||||||
|
func (h *heldHtlcSet) forEach(cb func(InterceptedForward)) {
|
||||||
|
for _, fwd := range h.set {
|
||||||
|
cb(fwd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// popAll calls the callback for each forward and removes them from the set.
|
||||||
|
func (h *heldHtlcSet) popAll(cb func(InterceptedForward)) {
|
||||||
|
for _, fwd := range h.set {
|
||||||
|
cb(fwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.set = make(map[channeldb.CircuitKey]InterceptedForward)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pop returns the specified forward and removes it from the set.
|
||||||
|
func (h *heldHtlcSet) pop(key channeldb.CircuitKey) (InterceptedForward, error) {
|
||||||
|
intercepted, ok := h.set[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("fwd %v not found", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(h.set, key)
|
||||||
|
|
||||||
|
return intercepted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// exists tests whether the specified forward is part of the set.
|
||||||
|
func (h *heldHtlcSet) exists(key channeldb.CircuitKey) bool {
|
||||||
|
_, ok := h.set[key]
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// push adds the specified forward to the set. An error is returned if the
|
||||||
|
// forward exists already.
|
||||||
|
func (h *heldHtlcSet) push(key channeldb.CircuitKey,
|
||||||
|
fwd InterceptedForward) error {
|
||||||
|
|
||||||
|
if fwd == nil {
|
||||||
|
return errors.New("nil fwd pushed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.exists(key) {
|
||||||
|
return errors.New("htlc already exists in set")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.set[key] = fwd
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
80
htlcswitch/held_htlc_set_test.go
Normal file
80
htlcswitch/held_htlc_set_test.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package htlcswitch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHeldHtlcSetEmpty(t *testing.T) {
|
||||||
|
set := newHeldHtlcSet()
|
||||||
|
|
||||||
|
// Test operations on an empty set.
|
||||||
|
require.False(t, set.exists(channeldb.CircuitKey{}))
|
||||||
|
|
||||||
|
_, err := set.pop(channeldb.CircuitKey{})
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
set.popAll(
|
||||||
|
func(_ InterceptedForward) {
|
||||||
|
require.Fail(t, "unexpected fwd")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeldHtlcSet(t *testing.T) {
|
||||||
|
set := newHeldHtlcSet()
|
||||||
|
|
||||||
|
key := channeldb.CircuitKey{
|
||||||
|
ChanID: lnwire.NewShortChanIDFromInt(1),
|
||||||
|
HtlcID: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test pushing a nil forward.
|
||||||
|
require.Error(t, set.push(key, nil))
|
||||||
|
|
||||||
|
// Test pushing a forward.
|
||||||
|
fwd := &interceptedForward{
|
||||||
|
htlc: &lnwire.UpdateAddHTLC{},
|
||||||
|
}
|
||||||
|
require.NoError(t, set.push(key, fwd))
|
||||||
|
|
||||||
|
// Re-pushing should fail.
|
||||||
|
require.Error(t, set.push(key, fwd))
|
||||||
|
|
||||||
|
// Test popping the fwd.
|
||||||
|
poppedFwd, err := set.pop(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, fwd, poppedFwd)
|
||||||
|
|
||||||
|
_, err = set.pop(key)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Pushing the forward again.
|
||||||
|
require.NoError(t, set.push(key, fwd))
|
||||||
|
|
||||||
|
// Test for each.
|
||||||
|
var cbCalled bool
|
||||||
|
set.forEach(func(_ InterceptedForward) {
|
||||||
|
cbCalled = true
|
||||||
|
|
||||||
|
require.Equal(t, fwd, poppedFwd)
|
||||||
|
})
|
||||||
|
require.True(t, cbCalled)
|
||||||
|
|
||||||
|
// Test popping all forwards.
|
||||||
|
cbCalled = false
|
||||||
|
set.popAll(
|
||||||
|
func(_ InterceptedForward) {
|
||||||
|
cbCalled = true
|
||||||
|
|
||||||
|
require.Equal(t, fwd, poppedFwd)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.True(t, cbCalled)
|
||||||
|
|
||||||
|
_, err = set.pop(key)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
@ -57,8 +57,8 @@ type InterceptableSwitch struct {
|
|||||||
// interceptor is the handler for intercepted packets.
|
// interceptor is the handler for intercepted packets.
|
||||||
interceptor ForwardInterceptor
|
interceptor ForwardInterceptor
|
||||||
|
|
||||||
// holdForwards keeps track of outstanding intercepted forwards.
|
// heldHtlcSet keeps track of outstanding intercepted forwards.
|
||||||
holdForwards map[channeldb.CircuitKey]InterceptedForward
|
heldHtlcSet *heldHtlcSet
|
||||||
|
|
||||||
// cltvRejectDelta defines the number of blocks before the expiry of the
|
// cltvRejectDelta defines the number of blocks before the expiry of the
|
||||||
// htlc where we no longer intercept it and instead cancel it back.
|
// htlc where we no longer intercept it and instead cancel it back.
|
||||||
@ -152,7 +152,7 @@ func NewInterceptableSwitch(cfg *InterceptableSwitchConfig) *InterceptableSwitch
|
|||||||
intercepted: make(chan *interceptedPackets),
|
intercepted: make(chan *interceptedPackets),
|
||||||
onchainIntercepted: make(chan InterceptedForward),
|
onchainIntercepted: make(chan InterceptedForward),
|
||||||
interceptorRegistration: make(chan ForwardInterceptor),
|
interceptorRegistration: make(chan ForwardInterceptor),
|
||||||
holdForwards: make(map[channeldb.CircuitKey]InterceptedForward),
|
heldHtlcSet: newHeldHtlcSet(),
|
||||||
resolutionChan: make(chan *fwdResolution),
|
resolutionChan: make(chan *fwdResolution),
|
||||||
requireInterceptor: cfg.RequireInterceptor,
|
requireInterceptor: cfg.RequireInterceptor,
|
||||||
cltvRejectDelta: cfg.CltvRejectDelta,
|
cltvRejectDelta: cfg.CltvRejectDelta,
|
||||||
@ -231,7 +231,14 @@ func (s *InterceptableSwitch) run() error {
|
|||||||
case packets := <-s.intercepted:
|
case packets := <-s.intercepted:
|
||||||
var notIntercepted []*htlcPacket
|
var notIntercepted []*htlcPacket
|
||||||
for _, p := range packets.packets {
|
for _, p := range packets.packets {
|
||||||
if !s.interceptForward(p, packets.isReplay) {
|
intercepted, err := s.interceptForward(
|
||||||
|
p, packets.isReplay,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !intercepted {
|
||||||
notIntercepted = append(
|
notIntercepted = append(
|
||||||
notIntercepted, p,
|
notIntercepted, p,
|
||||||
)
|
)
|
||||||
@ -252,7 +259,9 @@ func (s *InterceptableSwitch) run() error {
|
|||||||
// already intercepted in the off-chain flow. And even
|
// already intercepted in the off-chain flow. And even
|
||||||
// if not, it is safe to signal replay so that we won't
|
// if not, it is safe to signal replay so that we won't
|
||||||
// unexpectedly skip over this htlc.
|
// unexpectedly skip over this htlc.
|
||||||
s.forward(fwd, true)
|
if _, err := s.forward(fwd, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
case res := <-s.resolutionChan:
|
case res := <-s.resolutionChan:
|
||||||
res.errChan <- s.resolve(res.resolution)
|
res.errChan <- s.resolve(res.resolution)
|
||||||
@ -287,9 +296,7 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
|
|||||||
if interceptor != nil {
|
if interceptor != nil {
|
||||||
log.Debugf("Interceptor connected")
|
log.Debugf("Interceptor connected")
|
||||||
|
|
||||||
for _, fwd := range s.holdForwards {
|
s.heldHtlcSet.forEach(s.sendForward)
|
||||||
s.sendForward(fwd)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -305,20 +312,19 @@ func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
|
|||||||
// Interceptor is not required. Release held forwards.
|
// Interceptor is not required. Release held forwards.
|
||||||
log.Infof("Interceptor disconnected, resolving held packets")
|
log.Infof("Interceptor disconnected, resolving held packets")
|
||||||
|
|
||||||
for _, fwd := range s.holdForwards {
|
s.heldHtlcSet.popAll(func(fwd InterceptedForward) {
|
||||||
if err := fwd.Resume(); err != nil {
|
err := fwd.Resume()
|
||||||
|
if err != nil {
|
||||||
log.Errorf("Failed to resume hold forward %v", err)
|
log.Errorf("Failed to resume hold forward %v", err)
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
s.holdForwards = make(map[channeldb.CircuitKey]InterceptedForward)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
|
func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
|
||||||
intercepted, ok := s.holdForwards[res.Key]
|
intercepted, err := s.heldHtlcSet.pop(res.Key)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return fmt.Errorf("fwd %v not found", res.Key)
|
return err
|
||||||
}
|
}
|
||||||
delete(s.holdForwards, res.Key)
|
|
||||||
|
|
||||||
switch res.Action {
|
switch res.Action {
|
||||||
case FwdActionResume:
|
case FwdActionResume:
|
||||||
@ -405,13 +411,13 @@ func (s *InterceptableSwitch) ForwardPacket(
|
|||||||
// interceptForward forwards the packet to the external interceptor after
|
// interceptForward forwards the packet to the external interceptor after
|
||||||
// checking the interception criteria.
|
// checking the interception criteria.
|
||||||
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
|
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
|
||||||
isReplay bool) bool {
|
isReplay bool) (bool, error) {
|
||||||
|
|
||||||
switch htlc := packet.htlc.(type) {
|
switch htlc := packet.htlc.(type) {
|
||||||
case *lnwire.UpdateAddHTLC:
|
case *lnwire.UpdateAddHTLC:
|
||||||
// We are not interested in intercepting initiated payments.
|
// We are not interested in intercepting initiated payments.
|
||||||
if packet.incomingChanID == hop.Source {
|
if packet.incomingChanID == hop.Source {
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
intercepted := &interceptedForward{
|
intercepted := &interceptedForward{
|
||||||
@ -435,28 +441,28 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
|
|||||||
// will remain stuck and potentially force-close the
|
// will remain stuck and potentially force-close the
|
||||||
// channel. But in the end, we should never get here, so
|
// channel. But in the end, we should never get here, so
|
||||||
// the actual return value doesn't matter that much.
|
// the actual return value doesn't matter that much.
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
if handled {
|
if handled {
|
||||||
return true
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.forward(intercepted, isReplay)
|
return s.forward(intercepted, isReplay)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// forward records the intercepted htlc and forwards it to the interceptor.
|
// forward records the intercepted htlc and forwards it to the interceptor.
|
||||||
func (s *InterceptableSwitch) forward(
|
func (s *InterceptableSwitch) forward(
|
||||||
fwd InterceptedForward, isReplay bool) bool {
|
fwd InterceptedForward, isReplay bool) (bool, error) {
|
||||||
|
|
||||||
inKey := fwd.Packet().IncomingCircuit
|
inKey := fwd.Packet().IncomingCircuit
|
||||||
|
|
||||||
// Ignore already held htlcs.
|
// Ignore already held htlcs.
|
||||||
if _, ok := s.holdForwards[inKey]; ok {
|
if s.heldHtlcSet.exists(inKey) {
|
||||||
return true
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is no interceptor currently registered, configuration and packet
|
// If there is no interceptor currently registered, configuration and packet
|
||||||
@ -464,7 +470,7 @@ func (s *InterceptableSwitch) forward(
|
|||||||
if s.interceptor == nil {
|
if s.interceptor == nil {
|
||||||
// Process normally if an interceptor is not required.
|
// Process normally if an interceptor is not required.
|
||||||
if !s.requireInterceptor {
|
if !s.requireInterceptor {
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// We are in interceptor-required mode. If this is a new packet, it is
|
// We are in interceptor-required mode. If this is a new packet, it is
|
||||||
@ -478,23 +484,28 @@ func (s *InterceptableSwitch) forward(
|
|||||||
log.Errorf("Cannot fail packet: %v", err)
|
log.Errorf("Cannot fail packet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// This packet is a replay. It is not safe to fail back, because the
|
// This packet is a replay. It is not safe to fail back, because the
|
||||||
// interceptor may still signal otherwise upon reconnect. Keep the
|
// interceptor may still signal otherwise upon reconnect. Keep the
|
||||||
// packet in the queue until then.
|
// packet in the queue until then.
|
||||||
s.holdForwards[inKey] = fwd
|
if err := s.heldHtlcSet.push(inKey, fwd); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// There is an interceptor registered. We can forward the packet right now.
|
// There is an interceptor registered. We can forward the packet right now.
|
||||||
// Hold it in the queue too to track what is outstanding.
|
// Hold it in the queue too to track what is outstanding.
|
||||||
s.holdForwards[inKey] = fwd
|
if err := s.heldHtlcSet.push(inKey, fwd); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
s.sendForward(fwd)
|
s.sendForward(fwd)
|
||||||
|
|
||||||
return true
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleExpired checks that the htlc isn't too close to the channel
|
// handleExpired checks that the htlc isn't too close to the channel
|
||||||
|
Loading…
Reference in New Issue
Block a user