Merge pull request #9143 from ellemouton/rb-set-bit-required

feature+rpcserver: add SetBit helper to set dependent bits
This commit is contained in:
Olaoluwa Osuntokun 2024-10-04 19:16:34 -07:00 committed by GitHub
commit e9fbbe2528
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 237 additions and 2 deletions

View File

@ -27,6 +27,10 @@
cause a nil pointer dereference during the probing of a payment request that cause a nil pointer dereference during the probing of a payment request that
does not contain a payment address. does not contain a payment address.
* [Use the required route blinding
feature-bit](https://github.com/lightningnetwork/lnd/pull/9143) for invoices
containing blinded paths.
# New Features # New Features
## Functional Enhancements ## Functional Enhancements
## RPC Additions ## RPC Additions

View File

@ -103,6 +103,58 @@ func ValidateDeps(fv *lnwire.FeatureVector) error {
return validateDeps(features, supported) return validateDeps(features, supported)
} }
// SetBit sets the given feature bit on the given feature bit vector along with
// any of its dependencies. If the bit is required, then all the dependencies
// are also set to required, otherwise, the optional dependency bits are set.
// Existing bits are only upgraded from optional to required but never
// downgraded from required to optional.
func SetBit(vector *lnwire.FeatureVector,
bit lnwire.FeatureBit) *lnwire.FeatureVector {
fv := vector.Clone()
// Get the optional version of the bit since that is what the deps map
// uses.
optBit := mapToOptional(bit)
// If the bit we are setting is optional, then we set it (in its
// optional form) and also set all its dependents as optional if they
// are not already set (they may already be set in a required form in
// which case they should not be overridden).
if !bit.IsRequired() {
// Set the bit itself if it does not already exist. We use
// SafeSet here so that if the bit already exists in the
// required form, then this is not overwritten.
_ = fv.SafeSet(bit)
// Do the same for all the dependent bits.
for depBit := range deps[optBit] {
fv = SetBit(fv, depBit)
}
return fv
}
// The bit is required. In this case, we do want to override any
// existing optional bit for both the bit itself and for the dependent
// bits.
fv.Unset(optBit)
fv.Set(bit)
// Do the same for all the dependent bits.
for depBit := range deps[optBit] {
// The deps map only contains the optional versions of bits, so
// there is no need to first map the bit to the optional
// version.
fv.Unset(depBit)
// Set the required version of the bit instead.
fv = SetBit(fv, mapToRequired(depBit))
}
return fv
}
// validateDeps is a subroutine that recursively checks that the passed features // validateDeps is a subroutine that recursively checks that the passed features
// have all of their associated dependencies in the supported map. // have all of their associated dependencies in the supported map.
func validateDeps(features featureSet, supported supportedFeatures) error { func validateDeps(features featureSet, supported supportedFeatures) error {
@ -162,3 +214,13 @@ func mapToOptional(bit lnwire.FeatureBit) lnwire.FeatureBit {
} }
return bit return bit
} }
// mapToRequired returns the required variant of a given feature bit pair.
func mapToRequired(bit lnwire.FeatureBit) lnwire.FeatureBit {
if bit.IsRequired() {
return bit
}
bit ^= 0x01
return bit
}

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
) )
type depTest struct { type depTest struct {
@ -164,3 +165,170 @@ func testValidateDeps(t *testing.T, test depTest) {
test.expErr, err) test.expErr, err)
} }
} }
// TestSettingDepBits sets that the SetBit function correctly sets a bit along
// with its dependencies in a feature vector. Specifically, we want to check
// that any existing optional bits are upgraded to required if the main bit
// being set is required. Similarly, if the main bit is optional, then any
// existing bits that depend on it should not be downgraded from required to
// optional.
func TestSettingDepBits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existingVector *lnwire.RawFeatureVector
newBit lnwire.FeatureBit
expectedVector *lnwire.RawFeatureVector
}{
{
name: "Optional bit with no dependants",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.ExplicitChannelTypeOptional,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.ExplicitChannelTypeOptional,
),
},
{
name: "Required bit with no dependants",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.ExplicitChannelTypeRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.ExplicitChannelTypeRequired,
),
},
{
name: "Optional bit with single " +
"level dependant",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.RouteBlindingOptional,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.RouteBlindingOptional,
lnwire.TLVOnionPayloadOptional,
),
},
{
name: "Required bit with single " +
"level dependant",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.RouteBlindingRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.RouteBlindingRequired,
lnwire.TLVOnionPayloadRequired,
),
},
{
name: "Optional bit with multi level " +
"dependants",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.Bolt11BlindedPathsOptional,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.Bolt11BlindedPathsOptional,
lnwire.RouteBlindingOptional,
lnwire.TLVOnionPayloadOptional,
),
},
{
name: "Required bit with multi level " +
"dependants",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.Bolt11BlindedPathsRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.Bolt11BlindedPathsRequired,
lnwire.RouteBlindingRequired,
lnwire.TLVOnionPayloadRequired,
),
},
{
name: "Existing required bit should not be " +
"overridden if new bit is optional",
existingVector: lnwire.NewRawFeatureVector(
lnwire.TLVOnionPayloadRequired,
),
newBit: lnwire.Bolt11BlindedPathsOptional,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.Bolt11BlindedPathsOptional,
lnwire.RouteBlindingOptional,
lnwire.TLVOnionPayloadRequired,
),
},
{
name: "Existing optional bit should be overridden if " +
"new bit is required",
existingVector: lnwire.NewRawFeatureVector(
lnwire.TLVOnionPayloadOptional,
),
newBit: lnwire.Bolt11BlindedPathsRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.Bolt11BlindedPathsRequired,
lnwire.RouteBlindingRequired,
lnwire.TLVOnionPayloadRequired,
),
},
{
name: "Unrelated bits should not be affected",
existingVector: lnwire.NewRawFeatureVector(
lnwire.AMPOptional,
lnwire.TLVOnionPayloadOptional,
),
newBit: lnwire.Bolt11BlindedPathsRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.AMPOptional,
lnwire.Bolt11BlindedPathsRequired,
lnwire.RouteBlindingRequired,
lnwire.TLVOnionPayloadRequired,
),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
fv := lnwire.NewFeatureVector(
test.existingVector, lnwire.Features,
)
resultFV := SetBit(fv, test.newBit)
require.Equal(
t, test.expectedVector,
resultFV.RawFeatureVector,
)
})
}
}
// TestSetBitNoCycles tests the SetBit call for each feature bit that we know of
// in both its optional and required form. This ensures that the SetBit call
// never gets stuck in a recursion cycle for any feature bit.
func TestSetBitNoCycles(t *testing.T) {
t.Parallel()
// For each feature-bit that we are aware of (both optional and
// required), we will create a feature vector that is empty, and then
// we will call SetBit with the given feature bit. We then check that
// all the dependent features are also added in the appropriate form
// (optional vs required). This test completing demonstrates that the
// recursion in SetBit is not a problem since no feature bits should
// create a dependency cycle.
for bit := range lnwire.Features {
fv := lnwire.NewFeatureVector(
lnwire.NewRawFeatureVector(), lnwire.Features,
)
resultFV := SetBit(fv, bit)
// Ensure that all the dependent feature bits are in fact set
// in the resulting set. Here we just check that some form
// (optional or required) is set. The expected type is asserted
// later on in the test.
for expectedBit := range deps[bit] {
require.True(t, resultFV.IsSet(expectedBit) ||
resultFV.IsSet(mapToRequired(expectedBit)))
}
// Make sure all the resulting feature bits have the correct
// form (optional vs required).
for depBit := range resultFV.Features() {
require.Equal(t, bit.IsRequired(), depBit.IsRequired())
}
}
}

View File

@ -6082,8 +6082,9 @@ func (r *rpcServer) AddInvoice(ctx context.Context,
// understand the new BOLT 11 tagged field // understand the new BOLT 11 tagged field
// containing the blinded path, so we switch // containing the blinded path, so we switch
// the bit to required. // the bit to required.
v.Unset(lnwire.Bolt11BlindedPathsOptional) v = feature.SetBit(
v.Set(lnwire.Bolt11BlindedPathsRequired) v, lnwire.Bolt11BlindedPathsRequired,
)
} }
return v return v