feature: add SetBit helper to set dependent bits

This commit adds a new SetBit helper function in the features package
along with a test for it.  SetBit sets the given feature bit on the
given feature bit vector along with any of its dependencies. If the bit
is required, then all the dependencies are also set to required,
otherwise, the optional dependency bits are set. Existing bits are only
upgraded from optional to required but never downgraded from required to
optional.
This commit is contained in:
Elle Mouton 2024-09-30 10:30:27 +02:00
parent f1207ef740
commit 49a87469db
No known key found for this signature in database
GPG Key ID: D7D916376026F177
2 changed files with 230 additions and 0 deletions

View File

@ -98,6 +98,58 @@ func ValidateDeps(fv *lnwire.FeatureVector) error {
return validateDeps(features, supported)
}
// SetBit sets the given feature bit on the given feature bit vector along with
// any of its dependencies. If the bit is required, then all the dependencies
// are also set to required, otherwise, the optional dependency bits are set.
// Existing bits are only upgraded from optional to required but never
// downgraded from required to optional.
func SetBit(vector *lnwire.FeatureVector,
bit lnwire.FeatureBit) *lnwire.FeatureVector {
fv := vector.Clone()
// Get the optional version of the bit since that is what the deps map
// uses.
optBit := mapToOptional(bit)
// If the bit we are setting is optional, then we set it (in its
// optional form) and also set all its dependents as optional if they
// are not already set (they may already be set in a required form in
// which case they should not be overridden).
if !bit.IsRequired() {
// Set the bit itself if it does not already exist. We use
// SafeSet here so that if the bit already exists in the
// required form, then this is not overwritten.
_ = fv.SafeSet(bit)
// Do the same for all the dependent bits.
for depBit := range deps[optBit] {
fv = SetBit(fv, depBit)
}
return fv
}
// The bit is required. In this case, we do want to override any
// existing optional bit for both the bit itself and for the dependent
// bits.
fv.Unset(optBit)
fv.Set(bit)
// Do the same for all the dependent bits.
for depBit := range deps[optBit] {
// The deps map only contains the optional versions of bits, so
// there is no need to first map the bit to the optional
// version.
fv.Unset(depBit)
// Set the required version of the bit instead.
fv = SetBit(fv, mapToRequired(depBit))
}
return fv
}
// validateDeps is a subroutine that recursively checks that the passed features
// have all of their associated dependencies in the supported map.
func validateDeps(features featureSet, supported supportedFeatures) error {
@ -157,3 +209,13 @@ func mapToOptional(bit lnwire.FeatureBit) lnwire.FeatureBit {
}
return bit
}
// mapToRequired returns the required variant of a given feature bit pair.
func mapToRequired(bit lnwire.FeatureBit) lnwire.FeatureBit {
if bit.IsRequired() {
return bit
}
bit ^= 0x01
return bit
}

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
)
type depTest struct {
@ -164,3 +165,170 @@ func testValidateDeps(t *testing.T, test depTest) {
test.expErr, err)
}
}
// TestSettingDepBits sets that the SetBit function correctly sets a bit along
// with its dependencies in a feature vector. Specifically, we want to check
// that any existing optional bits are upgraded to required if the main bit
// being set is required. Similarly, if the main bit is optional, then any
// existing bits that depend on it should not be downgraded from required to
// optional.
func TestSettingDepBits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existingVector *lnwire.RawFeatureVector
newBit lnwire.FeatureBit
expectedVector *lnwire.RawFeatureVector
}{
{
name: "Optional bit with no dependants",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.ExplicitChannelTypeOptional,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.ExplicitChannelTypeOptional,
),
},
{
name: "Required bit with no dependants",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.ExplicitChannelTypeRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.ExplicitChannelTypeRequired,
),
},
{
name: "Optional bit with single " +
"level dependant",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.RouteBlindingOptional,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.RouteBlindingOptional,
lnwire.TLVOnionPayloadOptional,
),
},
{
name: "Required bit with single " +
"level dependant",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.RouteBlindingRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.RouteBlindingRequired,
lnwire.TLVOnionPayloadRequired,
),
},
{
name: "Optional bit with multi level " +
"dependants",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.Bolt11BlindedPathsOptional,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.Bolt11BlindedPathsOptional,
lnwire.RouteBlindingOptional,
lnwire.TLVOnionPayloadOptional,
),
},
{
name: "Required bit with multi level " +
"dependants",
existingVector: lnwire.NewRawFeatureVector(),
newBit: lnwire.Bolt11BlindedPathsRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.Bolt11BlindedPathsRequired,
lnwire.RouteBlindingRequired,
lnwire.TLVOnionPayloadRequired,
),
},
{
name: "Existing required bit should not be " +
"overridden if new bit is optional",
existingVector: lnwire.NewRawFeatureVector(
lnwire.TLVOnionPayloadRequired,
),
newBit: lnwire.Bolt11BlindedPathsOptional,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.Bolt11BlindedPathsOptional,
lnwire.RouteBlindingOptional,
lnwire.TLVOnionPayloadRequired,
),
},
{
name: "Existing optional bit should be overridden if " +
"new bit is required",
existingVector: lnwire.NewRawFeatureVector(
lnwire.TLVOnionPayloadOptional,
),
newBit: lnwire.Bolt11BlindedPathsRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.Bolt11BlindedPathsRequired,
lnwire.RouteBlindingRequired,
lnwire.TLVOnionPayloadRequired,
),
},
{
name: "Unrelated bits should not be affected",
existingVector: lnwire.NewRawFeatureVector(
lnwire.AMPOptional,
lnwire.TLVOnionPayloadOptional,
),
newBit: lnwire.Bolt11BlindedPathsRequired,
expectedVector: lnwire.NewRawFeatureVector(
lnwire.AMPOptional,
lnwire.Bolt11BlindedPathsRequired,
lnwire.RouteBlindingRequired,
lnwire.TLVOnionPayloadRequired,
),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
fv := lnwire.NewFeatureVector(
test.existingVector, lnwire.Features,
)
resultFV := SetBit(fv, test.newBit)
require.Equal(
t, test.expectedVector,
resultFV.RawFeatureVector,
)
})
}
}
// TestSetBitNoCycles tests the SetBit call for each feature bit that we know of
// in both its optional and required form. This ensures that the SetBit call
// never gets stuck in a recursion cycle for any feature bit.
func TestSetBitNoCycles(t *testing.T) {
t.Parallel()
// For each feature-bit that we are aware of (both optional and
// required), we will create a feature vector that is empty, and then
// we will call SetBit with the given feature bit. We then check that
// all the dependent features are also added in the appropriate form
// (optional vs required). This test completing demonstrates that the
// recursion in SetBit is not a problem since no feature bits should
// create a dependency cycle.
for bit := range lnwire.Features {
fv := lnwire.NewFeatureVector(
lnwire.NewRawFeatureVector(), lnwire.Features,
)
resultFV := SetBit(fv, bit)
// Ensure that all the dependent feature bits are in fact set
// in the resulting set. Here we just check that some form
// (optional or required) is set. The expected type is asserted
// later on in the test.
for expectedBit := range deps[bit] {
require.True(t, resultFV.IsSet(expectedBit) ||
resultFV.IsSet(mapToRequired(expectedBit)))
}
// Make sure all the resulting feature bits have the correct
// form (optional vs required).
for depBit := range resultFV.Features() {
require.Equal(t, bit.IsRequired(), depBit.IsRequired())
}
}
}