diff --git a/feature/manager.go b/feature/manager.go index 26a3d4a31..f788f9892 100644 --- a/feature/manager.go +++ b/feature/manager.go @@ -1,11 +1,16 @@ package feature import ( + "errors" "fmt" "github.com/lightningnetwork/lnd/lnwire" ) +// ErrUnknownSet is returned if a proposed feature vector contains a set that +// is unknown to LND. +var ErrUnknownSet = errors.New("unknown feature bit set") + // Config houses any runtime modifications to the default set descriptors. For // our purposes, this typically means disabling certain features to test legacy // protocol interoperability or functionality. @@ -198,3 +203,40 @@ func (m *Manager) ListSets() []Set { return sets } + +// UpdateFeatureSets accepts a map of new feature vectors for each of the +// manager's known sets, validates that the update can be applied and modifies +// the feature manager's internal state. If a set is not included in the update +// map, it is left unchanged. The feature vectors provided are expected to +// include the current set of features, updated with desired bits added/removed. +func (m *Manager) UpdateFeatureSets( + updates map[Set]*lnwire.RawFeatureVector) error { + + for set, newFeatures := range updates { + if !set.valid() { + return fmt.Errorf("%w: set: %d", ErrUnknownSet, set) + } + + if err := newFeatures.ValidatePairs(); err != nil { + return err + } + + if err := m.Get(set).ValidateUpdate(newFeatures); err != nil { + return err + } + + fv := lnwire.NewFeatureVector(newFeatures, lnwire.Features) + if err := ValidateDeps(fv); err != nil { + return err + } + } + + // Only update the current feature sets once every proposed set has + // passed validation so that we don't partially update any sets then + // fail out on a later set's validation. + for set, features := range updates { + m.SetRaw(set, features.Clone()) + } + + return nil +} diff --git a/feature/manager_internal_test.go b/feature/manager_internal_test.go index 8debddcbe..04422cd9b 100644 --- a/feature/manager_internal_test.go +++ b/feature/manager_internal_test.go @@ -135,3 +135,113 @@ func testManager(t *testing.T, test managerTest) { assertSet(lnwire.StaticRemoteKeyOptional) } } + +// TestUpdateFeatureSets tests validation of the update of various features in +// each of our sets, asserting that the feature set is not partially modified +// if one set in incorrectly specified. +func TestUpdateFeatureSets(t *testing.T) { + t.Parallel() + + // Use a reduced set description to make reasoning about our sets + // easier. + setDesc := setDesc{ + lnwire.DataLossProtectRequired: { + SetInit: {}, // I + SetNodeAnn: {}, // N + }, + lnwire.GossipQueriesOptional: { + SetNodeAnn: {}, // N + }, + } + + testCases := []struct { + name string + features map[Set]*lnwire.RawFeatureVector + err error + }{ + { + name: "unknown set", + features: map[Set]*lnwire.RawFeatureVector{ + setSentinel + 1: lnwire.NewRawFeatureVector(), + }, + err: ErrUnknownSet, + }, + { + name: "invalid pairwise feature", + features: map[Set]*lnwire.RawFeatureVector{ + SetNodeAnn: lnwire.NewRawFeatureVector( + lnwire.FeatureBit(1000), + lnwire.FeatureBit(1001), + ), + }, + err: lnwire.ErrFeaturePairExists, + }, + { + name: "error in one set", + features: map[Set]*lnwire.RawFeatureVector{ + SetNodeAnn: lnwire.NewRawFeatureVector( + lnwire.FeatureBit(1000), + lnwire.FeatureBit(1001), + ), + SetInit: lnwire.NewRawFeatureVector( + lnwire.DataLossProtectRequired, + ), + }, + err: lnwire.ErrFeaturePairExists, + }, + { + name: "update existing sets ok", + features: map[Set]*lnwire.RawFeatureVector{ + SetInit: lnwire.NewRawFeatureVector( + lnwire.DataLossProtectRequired, + lnwire.FeatureBit(1001), + ), + SetNodeAnn: lnwire.NewRawFeatureVector( + lnwire.DataLossProtectRequired, + lnwire.GossipQueriesOptional, + lnwire.FeatureBit(1000), + ), + }, + }, + { + name: "update new, valid set ok", + features: map[Set]*lnwire.RawFeatureVector{ + SetInvoice: lnwire.NewRawFeatureVector( + lnwire.FeatureBit(1001), + ), + }, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + featureMgr, err := newManager(Config{}, setDesc) + require.NoError(t, err) + + err = featureMgr.UpdateFeatureSets(testCase.features) + require.ErrorIs(t, err, testCase.err) + + // Compare the feature manager's sets to the updated + // set if no error was hit, otherwise assert that it + // is unchanged. + expected := testCase.features + actual := featureMgr + if err != nil { + originalMgr, err := newManager( + Config{}, setDesc, + ) + require.NoError(t, err) + expected = originalMgr.fsets + } + + for set, expectedFeatures := range expected { + actualSet := actual.GetRaw(set) + require.True(t, + actualSet.Equals(expectedFeatures)) + } + }) + } +} diff --git a/feature/set.go b/feature/set.go index 435fac1a4..c70637031 100644 --- a/feature/set.go +++ b/feature/set.go @@ -26,8 +26,19 @@ const ( // SetInvoiceAmp identifies the features that should be advertised on // AMP invoices generated by the daemon. SetInvoiceAmp + + // setSentinel is used to mark the end of our known sets. This enum + // member must *always* be the last item in the iota list to ensure + // that validation works as expected. + setSentinel ) +// valid returns a boolean indicating whether a set value is one of our +// predefined feature sets. +func (s Set) valid() bool { + return s < setSentinel +} + // String returns a human-readable description of a Set. func (s Set) String() string { switch s {