mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
aliasmgr: avoid collision when requesting alias
With the new RPC calls that we are going to add in the next commits, it will be possible for users to add (local only, non-gossipped) SCID aliases for channels. Since those will be in the same range as the ones given out by RequestAlias, we need to make sure that when we generate a new one that it doesn't collide with an already existing one.
This commit is contained in:
parent
80dfaeb16d
commit
466f550ddb
@ -9,6 +9,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// UpdateLinkAliases is a function type for a function that locates the active
|
||||
@ -58,16 +59,16 @@ var (
|
||||
byteOrder = binary.BigEndian
|
||||
|
||||
// startBlockHeight is the starting block height of the alias range.
|
||||
startingBlockHeight = 16_000_000
|
||||
startingBlockHeight uint32 = 16_000_000
|
||||
|
||||
// endBlockHeight is the ending block height of the alias range.
|
||||
endBlockHeight = 16_250_000
|
||||
endBlockHeight uint32 = 16_250_000
|
||||
|
||||
// StartingAlias is the first alias ShortChannelID that will get
|
||||
// assigned by RequestAlias. The starting BlockHeight is chosen so that
|
||||
// legitimate SCIDs in integration tests aren't mistaken for an alias.
|
||||
StartingAlias = lnwire.ShortChannelID{
|
||||
BlockHeight: uint32(startingBlockHeight),
|
||||
BlockHeight: startingBlockHeight,
|
||||
TxIndex: 0,
|
||||
TxPosition: 0,
|
||||
}
|
||||
@ -506,6 +507,19 @@ func (m *Manager) GetPeerAlias(chanID lnwire.ChannelID) (lnwire.ShortChannelID,
|
||||
func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) {
|
||||
var nextAlias lnwire.ShortChannelID
|
||||
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
// haveAlias returns true if the passed alias is already assigned to a
|
||||
// channel in the baseToSet map.
|
||||
haveAlias := func(maybeNextAlias lnwire.ShortChannelID) bool {
|
||||
return fn.Any(func(aliasList []lnwire.ShortChannelID) bool {
|
||||
return fn.Any(func(alias lnwire.ShortChannelID) bool {
|
||||
return alias == maybeNextAlias
|
||||
}, aliasList)
|
||||
}, maps.Values(m.baseToSet))
|
||||
}
|
||||
|
||||
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
|
||||
bucket, err := tx.CreateTopLevelBucket(aliasAllocBucket)
|
||||
if err != nil {
|
||||
@ -518,6 +532,27 @@ func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) {
|
||||
// StartingAlias to it.
|
||||
nextAlias = StartingAlias
|
||||
|
||||
// If the very first alias is already assigned, we'll
|
||||
// keep incrementing until we find an unassigned alias.
|
||||
// This is to avoid collision with custom added SCID
|
||||
// aliases that fall into the same range as the ones we
|
||||
// generate here monotonically. Those custom SCIDs are
|
||||
// stored in a different bucket, but we can just check
|
||||
// the in-memory map for simplicity.
|
||||
for {
|
||||
if !haveAlias(nextAlias) {
|
||||
break
|
||||
}
|
||||
|
||||
nextAlias = getNextScid(nextAlias)
|
||||
|
||||
// Abort if we've reached the end of the range.
|
||||
if nextAlias.BlockHeight >= endBlockHeight {
|
||||
return fmt.Errorf("range for custom " +
|
||||
"aliases exhausted")
|
||||
}
|
||||
}
|
||||
|
||||
var scratch [8]byte
|
||||
byteOrder.PutUint64(scratch[:], nextAlias.ToUint64())
|
||||
return bucket.Put(lastAliasKey, scratch[:])
|
||||
@ -532,6 +567,26 @@ func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) {
|
||||
)
|
||||
nextAlias = getNextScid(lastScid)
|
||||
|
||||
// If the next alias is already assigned, we'll keep
|
||||
// incrementing until we find an unassigned alias. This is to
|
||||
// avoid collision with custom added SCID aliases that fall into
|
||||
// the same range as the ones we generate here monotonically.
|
||||
// Those custom SCIDs are stored in a different bucket, but we
|
||||
// can just check the in-memory map for simplicity.
|
||||
for {
|
||||
if !haveAlias(nextAlias) {
|
||||
break
|
||||
}
|
||||
|
||||
nextAlias = getNextScid(nextAlias)
|
||||
|
||||
// Abort if we've reached the end of the range.
|
||||
if nextAlias.BlockHeight >= endBlockHeight {
|
||||
return fmt.Errorf("range for custom " +
|
||||
"aliases exhausted")
|
||||
}
|
||||
}
|
||||
|
||||
var scratch [8]byte
|
||||
byteOrder.PutUint64(scratch[:], nextAlias.ToUint64())
|
||||
return bucket.Put(lastAliasKey, scratch[:])
|
||||
@ -614,6 +669,6 @@ func getNextScid(last lnwire.ShortChannelID) lnwire.ShortChannelID {
|
||||
// assigned by RequestAlias. These bounds only apply to aliases we generate.
|
||||
// Our peers are free to use any range they choose.
|
||||
func IsAlias(scid lnwire.ShortChannelID) bool {
|
||||
return scid.BlockHeight >= uint32(startingBlockHeight) &&
|
||||
scid.BlockHeight < uint32(endBlockHeight)
|
||||
return scid.BlockHeight >= startingBlockHeight &&
|
||||
scid.BlockHeight < endBlockHeight
|
||||
}
|
||||
|
@ -168,6 +168,24 @@ func TestAliasLifecycle(t *testing.T) {
|
||||
// Query the aliases and verify that none exists.
|
||||
aliasList = aliasStore.GetAliases(baseScid)
|
||||
require.Len(t, aliasList, 0)
|
||||
|
||||
// We now request an alias generated by the aliasStore. This should give
|
||||
// the first from the pre-defined list of allocated aliases.
|
||||
firstRequested, err := aliasStore.RequestAlias()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, StartingAlias, firstRequested)
|
||||
|
||||
// We now manually add the next alias from the range as a custom alias.
|
||||
secondAlias := getNextScid(firstRequested)
|
||||
err = aliasStore.AddLocalAlias(secondAlias, baseScid, false, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// When we now request another alias from the allocation list, we expect
|
||||
// the third one (tx position 2) to be returned.
|
||||
thirdRequested, err := aliasStore.RequestAlias()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, getNextScid(secondAlias), thirdRequested)
|
||||
require.EqualValues(t, 2, thirdRequested.TxPosition)
|
||||
}
|
||||
|
||||
// TestGetNextScid tests that given a current lnwire.ShortChannelID,
|
||||
@ -182,7 +200,7 @@ func TestGetNextScid(t *testing.T) {
|
||||
name: "starting alias",
|
||||
current: StartingAlias,
|
||||
expected: lnwire.ShortChannelID{
|
||||
BlockHeight: uint32(startingBlockHeight),
|
||||
BlockHeight: startingBlockHeight,
|
||||
TxIndex: 0,
|
||||
TxPosition: 1,
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user