htlcswitch: add receiver-side inbound fee support

This commit is contained in:
Joost Jager 2022-09-19 12:06:34 +02:00
parent 3e6adbf1c0
commit e8c97deaef
10 changed files with 388 additions and 19 deletions

View file

@ -115,6 +115,9 @@ type ForwardingPolicy struct {
// used to compute the required fee for a given HTLC.
FeeRate lnwire.MilliSatoshi
// InboundFee is the fee that must be paid for incoming HTLCs.
InboundFee InboundFee
// TimeLockDelta is the absolute time-lock value, expressed in blocks,
// that will be subtracted from an incoming HTLC's timelock value to
// create the time-lock value for the forwarded outgoing HTLC. The

View file

@ -0,0 +1,53 @@
package models
import "github.com/lightningnetwork/lnd/lnwire"
const (
// maxFeeRate is the maximum fee rate that we allow. It is set to allow
// a variable fee component of up to 10x the payment amount.
maxFeeRate = 10 * feeRateParts
)
type InboundFee struct {
Base int32
Rate int32
}
// NewInboundFeeFromWire constructs an inbound fee structure from a wire fee.
func NewInboundFeeFromWire(fee lnwire.Fee) InboundFee {
return InboundFee{
Base: fee.BaseFee,
Rate: fee.FeeRate,
}
}
// ToWire converts the inbound fee to a wire fee structure.
func (i *InboundFee) ToWire() lnwire.Fee {
return lnwire.Fee{
BaseFee: i.Base,
FeeRate: i.Rate,
}
}
// CalcFee calculates what the inbound fee should minimally be for forwarding
// the given amount. This amount is the total of the outgoing amount plus the
// outbound fee, which is what the inbound fee is based on.
func (i *InboundFee) CalcFee(amt lnwire.MilliSatoshi) int64 {
fee := int64(i.Base)
rate := int64(i.Rate)
// Cap the rate to prevent overflows.
switch {
case rate > maxFeeRate:
rate = maxFeeRate
case rate < -maxFeeRate:
rate = -maxFeeRate
}
// Calculate proportional component. To keep the integer math simple,
// positive fees are rounded down while negative fees are rounded up.
fee += rate * int64(amt) / feeRateParts
return fee
}

View file

@ -0,0 +1,33 @@
package models
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInboundFee(t *testing.T) {
t.Parallel()
// Test positive fee.
i := InboundFee{
Base: 5,
Rate: 500000,
}
require.Equal(t, int64(6), i.CalcFee(2))
// Expect fee to be rounded down.
require.Equal(t, int64(6), i.CalcFee(3))
// Test negative fee.
i = InboundFee{
Base: -5,
Rate: -500000,
}
require.Equal(t, int64(-6), i.CalcFee(2))
// Expect fee to be rounded up.
require.Equal(t, int64(-6), i.CalcFee(3))
}

View file

@ -247,6 +247,7 @@ type ChannelLink interface {
CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi,
amtToForward lnwire.MilliSatoshi,
incomingTimeout, outgoingTimeout uint32,
inboundFee models.InboundFee,
heightNow uint32, scid lnwire.ShortChannelID) *LinkError
// CheckHtlcTransit should return a nil error if the passed HTLC details

View file

@ -2780,28 +2780,43 @@ func (l *channelLink) UpdateForwardingPolicy(
func (l *channelLink) CheckHtlcForward(payHash [32]byte,
incomingHtlcAmt, amtToForward lnwire.MilliSatoshi,
incomingTimeout, outgoingTimeout uint32,
inboundFee models.InboundFee,
heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError {
l.RLock()
policy := l.cfg.FwrdingPolicy
l.RUnlock()
// Using the amount of the incoming HTLC, we'll calculate the expected
// fee this incoming HTLC must carry in order to satisfy the
// constraints of the outgoing link.
expectedFee := ExpectedFee(policy, amtToForward)
// Using the outgoing HTLC amount, we'll calculate the outgoing
// fee this incoming HTLC must carry in order to satisfy the constraints
// of the outgoing link.
outFee := ExpectedFee(policy, amtToForward)
// Then calculate the inbound fee that we charge based on the sum of
// outgoing HTLC amount and outgoing fee.
inFee := inboundFee.CalcFee(amtToForward + outFee)
// Add up both fee components. It is important to calculate both fees
// separately. An alternative way of calculating is to first determine
// an aggregate fee and apply that to the outgoing HTLC amount. However,
// rounding may cause the result to be slightly higher than in the case
// of separately rounded fee components. This potentially causes failed
// forwards for senders and is something to be avoided.
expectedFee := inFee + int64(outFee)
// If the actual fee is less than our expected fee, then we'll reject
// this HTLC as it didn't provide a sufficient amount of fees, or the
// values have been tampered with, or the send used incorrect/dated
// information to construct the forwarding information for this hop. In
// any case, we'll cancel this HTLC. We're checking for this case first
// to leak as little information as possible.
actualFee := incomingHtlcAmt - amtToForward
// any case, we'll cancel this HTLC.
actualFee := int64(incomingHtlcAmt) - int64(amtToForward)
if incomingHtlcAmt < amtToForward || actualFee < expectedFee {
l.log.Warnf("outgoing htlc(%x) has insufficient fee: "+
"expected %v, got %v",
payHash[:], int64(expectedFee), int64(actualFee))
"expected %v, got %v: incoming=%v, outgoing=%v, "+
"inboundFee=%v",
payHash[:], expectedFee, actualFee,
incomingHtlcAmt, amtToForward, inboundFee,
)
// As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up to date data.
@ -3330,6 +3345,8 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// round of processing.
chanIterator.EncodeNextHop(buf)
inboundFee := l.cfg.FwrdingPolicy.InboundFee
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
@ -3342,6 +3359,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
}
switchPackets = append(
switchPackets, updatePacket,
@ -3394,6 +3412,8 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// have been added to switchPackets at the top of this
// section.
if fwdPkg.State == channeldb.FwdStateLockedIn {
inboundFee := l.cfg.FwrdingPolicy.InboundFee
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
@ -3406,6 +3426,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
}
fwdPkg.FwdFilter.Set(idx)

View file

@ -643,6 +643,206 @@ func testChannelLinkMultiHopPayment(t *testing.T,
}
}
func TestChannelLinkInboundFee(t *testing.T) {
t.Parallel()
t.Run("negative", func(t *testing.T) {
t.Parallel()
bobInboundFee := models.InboundFee{
Base: -500,
Rate: -100,
}
// Bob is supposed to sent Carol 1000000 msats. For this, he
// will charge an out fee of 1000 msat (the default hop network
// policy). Bob's inbound fee is based on the sum of outgoing
// htlc amount and the out fee that Bob charges. The value of
// this sum is 1001000. The proportional component of the
// inbound fee is -0.01% of the sum, which is -100 (rounded
// up). Added to this is the base inbound fee of -500, making
// for a total inbound fee of -600.
const expectedBobInFee = -600
testChannelLinkInboundFee(
t, bobInboundFee, expectedBobInFee, false,
)
})
t.Run("negative overpaid", func(t *testing.T) {
t.Parallel()
bobInboundFee := models.InboundFee{
Base: -500,
Rate: -100,
}
// Alice is not aware of the inbound discount and pays the full
// outbound fee.
const expectedBobInFee = 0
testChannelLinkInboundFee(
t, bobInboundFee, expectedBobInFee, false,
)
})
t.Run("negative total", func(t *testing.T) {
t.Parallel()
bobInboundFee := models.InboundFee{
Base: -5000,
}
const expectedBobInFee = -5000
// Bob's inbound discount exceeds his outbound fee. Forwards
// carrying a negative total fee should be rejected.
testChannelLinkInboundFee(
t, bobInboundFee, expectedBobInFee, true,
)
})
t.Run("positive", func(t *testing.T) {
t.Parallel()
bobInboundFee := models.InboundFee{
Base: 1_000,
Rate: 100_000,
}
const expectedBobInFee = 101_100
testChannelLinkInboundFee(
t, bobInboundFee, expectedBobInFee, false,
)
})
}
func testChannelLinkInboundFee(t *testing.T, //nolint:thelper
bobInboundFee models.InboundFee, expectedBobInFee int64,
expectedFail bool) {
channels, _, err := createClusterChannels(
t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
require.NoError(t, err, "unable to create channel")
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight)
require.NoError(t, n.start())
defer n.stop()
bobPolicy := n.globalPolicy
bobPolicy.InboundFee = bobInboundFee
n.firstBobChannelLink.UpdateForwardingPolicy(bobPolicy)
// Set an inbound fee for Carol. Because Carol is the payee, the fee
// should not be applied.
carolPolicy := n.globalPolicy
carolPolicy.InboundFee = models.InboundFee{
Base: -2_000,
Rate: -200_000,
}
n.carolChannelLink.UpdateForwardingPolicy(carolPolicy)
carolBandwidthBefore := n.carolChannelLink.Bandwidth()
firstBobBandwidthBefore := n.firstBobChannelLink.Bandwidth()
secondBobBandwidthBefore := n.secondBobChannelLink.Bandwidth()
aliceBandwidthBefore := n.aliceChannelLink.Bandwidth()
const (
expectedCarolInboundFee = 0
// Expect Bob's outbound fee to match the default hop network
// policy.
expectedBobOutboundFee = 1_000
)
amount := lnwire.MilliSatoshi(1_000_000)
htlcAmt := lnwire.MilliSatoshi(1_000_000 +
expectedCarolInboundFee + expectedBobOutboundFee +
expectedBobInFee,
)
totalTimelock := uint32(112)
hops := []*hop.Payload{
{
FwdInfo: hop.ForwardingInfo{
NextHop: n.carolChannelLink.
ShortChanID(),
AmountToForward: 1_000_000,
OutgoingCTLV: 106,
},
},
{
FwdInfo: hop.ForwardingInfo{
AmountToForward: 1_000_000,
OutgoingCTLV: 106,
},
},
}
receiver := n.carolServer
firstHop := n.firstBobChannelLink.ShortChanID()
rhash, err := makePayment(
n.aliceServer, n.carolServer, firstHop, hops, amount, htlcAmt,
totalTimelock,
).Wait(30 * time.Second)
if expectedFail {
require.Error(t, err)
return
}
require.NoError(t, err, "unable to send payment")
// Wait for Alice and Bob's second link to receive the revocation.
time.Sleep(2 * time.Second)
// Check that Carol invoice was settled and bandwidth of HTLC
// links were changed.
invoice, err := receiver.registry.LookupInvoice(
context.Background(), rhash,
)
require.NoError(t, err, "unable to get invoice")
require.Equal(t, invpkg.ContractSettled, invoice.State,
"carol invoice haven't been settled")
expectedAliceBandwidth := aliceBandwidthBefore - htlcAmt
require.Equalf(t,
expectedAliceBandwidth, n.aliceChannelLink.Bandwidth(),
"channel bandwidth incorrect: expected %v, got %v",
expectedAliceBandwidth, n.aliceChannelLink.Bandwidth(),
)
expectedBobBandwidth1 := firstBobBandwidthBefore + htlcAmt
require.Equalf(t,
expectedBobBandwidth1, n.firstBobChannelLink.Bandwidth(),
"channel bandwidth incorrect: expected %v, got %v",
expectedBobBandwidth1, n.firstBobChannelLink.Bandwidth(),
)
bobCarolDelta := lnwire.MilliSatoshi(
int64(amount) + expectedCarolInboundFee,
)
expectedBobBandwidth2 := secondBobBandwidthBefore - bobCarolDelta
require.Equalf(t,
expectedBobBandwidth2, n.secondBobChannelLink.Bandwidth(),
"channel bandwidth incorrect: expected %v, got %v",
expectedBobBandwidth2, n.secondBobChannelLink.Bandwidth(),
)
expectedCarolBandwidth := carolBandwidthBefore + bobCarolDelta
require.Equalf(t,
expectedCarolBandwidth, n.carolChannelLink.Bandwidth(),
"channel bandwidth incorrect: expected %v, got %v",
expectedCarolBandwidth, n.carolChannelLink.Bandwidth(),
)
}
// TestChannelLinkCancelFullCommitment tests the ability for links to cancel
// forwarded HTLCs once all of their commitment slots are full.
func TestChannelLinkCancelFullCommitment(t *testing.T) {
@ -5994,7 +6194,9 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("satisfied", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1500, 1000,
200, 150, 0, lnwire.ShortChannelID{})
200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
)
if result != nil {
t.Fatalf("expected policy to be satisfied")
}
@ -6002,7 +6204,9 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("below minhtlc", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 100, 50,
200, 150, 0, lnwire.ShortChannelID{})
200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
)
if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok {
t.Fatalf("expected FailAmountBelowMinimum failure code")
}
@ -6010,7 +6214,9 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("above maxhtlc", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1500, 1200,
200, 150, 0, lnwire.ShortChannelID{})
200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
)
if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok {
t.Fatalf("expected FailTemporaryChannelFailure failure code")
}
@ -6018,7 +6224,9 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("insufficient fee", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1005, 1000,
200, 150, 0, lnwire.ShortChannelID{})
200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
)
if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok {
t.Fatalf("expected FailFeeInsufficient failure code")
}
@ -6031,7 +6239,7 @@ func TestCheckHtlcForward(t *testing.T) {
result := link.CheckHtlcForward(
hash, 100005, 100000, 200,
150, 0, lnwire.ShortChannelID{},
150, models.InboundFee{}, 0, lnwire.ShortChannelID{},
)
_, ok := result.WireMessage().(*lnwire.FailFeeInsufficient)
require.True(t, ok, "expected FailFeeInsufficient failure code")
@ -6039,7 +6247,9 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("expiry too soon", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1500, 1000,
200, 150, 190, lnwire.ShortChannelID{})
200, 150, models.InboundFee{}, 190,
lnwire.ShortChannelID{},
)
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok {
t.Fatalf("expected FailExpiryTooSoon failure code")
}
@ -6047,7 +6257,9 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("incorrect cltv expiry", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1500, 1000,
200, 190, 0, lnwire.ShortChannelID{})
200, 190, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
)
if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok {
t.Fatalf("expected FailIncorrectCltvExpiry failure code")
}
@ -6057,11 +6269,37 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("cltv expiry too far in the future", func(t *testing.T) {
// Check that expiry isn't too far in the future.
result := link.CheckHtlcForward(hash, 1500, 1000,
10200, 10100, 0, lnwire.ShortChannelID{})
10200, 10100, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
)
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok {
t.Fatalf("expected FailExpiryTooFar failure code")
}
})
t.Run("inbound fee satisfied", func(t *testing.T) {
t.Parallel()
result := link.CheckHtlcForward(hash, 1000+10-2-1, 1000,
200, 150, models.InboundFee{Base: -2, Rate: -1_000},
0, lnwire.ShortChannelID{})
if result != nil {
t.Fatalf("expected policy to be satisfied")
}
})
t.Run("inbound fee insufficient", func(t *testing.T) {
t.Parallel()
result := link.CheckHtlcForward(hash, 1000+10-10-101-1, 1000,
200, 150, models.InboundFee{Base: -10, Rate: -100_000},
0, lnwire.ShortChannelID{})
msg := result.WireMessage()
if _, ok := msg.(*lnwire.FailFeeInsufficient); !ok {
t.Fatalf("expected FailFeeInsufficient failure code")
}
})
}
// TestChannelLinkCanceledInvoice in this test checks the interaction

