Merge pull request #3415 from joostjager/mpp

htlcswitch+invoices: allow settling invoices via multi-path payments
This commit is contained in:
Joost Jager 2019-12-11 19:30:15 +01:00 committed by GitHub
commit 62dadff291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1381 additions and 612 deletions

View File

@ -117,6 +117,7 @@ const (
resolveTimeType tlv.Type = 11
expiryHeightType tlv.Type = 13
htlcStateType tlv.Type = 15
mppTotalAmtType tlv.Type = 17
// A set of tlv type definitions used to serialize invoice bodiees.
//
@ -289,6 +290,10 @@ type InvoiceHTLC struct {
// Amt is the amount that is carried by this htlc.
Amt lnwire.MilliSatoshi
// MppTotalAmt is a field for mpp that indicates the expected total
// amount.
MppTotalAmt lnwire.MilliSatoshi
// AcceptHeight is the block height at which the invoice registry
// decided to accept this htlc as a payment to the invoice. At this
// height, the invoice cltv delay must have been met.
@ -323,6 +328,10 @@ type HtlcAcceptDesc struct {
// Amt is the amount that is carried by this htlc.
Amt lnwire.MilliSatoshi
// MppTotalAmt is a field for mpp that indicates the expected total
// amount.
MppTotalAmt lnwire.MilliSatoshi
// Expiry is the expiry height of this htlc.
Expiry uint32
@ -1018,6 +1027,7 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
// Encode the htlc in a tlv stream.
chanID := key.ChanID.ToUint64()
amt := uint64(htlc.Amt)
mppTotalAmt := uint64(htlc.MppTotalAmt)
acceptTime := uint64(htlc.AcceptTime.UnixNano())
resolveTime := uint64(htlc.ResolveTime.UnixNano())
state := uint8(htlc.State)
@ -1034,6 +1044,7 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(htlcStateType, &state),
tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
)
// Convert the custom records to tlv.Record types that are ready
@ -1193,7 +1204,7 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
chanID uint64
state uint8
acceptTime, resolveTime uint64
amt uint64
amt, mppTotalAmt uint64
)
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(chanIDType, &chanID),
@ -1206,6 +1217,7 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(htlcStateType, &state),
tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
)
if err != nil {
return nil, err
@ -1221,6 +1233,7 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
htlc.ResolveTime = time.Unix(0, int64(resolveTime))
htlc.State = HtlcState(state)
htlc.Amt = lnwire.MilliSatoshi(amt)
htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
// Reconstruct the custom records fields from the parsed types
// map return from the tlv parser.
@ -1324,6 +1337,7 @@ func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucke
htlc := &InvoiceHTLC{
Amt: htlcUpdate.Amt,
MppTotalAmt: htlcUpdate.MppTotalAmt,
Expiry: htlcUpdate.Expiry,
AcceptHeight: uint32(htlcUpdate.AcceptHeight),
AcceptTime: now,

View File

@ -790,9 +790,12 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
panic(err)
}
finalCltvRejectDelta := int32(5)
registry := invoices.NewRegistry(cdb, finalCltvRejectDelta)
registry := invoices.NewRegistry(
cdb,
&invoices.RegistryConfig{
FinalCltvRejectDelta: 5,
},
)
registry.Start()
return &mockInvoiceRegistry{

77
invoices/clock_test.go Normal file
View File

@ -0,0 +1,77 @@
package invoices
import (
"sync"
"time"
)
// testClock can be used in tests to mock time.
type testClock struct {
currentTime time.Time
timeChanMap map[time.Time][]chan time.Time
timeLock sync.Mutex
}
// newTestClock returns a new test clock.
func newTestClock(startTime time.Time) *testClock {
return &testClock{
currentTime: startTime,
timeChanMap: make(map[time.Time][]chan time.Time),
}
}
// now returns the current (test) time.
func (c *testClock) now() time.Time {
c.timeLock.Lock()
defer c.timeLock.Unlock()
return c.currentTime
}
// tickAfter returns a channel that will receive a tick at the specified time.
func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time {
c.timeLock.Lock()
defer c.timeLock.Unlock()
triggerTime := c.currentTime.Add(duration)
log.Debugf("tickAfter called: duration=%v, trigger_time=%v",
duration, triggerTime)
ch := make(chan time.Time, 1)
// If already expired, tick immediately.
if !triggerTime.After(c.currentTime) {
ch <- c.currentTime
return ch
}
// Otherwise store the channel until the trigger time is there.
chans := c.timeChanMap[triggerTime]
chans = append(chans, ch)
c.timeChanMap[triggerTime] = chans
return ch
}
// setTime sets the (test) time and triggers tick channels when they expire.
func (c *testClock) setTime(now time.Time) {
c.timeLock.Lock()
defer c.timeLock.Unlock()
c.currentTime = now
remainingChans := make(map[time.Time][]chan time.Time)
for triggerTime, chans := range c.timeChanMap {
// If the trigger time is still in the future, keep this channel
// in the channel map for later.
if triggerTime.After(now) {
remainingChans[triggerTime] = chans
continue
}
for _, c := range chans {
c <- now
}
}
c.timeChanMap = remainingChans
}

View File

@ -2,8 +2,10 @@ package invoices
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
@ -26,6 +28,12 @@ var (
ErrShuttingDown = errors.New("invoice registry shutting down")
)
const (
// DefaultHtlcHoldDuration defines the default for how long mpp htlcs
// are held while waiting for the other set members to arrive.
DefaultHtlcHoldDuration = 120 * time.Second
)
// HodlEvent describes how an htlc should be resolved. If HodlEvent.Preimage is
// set, the event indicates a settle event. If Preimage is nil, it is a cancel
// event.
@ -41,6 +49,48 @@ type HodlEvent struct {
AcceptHeight int32
}
// RegistryConfig contains the configuration parameters for invoice registry.
type RegistryConfig struct {
// FinalCltvRejectDelta defines the number of blocks before the expiry
// of the htlc where we no longer settle it as an exit hop and instead
// cancel it back. Normally this value should be lower than the cltv
// expiry of any invoice we create and the code effectuating this should
// not be hit.
FinalCltvRejectDelta int32
// HtlcHoldDuration defines for how long mpp htlcs are held while
// waiting for the other set members to arrive.
HtlcHoldDuration time.Duration
// Now returns the current time.
Now func() time.Time
// TickAfter returns a channel that is sent on after the specified
// duration as passed.
TickAfter func(duration time.Duration) <-chan time.Time
}
// htlcReleaseEvent describes an htlc auto-release event. It is used to release
// mpp htlcs for which the complete set didn't arrive in time.
type htlcReleaseEvent struct {
// hash is the payment hash of the htlc to release.
hash lntypes.Hash
// key is the circuit key of the htlc to release.
key channeldb.CircuitKey
// releaseTime is the time at which to release the htlc.
releaseTime time.Time
}
// Less is used to order PriorityQueueItem's by their release time such that
// items with the older release time are at the top of the queue.
//
// NOTE: Part of the queue.PriorityQueueItem interface.
func (r *htlcReleaseEvent) Less(other queue.PriorityQueueItem) bool {
return r.releaseTime.Before(other.(*htlcReleaseEvent).releaseTime)
}
// InvoiceRegistry is a central registry of all the outstanding invoices
// created by the daemon. The registry is a thin wrapper around a map in order
// to ensure that all updates/reads are thread safe.
@ -49,6 +99,9 @@ type InvoiceRegistry struct {
cdb *channeldb.DB
// cfg contains the registry's configuration parameters.
cfg *RegistryConfig
clientMtx sync.Mutex
nextClientID uint32
notificationClients map[uint32]*InvoiceSubscription
@ -69,12 +122,9 @@ type InvoiceRegistry struct {
// subscriber. This is used to unsubscribe from all hashes efficiently.
hodlReverseSubscriptions map[chan<- interface{}]map[channeldb.CircuitKey]struct{}
// finalCltvRejectDelta defines the number of blocks before the expiry
// of the htlc where we no longer settle it as an exit hop and instead
// cancel it back. Normally this value should be lower than the cltv
// expiry of any invoice we create and the code effectuating this should
// not be hit.
finalCltvRejectDelta int32
// htlcAutoReleaseChan contains the new htlcs that need to be
// auto-released.
htlcAutoReleaseChan chan *htlcReleaseEvent
wg sync.WaitGroup
quit chan struct{}
@ -84,8 +134,7 @@ type InvoiceRegistry struct {
// wraps the persistent on-disk invoice storage with an additional in-memory
// layer. The in-memory layer is in place such that debug invoices can be added
// which are volatile yet available system wide within the daemon.
func NewRegistry(cdb *channeldb.DB, finalCltvRejectDelta int32) *InvoiceRegistry {
func NewRegistry(cdb *channeldb.DB, cfg *RegistryConfig) *InvoiceRegistry {
return &InvoiceRegistry{
cdb: cdb,
notificationClients: make(map[uint32]*InvoiceSubscription),
@ -95,7 +144,8 @@ func NewRegistry(cdb *channeldb.DB, finalCltvRejectDelta int32) *InvoiceRegistry
invoiceEvents: make(chan interface{}, 100),
hodlSubscriptions: make(map[channeldb.CircuitKey]map[chan<- interface{}]struct{}),
hodlReverseSubscriptions: make(map[chan<- interface{}]map[channeldb.CircuitKey]struct{}),
finalCltvRejectDelta: finalCltvRejectDelta,
cfg: cfg,
htlcAutoReleaseChan: make(chan *htlcReleaseEvent),
quit: make(chan struct{}),
}
}
@ -104,7 +154,7 @@ func NewRegistry(cdb *channeldb.DB, finalCltvRejectDelta int32) *InvoiceRegistry
func (i *InvoiceRegistry) Start() error {
i.wg.Add(1)
go i.invoiceEventNotifier()
go i.invoiceEventLoop()
return nil
}
@ -124,13 +174,31 @@ type invoiceEvent struct {
invoice *channeldb.Invoice
}
// invoiceEventNotifier is the dedicated goroutine responsible for accepting
// tickAt returns a channel that ticks at the specified time. If the time has
// already passed, it will tick immediately.
func (i *InvoiceRegistry) tickAt(t time.Time) <-chan time.Time {
now := i.cfg.Now()
return i.cfg.TickAfter(t.Sub(now))
}
// invoiceEventLoop is the dedicated goroutine responsible for accepting
// new notification subscriptions, cancelling old subscriptions, and
// dispatching new invoice events.
func (i *InvoiceRegistry) invoiceEventNotifier() {
func (i *InvoiceRegistry) invoiceEventLoop() {
defer i.wg.Done()
// Set up a heap for htlc auto-releases.
autoReleaseHeap := &queue.PriorityQueue{}
for {
// If there is something to release, set up a release tick
// channel.
var nextReleaseTick <-chan time.Time
if autoReleaseHeap.Len() > 0 {
head := autoReleaseHeap.Top().(*htlcReleaseEvent)
nextReleaseTick = i.tickAt(head.releaseTime)
}
select {
// A new invoice subscription for all invoices has just arrived!
// We'll query for any backlog notifications, then add it to the
@ -196,6 +264,29 @@ func (i *InvoiceRegistry) invoiceEventNotifier() {
i.singleNotificationClients[e.id] = e
}
// A new htlc came in for auto-release.
case event := <-i.htlcAutoReleaseChan:
log.Debugf("Scheduling auto-release for htlc: "+
"hash=%v, key=%v at %v",
event.hash, event.key, event.releaseTime)
// We use an independent timer for every htlc rather
// than a set timer that is reset with every htlc coming
// in. Otherwise the sender could keep resetting the
// timer until the broadcast window is entered and our
// channel is force closed.
autoReleaseHeap.Push(event)
// The htlc at the top of the heap needs to be auto-released.
case <-nextReleaseTick:
event := autoReleaseHeap.Pop().(*htlcReleaseEvent)
err := i.cancelSingleHtlc(
event.hash, event.key,
)
if err != nil {
log.Errorf("HTLC timer: %v", err)
}
case <-i.quit:
return
}
@ -412,6 +503,114 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice,
return i.cdb.LookupInvoice(rHash)
}
// startHtlcTimer starts a new timer via the invoice registry main loop that
// cancels a single htlc on an invoice when the htlc hold duration has passed.
func (i *InvoiceRegistry) startHtlcTimer(hash lntypes.Hash,
key channeldb.CircuitKey, acceptTime time.Time) error {
releaseTime := acceptTime.Add(i.cfg.HtlcHoldDuration)
event := &htlcReleaseEvent{
hash: hash,
key: key,
releaseTime: releaseTime,
}
select {
case i.htlcAutoReleaseChan <- event:
return nil
case <-i.quit:
return ErrShuttingDown
}
}
// cancelSingleHtlc cancels a single accepted htlc on an invoice.
func (i *InvoiceRegistry) cancelSingleHtlc(hash lntypes.Hash,
key channeldb.CircuitKey) error {
i.Lock()
defer i.Unlock()
updateInvoice := func(invoice *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, error) {
// Only allow individual htlc cancelation on open invoices.
if invoice.State != channeldb.ContractOpen {
log.Debugf("cancelSingleHtlc: invoice %v no longer "+
"open", hash)
return nil, nil
}
// Lookup the current status of the htlc in the database.
htlc, ok := invoice.Htlcs[key]
if !ok {
return nil, fmt.Errorf("htlc %v not found", key)
}
// Cancelation is only possible if the htlc wasn't already
// resolved.
if htlc.State != channeldb.HtlcStateAccepted {
log.Debugf("cancelSingleHtlc: htlc %v on invoice %v "+
"is already resolved", key, hash)
return nil, nil
}
log.Debugf("cancelSingleHtlc: cancelling htlc %v on invoice %v",
key, hash)
// Return an update descriptor that cancels htlc and keeps
// invoice open.
canceledHtlcs := map[channeldb.CircuitKey]struct{}{
key: {},
}
return &channeldb.InvoiceUpdateDesc{
CancelHtlcs: canceledHtlcs,
}, nil
}
// Try to mark the specified htlc as canceled in the invoice database.
// Intercept the update descriptor to set the local updated variable. If
// no invoice update is performed, we can return early.
var updated bool
invoice, err := i.cdb.UpdateInvoice(hash,
func(invoice *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, error) {
updateDesc, err := updateInvoice(invoice)
if err != nil {
return nil, err
}
updated = updateDesc != nil
return updateDesc, err
},
)
if err != nil {
return err
}
if !updated {
return nil
}
// The invoice has been updated. Notify subscribers of the htlc
// resolution.
htlc, ok := invoice.Htlcs[key]
if !ok {
return fmt.Errorf("htlc %v not found", key)
}
if htlc.State == channeldb.HtlcStateCanceled {
i.notifyHodlSubscribers(HodlEvent{
CircuitKey: key,
AcceptHeight: int32(htlc.AcceptHeight),
Preimage: nil,
})
}
return nil
}
// NotifyExitHopHtlc attempts to mark an invoice as settled. The return value
// describes how the htlc should be resolved.
//
@ -422,6 +621,11 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice,
// to be taken on the htlc (settle or cancel). The caller needs to ensure that
// the channel is either buffered or received on from another goroutine to
// prevent deadlock.
//
// In the case that the htlc is part of a larger set of htlcs that pay to the
// same invoice (multi-path payment), the htlc is held until the set is
// complete. If the set doesn't fully arrive in time, a timer will cancel the
// held htlc.
func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
amtPaid lnwire.MilliSatoshi, expiry uint32, currentHeight int32,
circuitKey channeldb.CircuitKey, hodlChan chan<- interface{},
@ -430,9 +634,11 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
i.Lock()
defer i.Unlock()
mpp := payload.MultiPath()
debugLog := func(s string) {
log.Debugf("Invoice(%x): %v, amt=%v, expiry=%v, circuit=%v",
rHash[:], s, amtPaid, expiry, circuitKey)
log.Debugf("Invoice(%x): %v, amt=%v, expiry=%v, circuit=%v, "+
"mpp=%v", rHash[:], s, amtPaid, expiry, circuitKey, mpp)
}
// Create the update context containing the relevant details of the
@ -442,8 +648,9 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
amtPaid: amtPaid,
expiry: expiry,
currentHeight: currentHeight,
finalCltvRejectDelta: i.finalCltvRejectDelta,
finalCltvRejectDelta: i.cfg.FinalCltvRejectDelta,
customRecords: payload.CustomRecords(),
mpp: mpp,
}
// We'll attempt to settle an invoice matching this rHash on disk (if
@ -508,6 +715,21 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
}, nil
case channeldb.HtlcStateSettled:
// Also settle any previously accepted htlcs. The invoice state
// is leading. If an htlc is marked as settled, we should follow
// now and settle the htlc with our peer.
for key, htlc := range invoice.Htlcs {
if htlc.State != channeldb.HtlcStateSettled {
continue
}
i.notifyHodlSubscribers(HodlEvent{
CircuitKey: key,
Preimage: &invoice.Terms.PaymentPreimage,
AcceptHeight: int32(htlc.AcceptHeight),
})
}
return &HodlEvent{
CircuitKey: circuitKey,
Preimage: &invoice.Terms.PaymentPreimage,
@ -515,6 +737,19 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
}, nil
case channeldb.HtlcStateAccepted:
// (Re)start the htlc timer if the invoice is still open. It can
// only happen for mpp payments that there are htlcs in state
// Accepted while the invoice is Open.
if invoice.State == channeldb.ContractOpen {
err := i.startHtlcTimer(
rHash, circuitKey,
invoiceHtlc.AcceptTime,
)
if err != nil {
return nil, err
}
}
i.hodlSubscribe(hodlChan, circuitKey)
return nil, nil

View File

@ -16,6 +16,8 @@ import (
var (
testTimeout = 5 * time.Second
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
preimage = lntypes.Preimage{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
@ -59,19 +61,29 @@ var (
type testContext struct {
registry *InvoiceRegistry
clock *testClock
cleanup func()
t *testing.T
}
func newTestContext(t *testing.T) *testContext {
clock := newTestClock(testTime)
cdb, cleanup, err := newDB()
if err != nil {
t.Fatal(err)
}
cdb.Now = clock.now
// Instantiate and start the invoice ctx.registry.
registry := NewRegistry(cdb, testFinalCltvRejectDelta)
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
HtlcHoldDuration: 30 * time.Second,
Now: clock.now,
TickAfter: clock.tickAfter,
}
registry := NewRegistry(cdb, &cfg)
err = registry.Start()
if err != nil {
@ -81,6 +93,7 @@ func newTestContext(t *testing.T) *testContext {
ctx := testContext{
registry: registry,
clock: clock,
t: t,
cleanup: func() {
registry.Stop()
@ -390,7 +403,10 @@ func TestSettleHoldInvoice(t *testing.T) {
defer cleanup()
// Instantiate and start the invoice ctx.registry.
registry := NewRegistry(cdb, testFinalCltvRejectDelta)
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
}
registry := NewRegistry(cdb, &cfg)
err = registry.Start()
if err != nil {
@ -558,7 +574,10 @@ func TestCancelHoldInvoice(t *testing.T) {
defer cleanup()
// Instantiate and start the invoice ctx.registry.
registry := NewRegistry(cdb, testFinalCltvRejectDelta)
cfg := RegistryConfig{
FinalCltvRejectDelta: testFinalCltvRejectDelta,
}
registry := NewRegistry(cdb, &cfg)
err = registry.Start()
if err != nil {
@ -674,3 +693,85 @@ func (p *mockPayload) MultiPath() *record.MPP {
func (p *mockPayload) CustomRecords() hop.CustomRecordSet {
return make(hop.CustomRecordSet)
}
// TestSettleMpp tests settling of an invoice with multiple partial payments.
func TestSettleMpp(t *testing.T) {
defer timeout(t)()
ctx := newTestContext(t)
defer ctx.cleanup()
// Add the invoice.
_, err := ctx.registry.AddInvoice(testInvoice, hash)
if err != nil {
t.Fatal(err)
}
mppPayload := &mockPayload{
mpp: record.NewMPP(testInvoiceAmt, [32]byte{}),
}
// Send htlc 1.
hodlChan1 := make(chan interface{}, 1)
event, err := ctx.registry.NotifyExitHopHtlc(
hash, testInvoice.Terms.Value/2,
testHtlcExpiry,
testCurrentHeight, getCircuitKey(10), hodlChan1, mppPayload,
)
if err != nil {
t.Fatal(err)
}
if event != nil {
t.Fatal("expected no direct resolution")
}
// Simulate mpp timeout releasing htlc 1.
ctx.clock.setTime(testTime.Add(30 * time.Second))
hodlEvent := (<-hodlChan1).(HodlEvent)
if hodlEvent.Preimage != nil {
t.Fatal("expected cancel event")
}
// Send htlc 2.
hodlChan2 := make(chan interface{}, 1)
event, err = ctx.registry.NotifyExitHopHtlc(
hash, testInvoice.Terms.Value/2,
testHtlcExpiry,
testCurrentHeight, getCircuitKey(11), hodlChan2, mppPayload,
)
if err != nil {
t.Fatal(err)
}
if event != nil {
t.Fatal("expected no direct resolution")
}
// Send htlc 3.
hodlChan3 := make(chan interface{}, 1)
event, err = ctx.registry.NotifyExitHopHtlc(
hash, testInvoice.Terms.Value/2,
testHtlcExpiry,
testCurrentHeight, getCircuitKey(12), hodlChan3, mppPayload,
)
if err != nil {
t.Fatal(err)
}
if event == nil {
t.Fatal("expected a settle event")
}
// Check that settled amount is equal to the sum of values of the htlcs
// 0 and 1.
inv, err := ctx.registry.LookupInvoice(hash)
if err != nil {
t.Fatal(err)
}
if inv.State != channeldb.ContractSettled {
t.Fatal("expected invoice to be settled")
}
if inv.AmtPaid != testInvoice.Terms.Value {
t.Fatalf("amount incorrect, expected %v but got %v",
testInvoice.Terms.Value, inv.AmtPaid)
}
}

View File

@ -7,6 +7,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
)
// updateResult is the result of the invoice update call.
@ -24,6 +25,13 @@ const (
resultDuplicateToSettled
resultAccepted
resultSettled
resultInvoiceNotOpen
resultPartialAccepted
resultMppInProgress
resultAddressMismatch
resultHtlcSetTotalMismatch
resultHtlcSetTotalTooLow
resultHtlcSetOverpayment
)
// String returns a human-readable representation of the invoice update result.
@ -63,6 +71,27 @@ func (u updateResult) String() string {
case resultSettled:
return "settled"
case resultInvoiceNotOpen:
return "invoice no longer open"
case resultPartialAccepted:
return "partial payment accepted"
case resultMppInProgress:
return "mpp reception in progress"
case resultAddressMismatch:
return "payment address mismatch"
case resultHtlcSetTotalMismatch:
return "htlc total amt doesn't match set total"
case resultHtlcSetTotalTooLow:
return "set total too low for invoice"
case resultHtlcSetOverpayment:
return "mpp is overpaying set total"
default:
return "unknown"
}
@ -77,6 +106,7 @@ type invoiceUpdateCtx struct {
currentHeight int32
finalCltvRejectDelta int32
customRecords hop.CustomRecordSet
mpp *record.MPP
}
// updateInvoice is a callback for DB.UpdateInvoice that contains the invoice
@ -102,8 +132,125 @@ func updateInvoice(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) (
}
}
// If the invoice is already canceled, there is no further checking to
// do.
if ctx.mpp == nil {
return updateLegacy(ctx, inv)
}
return updateMpp(ctx, inv)
}
// updateMpp is a callback for DB.UpdateInvoice that contains the invoice
// settlement logic for mpp payments.
func updateMpp(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, updateResult, error) {
// Start building the accept descriptor.
acceptDesc := &channeldb.HtlcAcceptDesc{
Amt: ctx.amtPaid,
Expiry: ctx.expiry,
AcceptHeight: ctx.currentHeight,
MppTotalAmt: ctx.mpp.TotalMsat(),
CustomRecords: ctx.customRecords,
}
// Only accept payments to open invoices. This behaviour differs from
// non-mpp payments that are accepted even after the invoice is settled.
// Because non-mpp payments don't have a payment address, this is needed
// to thwart probing.
if inv.State != channeldb.ContractOpen {
return nil, resultInvoiceNotOpen, nil
}
// Check the payment address that authorizes the payment.
if ctx.mpp.PaymentAddr() != inv.Terms.PaymentAddr {
return nil, resultAddressMismatch, nil
}
// Don't accept zero-valued sets.
if ctx.mpp.TotalMsat() == 0 {
return nil, resultHtlcSetTotalTooLow, nil
}
// Check that the total amt of the htlc set is high enough. In case this
// is a zero-valued invoice, it will always be enough.
if ctx.mpp.TotalMsat() < inv.Terms.Value {
return nil, resultHtlcSetTotalTooLow, nil
}
// Check whether total amt matches other htlcs in the set.
var newSetTotal lnwire.MilliSatoshi
for _, htlc := range inv.Htlcs {
// Only consider accepted mpp htlcs. It is possible that there
// are htlcs registered in the invoice database that previously
// timed out and are in the canceled state now.
if htlc.State != channeldb.HtlcStateAccepted {
continue
}
if ctx.mpp.TotalMsat() != htlc.MppTotalAmt {
return nil, resultHtlcSetTotalMismatch, nil
}
newSetTotal += htlc.Amt
}
// Add amount of new htlc.
newSetTotal += ctx.amtPaid
// Make sure the communicated set total isn't overpaid.
if newSetTotal > ctx.mpp.TotalMsat() {
return nil, resultHtlcSetOverpayment, nil
}
// The invoice is still open. Check the expiry.
if ctx.expiry < uint32(ctx.currentHeight+ctx.finalCltvRejectDelta) {
return nil, resultExpiryTooSoon, nil
}
if ctx.expiry < uint32(ctx.currentHeight+inv.Terms.FinalCltvDelta) {
return nil, resultExpiryTooSoon, nil
}
// Record HTLC in the invoice database.
newHtlcs := map[channeldb.CircuitKey]*channeldb.HtlcAcceptDesc{
ctx.circuitKey: acceptDesc,
}
update := channeldb.InvoiceUpdateDesc{
AddHtlcs: newHtlcs,
}
// If the invoice cannot be settled yet, only record the htlc.
setComplete := newSetTotal == ctx.mpp.TotalMsat()
if !setComplete {
return &update, resultPartialAccepted, nil
}
// Check to see if we can settle or this is an hold invoice and
// we need to wait for the preimage.
holdInvoice := inv.Terms.PaymentPreimage == channeldb.UnknownPreimage
if holdInvoice {
update.State = &channeldb.InvoiceStateUpdateDesc{
NewState: channeldb.ContractAccepted,
}
return &update, resultAccepted, nil
}
update.State = &channeldb.InvoiceStateUpdateDesc{
NewState: channeldb.ContractSettled,
Preimage: inv.Terms.PaymentPreimage,
}
return &update, resultSettled, nil
}
// updateLegacy is a callback for DB.UpdateInvoice that contains the invoice
// settlement logic for legacy payments.
func updateLegacy(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, updateResult, error) {
// If the invoice is already canceled, there is no further
// checking to do.
if inv.State == channeldb.ContractCanceled {
return nil, resultInvoiceAlreadyCanceled, nil
}
@ -116,6 +263,20 @@ func updateInvoice(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) (
return nil, resultAmountTooLow, nil
}
// TODO(joostjager): Check invoice mpp required feature
// bit when feature becomes mandatory.
// Don't allow settling the invoice with an old style
// htlc if we are already in the process of gathering an
// mpp set.
for _, htlc := range inv.Htlcs {
if htlc.State == channeldb.HtlcStateAccepted &&
htlc.MppTotalAmt > 0 {
return nil, resultMppInProgress, nil
}
}
// The invoice is still open. Check the expiry.
if ctx.expiry < uint32(ctx.currentHeight+ctx.finalCltvRejectDelta) {
return nil, resultExpiryTooSoon, nil

View File

@ -75,14 +75,15 @@ func CreateRPCInvoice(invoice *channeldb.Invoice,
}
rpcHtlc := lnrpc.InvoiceHTLC{
ChanId: key.ChanID.ToUint64(),
HtlcIndex: key.HtlcID,
AcceptHeight: int32(htlc.AcceptHeight),
AcceptTime: htlc.AcceptTime.Unix(),
ExpiryHeight: int32(htlc.Expiry),
AmtMsat: uint64(htlc.Amt),
State: state,
CustomRecords: htlc.CustomRecords,
ChanId: key.ChanID.ToUint64(),
HtlcIndex: key.HtlcID,
AcceptHeight: int32(htlc.AcceptHeight),
AcceptTime: htlc.AcceptTime.Unix(),
ExpiryHeight: int32(htlc.Expiry),
AmtMsat: uint64(htlc.Amt),
State: state,
CustomRecords: htlc.CustomRecords,
MppTotalAmtMsat: uint64(htlc.MppTotalAmt),
}
// Only report resolved times if htlc is resolved.

File diff suppressed because it is too large Load Diff

View File

@ -2407,6 +2407,9 @@ message InvoiceHTLC {
/// Custom tlv records.
map<uint64, bytes> custom_records = 9 [json_name = "custom_records"];
/// The total amount of the mpp payment in msat.
uint64 mpp_total_amt_msat = 10 [json_name = "mpp_total_amt_msat"];
}
message AddInvoiceResponse {

View File

@ -2854,6 +2854,11 @@
"format": "byte"
},
"description": "/ Custom tlv records."
},
"mpp_total_amt_msat": {
"type": "string",
"format": "uint64",
"description": "/ The total amount of the mpp payment in msat."
}
},
"title": "/ Details of an HTLC that paid to an invoice"

View File

@ -4712,13 +4712,29 @@ func testSingleHopSendToRouteCase(net *lntest.NetworkHarness, t *harnessTest,
}
// Create invoices for Dave, which expect a payment from Carol.
_, rHashes, _, err := createPayReqs(
payReqs, rHashes, _, err := createPayReqs(
dave, paymentAmtSat, numPayments,
)
if err != nil {
t.Fatalf("unable to create pay reqs: %v", err)
}
// Reconstruct payment addresses.
var payAddrs [][]byte
for _, payReq := range payReqs {
ctx, _ := context.WithTimeout(
context.Background(), defaultTimeout,
)
resp, err := dave.DecodePayReq(
ctx,
&lnrpc.PayReqString{PayReq: payReq},
)
if err != nil {
t.Fatalf("decode pay req: %v", err)
}
payAddrs = append(payAddrs, resp.PaymentAddr)
}
// Query for routes to pay from Carol to Dave.
// We set FinalCltvDelta to 40 since by default QueryRoutes returns
// the last hop with a final cltv delta of 9 where as the default in
@ -4741,12 +4757,10 @@ func testSingleHopSendToRouteCase(net *lntest.NetworkHarness, t *harnessTest,
// Construct a closure that will set MPP fields on the route, which
// allows us to test MPP payments.
setMPPFields := func(i int) {
addr := [32]byte{byte(i)}
hop := r.Hops[len(r.Hops)-1]
hop.TlvPayload = true
hop.MppRecord = &lnrpc.MPPRecord{
PaymentAddr: addr[:],
PaymentAddr: payAddrs[i],
TotalAmtMsat: paymentAmtSat * 1000,
}
}
@ -4930,8 +4944,8 @@ func testSingleHopSendToRouteCase(net *lntest.NetworkHarness, t *harnessTest,
hop.MppRecord.TotalAmtMsat)
}
expAddr := [32]byte{byte(i)}
if !bytes.Equal(hop.MppRecord.PaymentAddr, expAddr[:]) {
expAddr := payAddrs[i]
if !bytes.Equal(hop.MppRecord.PaymentAddr, expAddr) {
t.Fatalf("incorrect mpp payment addr for payment %d "+
"want: %x, got: %x",
i, expAddr, hop.MppRecord.PaymentAddr)

74
queue/priority_queue.go Normal file
View File

@ -0,0 +1,74 @@
package queue
import (
"container/heap"
)
// PriorityQueueItem is an interface that represents items in a PriorityQueue.
// Users of PriorityQueue will need to define a Less function such that
// PriorityQueue will be able to use that to build and restore an underlying
// heap.
type PriorityQueueItem interface {
Less(other PriorityQueueItem) bool
}
type priorityQueue []PriorityQueueItem
// Len returns the length of the priorityQueue.
func (pq priorityQueue) Len() int { return len(pq) }
// Less is used to order PriorityQueueItem items in the queue.
func (pq priorityQueue) Less(i, j int) bool {
return pq[i].Less(pq[j])
}
// Swap swaps two items in the priorityQueue. Swap is used by heap.Interface.
func (pq priorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
}
// Push adds a new item the the priorityQueue.
func (pq *priorityQueue) Push(x interface{}) {
item := x.(PriorityQueueItem)
*pq = append(*pq, item)
}
// Pop removes the top item from the priorityQueue.
func (pq *priorityQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
old[n-1] = nil
*pq = old[0 : n-1]
return item
}
// Priority wrap a standard heap in a more object-oriented structure.
type PriorityQueue struct {
queue priorityQueue
}
// Len returns the length of the queue.
func (pq *PriorityQueue) Len() int {
return len(pq.queue)
}
// Empty returns true if the queue is empty.
func (pq *PriorityQueue) Empty() bool {
return len(pq.queue) == 0
}
// Push adds an item to the priority queue.
func (pq *PriorityQueue) Push(item PriorityQueueItem) {
heap.Push(&pq.queue, item)
}
// Pop removes the top most item from the queue.
func (pq *PriorityQueue) Pop() PriorityQueueItem {
return heap.Pop(&pq.queue).(PriorityQueueItem)
}
// Top returns the top most item from the queue without removing it.
func (pq *PriorityQueue) Top() PriorityQueueItem {
return pq.queue[0]
}

View File

@ -0,0 +1,67 @@
package queue
import (
"math/rand"
"testing"
"time"
)
type testQueueItem struct {
Value int
Expiry time.Time
}
func (e testQueueItem) Less(other PriorityQueueItem) bool {
return e.Expiry.Before(other.(*testQueueItem).Expiry)
}
func TestExpiryQueue(t *testing.T) {
// The number of elements we push to the queue.
count := 100
// Generate a random permutation of a range [0, count)
array := rand.Perm(count)
// t0 holds a reference time point.
t0 := time.Date(1975, time.April, 5, 12, 0, 0, 0, time.UTC)
var testQueue PriorityQueue
if testQueue.Len() != 0 && !testQueue.Empty() {
t.Fatal("Expected the queue to be empty")
}
// Create elements with expiry of t0 + value * second.
for _, value := range array {
testQueue.Push(&testQueueItem{
Value: value,
Expiry: t0.Add(time.Duration(value) * time.Second),
})
}
// Now expect that we can retrieve elements in order of their expiry.
for i := 0; i < count; i++ {
expectedQueueLen := count - i
if testQueue.Len() != expectedQueueLen {
t.Fatalf("Expected the queue len %v, got %v",
expectedQueueLen, testQueue.Len())
}
if testQueue.Empty() {
t.Fatalf("Did not expect the queue to be empty")
}
top := testQueue.Top().(*testQueueItem)
if top.Value != i {
t.Fatalf("Expected queue top %v, got %v", i, top.Value)
}
popped := testQueue.Pop().(*testQueueItem)
if popped != top {
t.Fatalf("Expected queue top %v equal to popped: %v",
top, popped)
}
}
if testQueue.Len() != 0 || !testQueue.Empty() {
t.Fatalf("Expected the queue to be empty")
}
}

View File

@ -378,6 +378,13 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
return nil, err
}
registryConfig := invoices.RegistryConfig{
FinalCltvRejectDelta: defaultFinalCltvRejectDelta,
HtlcHoldDuration: invoices.DefaultHtlcHoldDuration,
Now: time.Now,
TickAfter: time.After,
}
s := &server{
chanDB: chanDB,
cc: cc,
@ -386,9 +393,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
readPool: readPool,
chansToRestore: chansToRestore,
invoices: invoices.NewRegistry(
chanDB, defaultFinalCltvRejectDelta,
),
invoices: invoices.NewRegistry(chanDB, &registryConfig),
channelNotifier: channelnotifier.New(chanDB),