mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 18:10:34 +01:00
2d397b12b1
We'll use this AMP-specific ShardTracker for AMP payments. It will be used to derive hashes for each HTLC attempt using the underlying AMP derivation scheme.
166 lines
4.2 KiB
Go
166 lines
4.2 KiB
Go
package amp
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/lightningnetwork/lnd/lntypes"
|
|
"github.com/lightningnetwork/lnd/lnwire"
|
|
"github.com/lightningnetwork/lnd/record"
|
|
"github.com/lightningnetwork/lnd/routing/shards"
|
|
)
|
|
|
|
// Shard is an implementation of the shards.PaymentShards interface specific
|
|
// to AMP payments.
|
|
type Shard struct {
|
|
child *Child
|
|
mpp *record.MPP
|
|
amp *record.AMP
|
|
}
|
|
|
|
// A compile time check to ensure Shard implements the shards.PaymentShard
|
|
// interface.
|
|
var _ shards.PaymentShard = (*Shard)(nil)
|
|
|
|
// Hash returns the hash used for the HTLC representing this AMP shard.
|
|
func (s *Shard) Hash() lntypes.Hash {
|
|
return s.child.Hash
|
|
}
|
|
|
|
// MPP returns any extra MPP records that should be set for the final hop on
|
|
// the route used by this shard.
|
|
func (s *Shard) MPP() *record.MPP {
|
|
return s.mpp
|
|
}
|
|
|
|
// AMP returns any extra AMP records that should be set for the final hop on
|
|
// the route used by this shard.
|
|
func (s *Shard) AMP() *record.AMP {
|
|
return s.amp
|
|
}
|
|
|
|
// ShardTracker is an implementation of the shards.ShardTracker interface
|
|
// that is able to generate payment shards according to the AMP splitting
|
|
// algorithm. It can be used to generate new hashes to use for HTLCs, and also
|
|
// cancel shares used for failed payment shards.
|
|
type ShardTracker struct {
|
|
setID [32]byte
|
|
paymentAddr [32]byte
|
|
totalAmt lnwire.MilliSatoshi
|
|
|
|
sharer Sharer
|
|
|
|
shards map[uint64]*Child
|
|
sync.Mutex
|
|
}
|
|
|
|
// A compile time check to ensure ShardTracker implements the
|
|
// shards.ShardTracker interface.
|
|
var _ shards.ShardTracker = (*ShardTracker)(nil)
|
|
|
|
// NewShardTracker creates a new shard tracker to use for AMP payments. The
|
|
// root shard, setID, payment address and total amount must be correctly set in
|
|
// order for the TLV options to include with each shard to be created
|
|
// correctly.
|
|
func NewShardTracker(root, setID, payAddr [32]byte,
|
|
totalAmt lnwire.MilliSatoshi) *ShardTracker {
|
|
|
|
// Create a new seed sharer from this root.
|
|
rootShare := Share(root)
|
|
rootSharer := SeedSharerFromRoot(&rootShare)
|
|
|
|
return &ShardTracker{
|
|
setID: setID,
|
|
paymentAddr: payAddr,
|
|
totalAmt: totalAmt,
|
|
sharer: rootSharer,
|
|
shards: make(map[uint64]*Child),
|
|
}
|
|
}
|
|
|
|
// NewShard registers a new attempt with the ShardTracker and returns a
|
|
// new shard representing this attempt. This attempt's shard should be canceled
|
|
// if it ends up not being used by the overall payment, i.e. if the attempt
|
|
// fails.
|
|
func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard,
|
|
error) {
|
|
|
|
s.Lock()
|
|
defer s.Unlock()
|
|
|
|
// Use a random child index.
|
|
var childIndex [4]byte
|
|
if _, err := rand.Read(childIndex[:]); err != nil {
|
|
return nil, err
|
|
}
|
|
idx := binary.BigEndian.Uint32(childIndex[:])
|
|
|
|
// Depending on whether we are requesting the last shard or not, either
|
|
// split the current share into two, or get a Child directly from the
|
|
// current sharer.
|
|
var child *Child
|
|
if last {
|
|
child = s.sharer.Child(idx)
|
|
|
|
// If this was the last shard, set the current share to the
|
|
// zero share to indicate we cannot split it further.
|
|
s.sharer = s.sharer.Zero()
|
|
} else {
|
|
left, sharer, err := s.sharer.Split()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.sharer = sharer
|
|
child = left.Child(idx)
|
|
}
|
|
|
|
// Track the new child and return the shard.
|
|
s.shards[pid] = child
|
|
|
|
mpp := record.NewMPP(s.totalAmt, s.paymentAddr)
|
|
amp := record.NewAMP(
|
|
child.ChildDesc.Share, s.setID, child.ChildDesc.Index,
|
|
)
|
|
|
|
return &Shard{
|
|
child: child,
|
|
mpp: mpp,
|
|
amp: amp,
|
|
}, nil
|
|
}
|
|
|
|
// CancelShard cancel's the shard corresponding to the given attempt ID.
|
|
func (s *ShardTracker) CancelShard(pid uint64) error {
|
|
s.Lock()
|
|
defer s.Unlock()
|
|
|
|
c, ok := s.shards[pid]
|
|
if !ok {
|
|
return fmt.Errorf("pid not found")
|
|
}
|
|
delete(s.shards, pid)
|
|
|
|
// Now that we are canceling this shard, we XOR the share back into our
|
|
// current share.
|
|
s.sharer = s.sharer.Merge(c)
|
|
return nil
|
|
}
|
|
|
|
// GetHash retrieves the hash used by the shard of the given attempt ID. This
|
|
// will return an error if the attempt ID is unknown.
|
|
func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) {
|
|
s.Lock()
|
|
defer s.Unlock()
|
|
|
|
c, ok := s.shards[pid]
|
|
if !ok {
|
|
return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+
|
|
"not found", pid)
|
|
}
|
|
|
|
return c.Hash, nil
|
|
}
|