mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-21 22:11:41 +01:00
htlcswitch: add receiver-side inbound fee support
This commit is contained in:
parent
3e6adbf1c0
commit
e8c97deaef
10 changed files with 388 additions and 19 deletions
|
@ -115,6 +115,9 @@ type ForwardingPolicy struct {
|
||||||
// used to compute the required fee for a given HTLC.
|
// used to compute the required fee for a given HTLC.
|
||||||
FeeRate lnwire.MilliSatoshi
|
FeeRate lnwire.MilliSatoshi
|
||||||
|
|
||||||
|
// InboundFee is the fee that must be paid for incoming HTLCs.
|
||||||
|
InboundFee InboundFee
|
||||||
|
|
||||||
// TimeLockDelta is the absolute time-lock value, expressed in blocks,
|
// TimeLockDelta is the absolute time-lock value, expressed in blocks,
|
||||||
// that will be subtracted from an incoming HTLC's timelock value to
|
// that will be subtracted from an incoming HTLC's timelock value to
|
||||||
// create the time-lock value for the forwarded outgoing HTLC. The
|
// create the time-lock value for the forwarded outgoing HTLC. The
|
||||||
|
|
53
channeldb/models/inbound_fee.go
Normal file
53
channeldb/models/inbound_fee.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package models
|
||||||
|
|
||||||
|
import "github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxFeeRate is the maximum fee rate that we allow. It is set to allow
|
||||||
|
// a variable fee component of up to 10x the payment amount.
|
||||||
|
maxFeeRate = 10 * feeRateParts
|
||||||
|
)
|
||||||
|
|
||||||
|
type InboundFee struct {
|
||||||
|
Base int32
|
||||||
|
Rate int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInboundFeeFromWire constructs an inbound fee structure from a wire fee.
|
||||||
|
func NewInboundFeeFromWire(fee lnwire.Fee) InboundFee {
|
||||||
|
return InboundFee{
|
||||||
|
Base: fee.BaseFee,
|
||||||
|
Rate: fee.FeeRate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToWire converts the inbound fee to a wire fee structure.
|
||||||
|
func (i *InboundFee) ToWire() lnwire.Fee {
|
||||||
|
return lnwire.Fee{
|
||||||
|
BaseFee: i.Base,
|
||||||
|
FeeRate: i.Rate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalcFee calculates what the inbound fee should minimally be for forwarding
|
||||||
|
// the given amount. This amount is the total of the outgoing amount plus the
|
||||||
|
// outbound fee, which is what the inbound fee is based on.
|
||||||
|
func (i *InboundFee) CalcFee(amt lnwire.MilliSatoshi) int64 {
|
||||||
|
fee := int64(i.Base)
|
||||||
|
rate := int64(i.Rate)
|
||||||
|
|
||||||
|
// Cap the rate to prevent overflows.
|
||||||
|
switch {
|
||||||
|
case rate > maxFeeRate:
|
||||||
|
rate = maxFeeRate
|
||||||
|
|
||||||
|
case rate < -maxFeeRate:
|
||||||
|
rate = -maxFeeRate
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate proportional component. To keep the integer math simple,
|
||||||
|
// positive fees are rounded down while negative fees are rounded up.
|
||||||
|
fee += rate * int64(amt) / feeRateParts
|
||||||
|
|
||||||
|
return fee
|
||||||
|
}
|
33
channeldb/models/inbound_fee_test.go
Normal file
33
channeldb/models/inbound_fee_test.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInboundFee(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Test positive fee.
|
||||||
|
i := InboundFee{
|
||||||
|
Base: 5,
|
||||||
|
Rate: 500000,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, int64(6), i.CalcFee(2))
|
||||||
|
|
||||||
|
// Expect fee to be rounded down.
|
||||||
|
require.Equal(t, int64(6), i.CalcFee(3))
|
||||||
|
|
||||||
|
// Test negative fee.
|
||||||
|
i = InboundFee{
|
||||||
|
Base: -5,
|
||||||
|
Rate: -500000,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, int64(-6), i.CalcFee(2))
|
||||||
|
|
||||||
|
// Expect fee to be rounded up.
|
||||||
|
require.Equal(t, int64(-6), i.CalcFee(3))
|
||||||
|
}
|
|
@ -247,6 +247,7 @@ type ChannelLink interface {
|
||||||
CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi,
|
CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi,
|
||||||
amtToForward lnwire.MilliSatoshi,
|
amtToForward lnwire.MilliSatoshi,
|
||||||
incomingTimeout, outgoingTimeout uint32,
|
incomingTimeout, outgoingTimeout uint32,
|
||||||
|
inboundFee models.InboundFee,
|
||||||
heightNow uint32, scid lnwire.ShortChannelID) *LinkError
|
heightNow uint32, scid lnwire.ShortChannelID) *LinkError
|
||||||
|
|
||||||
// CheckHtlcTransit should return a nil error if the passed HTLC details
|
// CheckHtlcTransit should return a nil error if the passed HTLC details
|
||||||
|
|
|
@ -2780,28 +2780,43 @@ func (l *channelLink) UpdateForwardingPolicy(
|
||||||
func (l *channelLink) CheckHtlcForward(payHash [32]byte,
|
func (l *channelLink) CheckHtlcForward(payHash [32]byte,
|
||||||
incomingHtlcAmt, amtToForward lnwire.MilliSatoshi,
|
incomingHtlcAmt, amtToForward lnwire.MilliSatoshi,
|
||||||
incomingTimeout, outgoingTimeout uint32,
|
incomingTimeout, outgoingTimeout uint32,
|
||||||
|
inboundFee models.InboundFee,
|
||||||
heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError {
|
heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError {
|
||||||
|
|
||||||
l.RLock()
|
l.RLock()
|
||||||
policy := l.cfg.FwrdingPolicy
|
policy := l.cfg.FwrdingPolicy
|
||||||
l.RUnlock()
|
l.RUnlock()
|
||||||
|
|
||||||
// Using the amount of the incoming HTLC, we'll calculate the expected
|
// Using the outgoing HTLC amount, we'll calculate the outgoing
|
||||||
// fee this incoming HTLC must carry in order to satisfy the
|
// fee this incoming HTLC must carry in order to satisfy the constraints
|
||||||
// constraints of the outgoing link.
|
// of the outgoing link.
|
||||||
expectedFee := ExpectedFee(policy, amtToForward)
|
outFee := ExpectedFee(policy, amtToForward)
|
||||||
|
|
||||||
|
// Then calculate the inbound fee that we charge based on the sum of
|
||||||
|
// outgoing HTLC amount and outgoing fee.
|
||||||
|
inFee := inboundFee.CalcFee(amtToForward + outFee)
|
||||||
|
|
||||||
|
// Add up both fee components. It is important to calculate both fees
|
||||||
|
// separately. An alternative way of calculating is to first determine
|
||||||
|
// an aggregate fee and apply that to the outgoing HTLC amount. However,
|
||||||
|
// rounding may cause the result to be slightly higher than in the case
|
||||||
|
// of separately rounded fee components. This potentially causes failed
|
||||||
|
// forwards for senders and is something to be avoided.
|
||||||
|
expectedFee := inFee + int64(outFee)
|
||||||
|
|
||||||
// If the actual fee is less than our expected fee, then we'll reject
|
// If the actual fee is less than our expected fee, then we'll reject
|
||||||
// this HTLC as it didn't provide a sufficient amount of fees, or the
|
// this HTLC as it didn't provide a sufficient amount of fees, or the
|
||||||
// values have been tampered with, or the send used incorrect/dated
|
// values have been tampered with, or the send used incorrect/dated
|
||||||
// information to construct the forwarding information for this hop. In
|
// information to construct the forwarding information for this hop. In
|
||||||
// any case, we'll cancel this HTLC. We're checking for this case first
|
// any case, we'll cancel this HTLC.
|
||||||
// to leak as little information as possible.
|
actualFee := int64(incomingHtlcAmt) - int64(amtToForward)
|
||||||
actualFee := incomingHtlcAmt - amtToForward
|
|
||||||
if incomingHtlcAmt < amtToForward || actualFee < expectedFee {
|
if incomingHtlcAmt < amtToForward || actualFee < expectedFee {
|
||||||
l.log.Warnf("outgoing htlc(%x) has insufficient fee: "+
|
l.log.Warnf("outgoing htlc(%x) has insufficient fee: "+
|
||||||
"expected %v, got %v",
|
"expected %v, got %v: incoming=%v, outgoing=%v, "+
|
||||||
payHash[:], int64(expectedFee), int64(actualFee))
|
"inboundFee=%v",
|
||||||
|
payHash[:], expectedFee, actualFee,
|
||||||
|
incomingHtlcAmt, amtToForward, inboundFee,
|
||||||
|
)
|
||||||
|
|
||||||
// As part of the returned error, we'll send our latest routing
|
// As part of the returned error, we'll send our latest routing
|
||||||
// policy so the sending node obtains the most up to date data.
|
// policy so the sending node obtains the most up to date data.
|
||||||
|
@ -3330,6 +3345,8 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||||
// round of processing.
|
// round of processing.
|
||||||
chanIterator.EncodeNextHop(buf)
|
chanIterator.EncodeNextHop(buf)
|
||||||
|
|
||||||
|
inboundFee := l.cfg.FwrdingPolicy.InboundFee
|
||||||
|
|
||||||
updatePacket := &htlcPacket{
|
updatePacket := &htlcPacket{
|
||||||
incomingChanID: l.ShortChanID(),
|
incomingChanID: l.ShortChanID(),
|
||||||
incomingHTLCID: pd.HtlcIndex,
|
incomingHTLCID: pd.HtlcIndex,
|
||||||
|
@ -3342,6 +3359,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||||
incomingTimeout: pd.Timeout,
|
incomingTimeout: pd.Timeout,
|
||||||
outgoingTimeout: fwdInfo.OutgoingCTLV,
|
outgoingTimeout: fwdInfo.OutgoingCTLV,
|
||||||
customRecords: pld.CustomRecords(),
|
customRecords: pld.CustomRecords(),
|
||||||
|
inboundFee: inboundFee,
|
||||||
}
|
}
|
||||||
switchPackets = append(
|
switchPackets = append(
|
||||||
switchPackets, updatePacket,
|
switchPackets, updatePacket,
|
||||||
|
@ -3394,6 +3412,8 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||||
// have been added to switchPackets at the top of this
|
// have been added to switchPackets at the top of this
|
||||||
// section.
|
// section.
|
||||||
if fwdPkg.State == channeldb.FwdStateLockedIn {
|
if fwdPkg.State == channeldb.FwdStateLockedIn {
|
||||||
|
inboundFee := l.cfg.FwrdingPolicy.InboundFee
|
||||||
|
|
||||||
updatePacket := &htlcPacket{
|
updatePacket := &htlcPacket{
|
||||||
incomingChanID: l.ShortChanID(),
|
incomingChanID: l.ShortChanID(),
|
||||||
incomingHTLCID: pd.HtlcIndex,
|
incomingHTLCID: pd.HtlcIndex,
|
||||||
|
@ -3406,6 +3426,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||||
incomingTimeout: pd.Timeout,
|
incomingTimeout: pd.Timeout,
|
||||||
outgoingTimeout: fwdInfo.OutgoingCTLV,
|
outgoingTimeout: fwdInfo.OutgoingCTLV,
|
||||||
customRecords: pld.CustomRecords(),
|
customRecords: pld.CustomRecords(),
|
||||||
|
inboundFee: inboundFee,
|
||||||
}
|
}
|
||||||
|
|
||||||
fwdPkg.FwdFilter.Set(idx)
|
fwdPkg.FwdFilter.Set(idx)
|
||||||
|
|
|
@ -643,6 +643,206 @@ func testChannelLinkMultiHopPayment(t *testing.T,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChannelLinkInboundFee(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("negative", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
bobInboundFee := models.InboundFee{
|
||||||
|
Base: -500,
|
||||||
|
Rate: -100,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bob is supposed to sent Carol 1000000 msats. For this, he
|
||||||
|
// will charge an out fee of 1000 msat (the default hop network
|
||||||
|
// policy). Bob's inbound fee is based on the sum of outgoing
|
||||||
|
// htlc amount and the out fee that Bob charges. The value of
|
||||||
|
// this sum is 1001000. The proportional component of the
|
||||||
|
// inbound fee is -0.01% of the sum, which is -100 (rounded
|
||||||
|
// up). Added to this is the base inbound fee of -500, making
|
||||||
|
// for a total inbound fee of -600.
|
||||||
|
const expectedBobInFee = -600
|
||||||
|
|
||||||
|
testChannelLinkInboundFee(
|
||||||
|
t, bobInboundFee, expectedBobInFee, false,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negative overpaid", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
bobInboundFee := models.InboundFee{
|
||||||
|
Base: -500,
|
||||||
|
Rate: -100,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alice is not aware of the inbound discount and pays the full
|
||||||
|
// outbound fee.
|
||||||
|
const expectedBobInFee = 0
|
||||||
|
|
||||||
|
testChannelLinkInboundFee(
|
||||||
|
t, bobInboundFee, expectedBobInFee, false,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negative total", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
bobInboundFee := models.InboundFee{
|
||||||
|
Base: -5000,
|
||||||
|
}
|
||||||
|
|
||||||
|
const expectedBobInFee = -5000
|
||||||
|
|
||||||
|
// Bob's inbound discount exceeds his outbound fee. Forwards
|
||||||
|
// carrying a negative total fee should be rejected.
|
||||||
|
testChannelLinkInboundFee(
|
||||||
|
t, bobInboundFee, expectedBobInFee, true,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("positive", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
bobInboundFee := models.InboundFee{
|
||||||
|
Base: 1_000,
|
||||||
|
Rate: 100_000,
|
||||||
|
}
|
||||||
|
|
||||||
|
const expectedBobInFee = 101_100
|
||||||
|
|
||||||
|
testChannelLinkInboundFee(
|
||||||
|
t, bobInboundFee, expectedBobInFee, false,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testChannelLinkInboundFee(t *testing.T, //nolint:thelper
|
||||||
|
bobInboundFee models.InboundFee, expectedBobInFee int64,
|
||||||
|
expectedFail bool) {
|
||||||
|
|
||||||
|
channels, _, err := createClusterChannels(
|
||||||
|
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "unable to create channel")
|
||||||
|
|
||||||
|
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
|
||||||
|
channels.bobToCarol, channels.carolToBob, testStartingHeight)
|
||||||
|
|
||||||
|
require.NoError(t, n.start())
|
||||||
|
defer n.stop()
|
||||||
|
|
||||||
|
bobPolicy := n.globalPolicy
|
||||||
|
bobPolicy.InboundFee = bobInboundFee
|
||||||
|
n.firstBobChannelLink.UpdateForwardingPolicy(bobPolicy)
|
||||||
|
|
||||||
|
// Set an inbound fee for Carol. Because Carol is the payee, the fee
|
||||||
|
// should not be applied.
|
||||||
|
carolPolicy := n.globalPolicy
|
||||||
|
carolPolicy.InboundFee = models.InboundFee{
|
||||||
|
Base: -2_000,
|
||||||
|
Rate: -200_000,
|
||||||
|
}
|
||||||
|
n.carolChannelLink.UpdateForwardingPolicy(carolPolicy)
|
||||||
|
|
||||||
|
carolBandwidthBefore := n.carolChannelLink.Bandwidth()
|
||||||
|
firstBobBandwidthBefore := n.firstBobChannelLink.Bandwidth()
|
||||||
|
secondBobBandwidthBefore := n.secondBobChannelLink.Bandwidth()
|
||||||
|
aliceBandwidthBefore := n.aliceChannelLink.Bandwidth()
|
||||||
|
|
||||||
|
const (
|
||||||
|
expectedCarolInboundFee = 0
|
||||||
|
|
||||||
|
// Expect Bob's outbound fee to match the default hop network
|
||||||
|
// policy.
|
||||||
|
expectedBobOutboundFee = 1_000
|
||||||
|
)
|
||||||
|
|
||||||
|
amount := lnwire.MilliSatoshi(1_000_000)
|
||||||
|
htlcAmt := lnwire.MilliSatoshi(1_000_000 +
|
||||||
|
expectedCarolInboundFee + expectedBobOutboundFee +
|
||||||
|
expectedBobInFee,
|
||||||
|
)
|
||||||
|
totalTimelock := uint32(112)
|
||||||
|
|
||||||
|
hops := []*hop.Payload{
|
||||||
|
{
|
||||||
|
FwdInfo: hop.ForwardingInfo{
|
||||||
|
NextHop: n.carolChannelLink.
|
||||||
|
ShortChanID(),
|
||||||
|
AmountToForward: 1_000_000,
|
||||||
|
OutgoingCTLV: 106,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FwdInfo: hop.ForwardingInfo{
|
||||||
|
AmountToForward: 1_000_000,
|
||||||
|
OutgoingCTLV: 106,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
receiver := n.carolServer
|
||||||
|
firstHop := n.firstBobChannelLink.ShortChanID()
|
||||||
|
rhash, err := makePayment(
|
||||||
|
n.aliceServer, n.carolServer, firstHop, hops, amount, htlcAmt,
|
||||||
|
totalTimelock,
|
||||||
|
).Wait(30 * time.Second)
|
||||||
|
|
||||||
|
if expectedFail {
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, "unable to send payment")
|
||||||
|
|
||||||
|
// Wait for Alice and Bob's second link to receive the revocation.
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
// Check that Carol invoice was settled and bandwidth of HTLC
|
||||||
|
// links were changed.
|
||||||
|
invoice, err := receiver.registry.LookupInvoice(
|
||||||
|
context.Background(), rhash,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "unable to get invoice")
|
||||||
|
require.Equal(t, invpkg.ContractSettled, invoice.State,
|
||||||
|
"carol invoice haven't been settled")
|
||||||
|
|
||||||
|
expectedAliceBandwidth := aliceBandwidthBefore - htlcAmt
|
||||||
|
require.Equalf(t,
|
||||||
|
expectedAliceBandwidth, n.aliceChannelLink.Bandwidth(),
|
||||||
|
"channel bandwidth incorrect: expected %v, got %v",
|
||||||
|
expectedAliceBandwidth, n.aliceChannelLink.Bandwidth(),
|
||||||
|
)
|
||||||
|
|
||||||
|
expectedBobBandwidth1 := firstBobBandwidthBefore + htlcAmt
|
||||||
|
require.Equalf(t,
|
||||||
|
expectedBobBandwidth1, n.firstBobChannelLink.Bandwidth(),
|
||||||
|
"channel bandwidth incorrect: expected %v, got %v",
|
||||||
|
expectedBobBandwidth1, n.firstBobChannelLink.Bandwidth(),
|
||||||
|
)
|
||||||
|
|
||||||
|
bobCarolDelta := lnwire.MilliSatoshi(
|
||||||
|
int64(amount) + expectedCarolInboundFee,
|
||||||
|
)
|
||||||
|
|
||||||
|
expectedBobBandwidth2 := secondBobBandwidthBefore - bobCarolDelta
|
||||||
|
require.Equalf(t,
|
||||||
|
expectedBobBandwidth2, n.secondBobChannelLink.Bandwidth(),
|
||||||
|
"channel bandwidth incorrect: expected %v, got %v",
|
||||||
|
expectedBobBandwidth2, n.secondBobChannelLink.Bandwidth(),
|
||||||
|
)
|
||||||
|
|
||||||
|
expectedCarolBandwidth := carolBandwidthBefore + bobCarolDelta
|
||||||
|
require.Equalf(t,
|
||||||
|
expectedCarolBandwidth, n.carolChannelLink.Bandwidth(),
|
||||||
|
"channel bandwidth incorrect: expected %v, got %v",
|
||||||
|
expectedCarolBandwidth, n.carolChannelLink.Bandwidth(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// TestChannelLinkCancelFullCommitment tests the ability for links to cancel
|
// TestChannelLinkCancelFullCommitment tests the ability for links to cancel
|
||||||
// forwarded HTLCs once all of their commitment slots are full.
|
// forwarded HTLCs once all of their commitment slots are full.
|
||||||
func TestChannelLinkCancelFullCommitment(t *testing.T) {
|
func TestChannelLinkCancelFullCommitment(t *testing.T) {
|
||||||
|
@ -5994,7 +6194,9 @@ func TestCheckHtlcForward(t *testing.T) {
|
||||||
|
|
||||||
t.Run("satisfied", func(t *testing.T) {
|
t.Run("satisfied", func(t *testing.T) {
|
||||||
result := link.CheckHtlcForward(hash, 1500, 1000,
|
result := link.CheckHtlcForward(hash, 1500, 1000,
|
||||||
200, 150, 0, lnwire.ShortChannelID{})
|
200, 150, models.InboundFee{}, 0,
|
||||||
|
lnwire.ShortChannelID{},
|
||||||
|
)
|
||||||
if result != nil {
|
if result != nil {
|
||||||
t.Fatalf("expected policy to be satisfied")
|
t.Fatalf("expected policy to be satisfied")
|
||||||
}
|
}
|
||||||
|
@ -6002,7 +6204,9 @@ func TestCheckHtlcForward(t *testing.T) {
|
||||||
|
|
||||||
t.Run("below minhtlc", func(t *testing.T) {
|
t.Run("below minhtlc", func(t *testing.T) {
|
||||||
result := link.CheckHtlcForward(hash, 100, 50,
|
result := link.CheckHtlcForward(hash, 100, 50,
|
||||||
200, 150, 0, lnwire.ShortChannelID{})
|
200, 150, models.InboundFee{}, 0,
|
||||||
|
lnwire.ShortChannelID{},
|
||||||
|
)
|
||||||
if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok {
|
if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok {
|
||||||
t.Fatalf("expected FailAmountBelowMinimum failure code")
|
t.Fatalf("expected FailAmountBelowMinimum failure code")
|
||||||
}
|
}
|
||||||
|
@ -6010,7 +6214,9 @@ func TestCheckHtlcForward(t *testing.T) {
|
||||||
|
|
||||||
t.Run("above maxhtlc", func(t *testing.T) {
|
t.Run("above maxhtlc", func(t *testing.T) {
|
||||||
result := link.CheckHtlcForward(hash, 1500, 1200,
|
result := link.CheckHtlcForward(hash, 1500, 1200,
|
||||||
200, 150, 0, lnwire.ShortChannelID{})
|
200, 150, models.InboundFee{}, 0,
|
||||||
|
lnwire.ShortChannelID{},
|
||||||
|
)
|
||||||
if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok {
|
if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok {
|
||||||
t.Fatalf("expected FailTemporaryChannelFailure failure code")
|
t.Fatalf("expected FailTemporaryChannelFailure failure code")
|
||||||
}
|
}
|
||||||
|
@ -6018,7 +6224,9 @@ func TestCheckHtlcForward(t *testing.T) {
|
||||||
|
|
||||||
t.Run("insufficient fee", func(t *testing.T) {
|
t.Run("insufficient fee", func(t *testing.T) {
|
||||||
result := link.CheckHtlcForward(hash, 1005, 1000,
|
result := link.CheckHtlcForward(hash, 1005, 1000,
|
||||||
200, 150, 0, lnwire.ShortChannelID{})
|
200, 150, models.InboundFee{}, 0,
|
||||||
|
lnwire.ShortChannelID{},
|
||||||
|
)
|
||||||
if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok {
|
if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok {
|
||||||
t.Fatalf("expected FailFeeInsufficient failure code")
|
t.Fatalf("expected FailFeeInsufficient failure code")
|
||||||
}
|
}
|
||||||
|
@ -6031,7 +6239,7 @@ func TestCheckHtlcForward(t *testing.T) {
|
||||||
|
|
||||||
result := link.CheckHtlcForward(
|
result := link.CheckHtlcForward(
|
||||||
hash, 100005, 100000, 200,
|
hash, 100005, 100000, 200,
|
||||||
150, 0, lnwire.ShortChannelID{},
|
150, models.InboundFee{}, 0, lnwire.ShortChannelID{},
|
||||||
)
|
)
|
||||||
_, ok := result.WireMessage().(*lnwire.FailFeeInsufficient)
|
_, ok := result.WireMessage().(*lnwire.FailFeeInsufficient)
|
||||||
require.True(t, ok, "expected FailFeeInsufficient failure code")
|
require.True(t, ok, "expected FailFeeInsufficient failure code")
|
||||||
|
@ -6039,7 +6247,9 @@ func TestCheckHtlcForward(t *testing.T) {
|
||||||
|
|
||||||
t.Run("expiry too soon", func(t *testing.T) {
|
t.Run("expiry too soon", func(t *testing.T) {
|
||||||
result := link.CheckHtlcForward(hash, 1500, 1000,
|
result := link.CheckHtlcForward(hash, 1500, 1000,
|
||||||
200, 150, 190, lnwire.ShortChannelID{})
|
200, 150, models.InboundFee{}, 190,
|
||||||
|
lnwire.ShortChannelID{},
|
||||||
|
)
|
||||||
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok {
|
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok {
|
||||||
t.Fatalf("expected FailExpiryTooSoon failure code")
|
t.Fatalf("expected FailExpiryTooSoon failure code")
|
||||||
}
|
}
|
||||||
|
@ -6047,7 +6257,9 @@ func TestCheckHtlcForward(t *testing.T) {
|
||||||
|
|
||||||
t.Run("incorrect cltv expiry", func(t *testing.T) {
|
t.Run("incorrect cltv expiry", func(t *testing.T) {
|
||||||
result := link.CheckHtlcForward(hash, 1500, 1000,
|
result := link.CheckHtlcForward(hash, 1500, 1000,
|
||||||
200, 190, 0, lnwire.ShortChannelID{})
|
200, 190, models.InboundFee{}, 0,
|
||||||
|
lnwire.ShortChannelID{},
|
||||||
|
)
|
||||||
if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok {
|
if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok {
|
||||||
t.Fatalf("expected FailIncorrectCltvExpiry failure code")
|
t.Fatalf("expected FailIncorrectCltvExpiry failure code")
|
||||||
}
|
}
|
||||||
|
@ -6057,11 +6269,37 @@ func TestCheckHtlcForward(t *testing.T) {
|
||||||
t.Run("cltv expiry too far in the future", func(t *testing.T) {
|
t.Run("cltv expiry too far in the future", func(t *testing.T) {
|
||||||
// Check that expiry isn't too far in the future.
|
// Check that expiry isn't too far in the future.
|
||||||
result := link.CheckHtlcForward(hash, 1500, 1000,
|
result := link.CheckHtlcForward(hash, 1500, 1000,
|
||||||
10200, 10100, 0, lnwire.ShortChannelID{})
|
10200, 10100, models.InboundFee{}, 0,
|
||||||
|
lnwire.ShortChannelID{},
|
||||||
|
)
|
||||||
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok {
|
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok {
|
||||||
t.Fatalf("expected FailExpiryTooFar failure code")
|
t.Fatalf("expected FailExpiryTooFar failure code")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("inbound fee satisfied", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result := link.CheckHtlcForward(hash, 1000+10-2-1, 1000,
|
||||||
|
200, 150, models.InboundFee{Base: -2, Rate: -1_000},
|
||||||
|
0, lnwire.ShortChannelID{})
|
||||||
|
if result != nil {
|
||||||
|
t.Fatalf("expected policy to be satisfied")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inbound fee insufficient", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result := link.CheckHtlcForward(hash, 1000+10-10-101-1, 1000,
|
||||||
|
200, 150, models.InboundFee{Base: -10, Rate: -100_000},
|
||||||
|
0, lnwire.ShortChannelID{})
|
||||||
|
|
||||||
|
msg := result.WireMessage()
|
||||||
|
if _, ok := msg.(*lnwire.FailFeeInsufficient); !ok {
|
||||||
|
t.Fatalf("expected FailFeeInsufficient failure code")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestChannelLinkCanceledInvoice in this test checks the interaction
|
// TestChannelLinkCanceledInvoice in this test checks the interaction
|
||||||
|
|
|
@ -834,7 +834,7 @@ func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) {
|
||||||
func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) {
|
func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) {
|
||||||
}
|
}
|
||||||
func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi,
|
func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi,
|
||||||
lnwire.MilliSatoshi, uint32, uint32, uint32,
|
lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32,
|
||||||
lnwire.ShortChannelID) *LinkError {
|
lnwire.ShortChannelID) *LinkError {
|
||||||
|
|
||||||
return f.checkHtlcForwardResult
|
return f.checkHtlcForwardResult
|
||||||
|
|
|
@ -2,6 +2,7 @@ package htlcswitch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/record"
|
"github.com/lightningnetwork/lnd/record"
|
||||||
|
@ -103,6 +104,9 @@ type htlcPacket struct {
|
||||||
// but receives a channel_update with the alias SCID. Instead, the
|
// but receives a channel_update with the alias SCID. Instead, the
|
||||||
// payer should receive a channel_update with the public SCID.
|
// payer should receive a channel_update with the public SCID.
|
||||||
originalOutgoingChanID lnwire.ShortChannelID
|
originalOutgoingChanID lnwire.ShortChannelID
|
||||||
|
|
||||||
|
// inboundFee is the fee schedule of the incoming channel.
|
||||||
|
inboundFee models.InboundFee
|
||||||
}
|
}
|
||||||
|
|
||||||
// inKey returns the circuit key used to identify the incoming htlc.
|
// inKey returns the circuit key used to identify the incoming htlc.
|
||||||
|
|
|
@ -1178,7 +1178,9 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
||||||
failure = link.CheckHtlcForward(
|
failure = link.CheckHtlcForward(
|
||||||
htlc.PaymentHash, packet.incomingAmount,
|
htlc.PaymentHash, packet.incomingAmount,
|
||||||
packet.amount, packet.incomingTimeout,
|
packet.amount, packet.incomingTimeout,
|
||||||
packet.outgoingTimeout, currentHeight,
|
packet.outgoingTimeout,
|
||||||
|
packet.inboundFee,
|
||||||
|
currentHeight,
|
||||||
packet.originalOutgoingChanID,
|
packet.originalOutgoingChanID,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -961,12 +961,26 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) (
|
||||||
// routing policy into a forwarding policy.
|
// routing policy into a forwarding policy.
|
||||||
var forwardingPolicy *models.ForwardingPolicy
|
var forwardingPolicy *models.ForwardingPolicy
|
||||||
if selfPolicy != nil {
|
if selfPolicy != nil {
|
||||||
|
var inboundWireFee lnwire.Fee
|
||||||
|
_, err := selfPolicy.ExtraOpaqueData.ExtractRecords(
|
||||||
|
&inboundWireFee,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
inboundFee := models.NewInboundFeeFromWire(
|
||||||
|
inboundWireFee,
|
||||||
|
)
|
||||||
|
|
||||||
forwardingPolicy = &models.ForwardingPolicy{
|
forwardingPolicy = &models.ForwardingPolicy{
|
||||||
MinHTLCOut: selfPolicy.MinHTLC,
|
MinHTLCOut: selfPolicy.MinHTLC,
|
||||||
MaxHTLC: selfPolicy.MaxHTLC,
|
MaxHTLC: selfPolicy.MaxHTLC,
|
||||||
BaseFee: selfPolicy.FeeBaseMSat,
|
BaseFee: selfPolicy.FeeBaseMSat,
|
||||||
FeeRate: selfPolicy.FeeProportionalMillionths,
|
FeeRate: selfPolicy.FeeProportionalMillionths,
|
||||||
TimeLockDelta: uint32(selfPolicy.TimeLockDelta),
|
TimeLockDelta: uint32(selfPolicy.TimeLockDelta),
|
||||||
|
|
||||||
|
InboundFee: inboundFee,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
p.log.Warnf("Unable to find our forwarding policy "+
|
p.log.Warnf("Unable to find our forwarding policy "+
|
||||||
|
|
Loading…
Add table
Reference in a new issue