View file

@ -834,7 +834,7 @@ func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) {
func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) {
}
func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi,
lnwire.MilliSatoshi, uint32, uint32, uint32,
lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32,
lnwire.ShortChannelID) *LinkError {
return f.checkHtlcForwardResult

View file

@ -2,6 +2,7 @@ package htlcswitch
import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
@ -103,6 +104,9 @@ type htlcPacket struct {
// but receives a channel_update with the alias SCID. Instead, the
// payer should receive a channel_update with the public SCID.
originalOutgoingChanID lnwire.ShortChannelID
// inboundFee is the fee schedule of the incoming channel.
inboundFee models.InboundFee
}
// inKey returns the circuit key used to identify the incoming htlc.

View file

@ -1178,7 +1178,9 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
failure = link.CheckHtlcForward(
htlc.PaymentHash, packet.incomingAmount,
packet.amount, packet.incomingTimeout,
packet.outgoingTimeout, currentHeight,
packet.outgoingTimeout,
packet.inboundFee,
currentHeight,
packet.originalOutgoingChanID,
)
}

View file

@ -961,12 +961,26 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) (
// routing policy into a forwarding policy.
var forwardingPolicy *models.ForwardingPolicy
if selfPolicy != nil {
var inboundWireFee lnwire.Fee
_, err := selfPolicy.ExtraOpaqueData.ExtractRecords(
&inboundWireFee,
)
if err != nil {
return nil, err
}
inboundFee := models.NewInboundFeeFromWire(
inboundWireFee,
)
forwardingPolicy = &models.ForwardingPolicy{
MinHTLCOut: selfPolicy.MinHTLC,
MaxHTLC: selfPolicy.MaxHTLC,
BaseFee: selfPolicy.FeeBaseMSat,
FeeRate: selfPolicy.FeeProportionalMillionths,
TimeLockDelta: uint32(selfPolicy.TimeLockDelta),
InboundFee: inboundFee,
}
} else {
p.log.Warnf("Unable to find our forwarding policy "+