lnd/routing/shards/shard_tracker.go
Johan T. Halseth 41ae3530a3
routing/payment_lifecycle: use ShardTracker to track shards
We'll let the payment's lifecycle register each shard it's sending with
the ShardTracker, canceling failed shards. This will be the foundation
for correct AMP derivation for each shard we'll send.
2021-04-27 09:43:40 +02:00

136 lines
4.0 KiB
Go

package shards
import (
"fmt"
"sync"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/record"
)
// PaymentShard is an interface representing a shard tracked by the
// ShardTracker. It contains options that are specific to the given shard that
// might differ from the overall payment.
type PaymentShard interface {
// Hash returns the hash used for the HTLC representing this shard.
Hash() lntypes.Hash
// MPP returns any extra MPP records that should be set for the final
// hop on the route used by this shard.
MPP() *record.MPP
// AMP returns any extra AMP records that should be set for the final
// hop on the route used by this shard.
AMP() *record.AMP
}
// ShardTracker is an interfae representing a tracker that keeps track of the
// inflight shards of a payment, and is able to assign new shards the correct
// options such as hash and extra records.
type ShardTracker interface {
// NewShard registers a new attempt with the ShardTracker and returns a
// new shard representing this attempt. This attempt's shard should be
// canceled if it ends up not being used by the overall payment, i.e.
// if the attempt fails.
NewShard(uint64, bool) (PaymentShard, error)
// CancelShard cancel's the shard corresponding to the given attempt
// ID. This lets the ShardTracker free up any slots used by this shard,
// and in case of AMP payments return the share used by this shard to
// the root share.
CancelShard(uint64) error
// GetHash retrieves the hash used by the shard of the given attempt
// ID. This wil return an error if the attempt ID is unknown.
GetHash(uint64) (lntypes.Hash, error)
}
// Shard is a struct used for simple shards where we obly need to keep map it
// to a single hash.
type Shard struct {
hash lntypes.Hash
}
// Hash returns the hash used for the HTLC representing this shard.
func (s *Shard) Hash() lntypes.Hash {
return s.hash
}
// MPP returns any extra MPP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) MPP() *record.MPP {
return nil
}
// AMP returns any extra AMP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) AMP() *record.AMP {
return nil
}
// SimpleShardTracker is an implementation of the ShardTracker interface that
// simply maps attempt IDs to hashes. New shards will be given a static payment
// hash. This should be used for regular and MPP payments, in addition to
// resumed payments where all the attempt's hashes have already been created.
type SimpleShardTracker struct {
hash lntypes.Hash
shards map[uint64]lntypes.Hash
sync.Mutex
}
// A compile time check to ensure SimpleShardTracker implements the
// ShardTracker interface.
var _ ShardTracker = (*SimpleShardTracker)(nil)
// NewSimpleShardTracker creates a new intance of the SimpleShardTracker with
// the given payment hash and existing attempts.
func NewSimpleShardTracker(paymentHash lntypes.Hash,
shards map[uint64]lntypes.Hash) ShardTracker {
if shards == nil {
shards = make(map[uint64]lntypes.Hash)
}
return &SimpleShardTracker{
hash: paymentHash,
shards: shards,
}
}
// NewShard registers a new attempt with the ShardTracker and returns a
// new shard representing this attempt. This attempt's shard should be canceled
// if it ends up not being used by the overall payment, i.e. if the attempt
// fails.
func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
m.Lock()
m.shards[id] = m.hash
m.Unlock()
return &Shard{
hash: m.hash,
}, nil
}
// CancelShard cancel's the shard corresponding to the given attempt ID.
func (m *SimpleShardTracker) CancelShard(id uint64) error {
m.Lock()
delete(m.shards, id)
m.Unlock()
return nil
}
// GetHash retrieves the hash used by the shard of the given attempt ID. This
// will return an error if the attempt ID is unknown.
func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) {
m.Lock()
hash, ok := m.shards[id]
m.Unlock()
if !ok {
return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+
"not found", id)
}
return hash, nil
}