diff --git a/config.go b/config.go index 70aed41a3..e9717c6e1 100644 --- a/config.go +++ b/config.go @@ -990,6 +990,12 @@ func loadConfig() (*config, error) { "minbackoff") } + // Assert that all worker pools will have a positive number of + // workers, otherwise the pools will rendered useless. + if err := cfg.Workers.Validate(); err != nil { + return nil, err + } + // Finally, ensure that the user's color is correctly formatted, // otherwise the server will not be able to start after the unlocking // the wallet. diff --git a/lncfg/workers.go b/lncfg/workers.go index cf3a8724c..fec57ddda 100644 --- a/lncfg/workers.go +++ b/lncfg/workers.go @@ -1,5 +1,7 @@ package lncfg +import "fmt" + const ( // DefaultReadWorkers is the default maximum number of concurrent // workers used by the daemon's read pool. @@ -26,3 +28,22 @@ type Workers struct { // Sig is the maximum number of concurrent sig pool workers. Sig int `long:"sig" description:"Maximum number of concurrent sig pool workers."` } + +// Validate checks the Workers configuration to ensure that the input values are +// sane. +func (w *Workers) Validate() error { + if w.Read <= 0 { + return fmt.Errorf("number of read workers (%d) must be "+ + "positive", w.Read) + } + if w.Write <= 0 { + return fmt.Errorf("number of write workers (%d) must be "+ + "positive", w.Write) + } + if w.Sig <= 0 { + return fmt.Errorf("number of sig workers (%d) must be "+ + "positive", w.Sig) + } + + return nil +} diff --git a/lncfg/workers_test.go b/lncfg/workers_test.go new file mode 100644 index 000000000..cc32202d3 --- /dev/null +++ b/lncfg/workers_test.go @@ -0,0 +1,102 @@ +package lncfg_test + +import ( + "testing" + + "github.com/lightningnetwork/lnd/lncfg" +) + +const ( + maxUint = ^uint(0) + maxInt = int(maxUint >> 1) + minInt = -maxInt - 1 +) + +// TestValidateWorkers asserts that validating the Workers config only succeeds +// if all fields specify a positive number of workers. +func TestValidateWorkers(t *testing.T) { + tests := []struct { + name string + cfg *lncfg.Workers + valid bool + }{ + { + name: "min valid", + cfg: &lncfg.Workers{ + Read: 1, + Write: 1, + Sig: 1, + }, + valid: true, + }, + { + name: "max valid", + cfg: &lncfg.Workers{ + Read: maxInt, + Write: maxInt, + Sig: maxInt, + }, + valid: true, + }, + { + name: "read max invalid", + cfg: &lncfg.Workers{ + Read: 0, + Write: 1, + Sig: 1, + }, + }, + { + name: "write max invalid", + cfg: &lncfg.Workers{ + Read: 1, + Write: 0, + Sig: 1, + }, + }, + { + name: "sig max invalid", + cfg: &lncfg.Workers{ + Read: 1, + Write: 1, + Sig: 0, + }, + }, + { + name: "read min invalid", + cfg: &lncfg.Workers{ + Read: minInt, + Write: 1, + Sig: 1, + }, + }, + { + name: "write min invalid", + cfg: &lncfg.Workers{ + Read: 1, + Write: minInt, + Sig: 1, + }, + }, + { + name: "sig min invalid", + cfg: &lncfg.Workers{ + Read: 1, + Write: 1, + Sig: minInt, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := test.cfg.Validate() + switch { + case test.valid && err != nil: + t.Fatalf("valid config was invalid: %v", err) + case !test.valid && err == nil: + t.Fatalf("invalid config was valid") + } + }) + } +}