diff --git a/channeldb/kvdb/etcd/db.go b/channeldb/kvdb/etcd/db.go index a082e6109..3bd89c290 100644 --- a/channeldb/kvdb/etcd/db.go +++ b/channeldb/kvdb/etcd/db.go @@ -124,6 +124,9 @@ var _ walletdb.DB = (*db)(nil) // BackendConfig holds and etcd backend config and connection parameters. type BackendConfig struct { + // Ctx is the context we use to cancel operations upon exit. + Ctx context.Context + // Host holds the peer url of the etcd instance. Host string @@ -155,6 +158,10 @@ type BackendConfig struct { // newEtcdBackend returns a db object initialized with the passed backend // config. If etcd connection cannot be estabished, then returns error. func newEtcdBackend(config BackendConfig) (*db, error) { + if config.Ctx == nil { + config.Ctx = context.Background() + } + tlsInfo := transport.TLSInfo{ CertFile: config.CertFile, KeyFile: config.KeyFile, @@ -167,6 +174,7 @@ func newEtcdBackend(config BackendConfig) (*db, error) { } cli, err := clientv3.New(clientv3.Config{ + Context: config.Ctx, Endpoints: []string{config.Host}, DialTimeout: etcdConnectionTimeout, Username: config.User, @@ -192,7 +200,8 @@ func newEtcdBackend(config BackendConfig) (*db, error) { // getSTMOptions creats all STM options based on the backend config. func (db *db) getSTMOptions() []STMOptionFunc { - opts := []STMOptionFunc{} + opts := []STMOptionFunc{WithAbortContext(db.config.Ctx)} + if db.config.CollectCommitStats { opts = append(opts, WithCommitStatsCallback(db.commitStatsCollector.callback), @@ -257,9 +266,7 @@ func (db *db) BeginReadTx() (walletdb.ReadTx, error) { // start a read-only transaction to perform all operations. // This function is part of the walletdb.Db interface implementation. func (db *db) Copy(w io.Writer) error { - ctx := context.Background() - - ctx, cancel := context.WithTimeout(ctx, etcdLongTimeout) + ctx, cancel := context.WithTimeout(db.config.Ctx, etcdLongTimeout) defer cancel() readCloser, err := db.cli.Snapshot(ctx) diff --git a/channeldb/kvdb/etcd/db_test.go b/channeldb/kvdb/etcd/db_test.go index 69342207a..155d912ec 100644 --- a/channeldb/kvdb/etcd/db_test.go +++ b/channeldb/kvdb/etcd/db_test.go @@ -4,6 +4,7 @@ package etcd import ( "bytes" + "context" "testing" "github.com/btcsuite/btcwallet/walletdb" @@ -42,3 +43,33 @@ func TestCopy(t *testing.T) { } assert.Equal(t, expected, f.Dump()) } + +func TestAbortContext(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + + config := f.BackendConfig() + config.Ctx = ctx + + // Pass abort context and abort right away. + db, err := newEtcdBackend(config) + assert.NoError(t, err) + cancel() + + // Expect that the update will fail. + err = db.Update(func(tx walletdb.ReadWriteTx) error { + _, err := tx.CreateTopLevelBucket([]byte("bucket")) + assert.NoError(t, err) + + return nil + }) + + assert.Error(t, err, "context canceled") + + // No changes in the DB. + assert.Equal(t, map[string]string{}, f.Dump()) +} diff --git a/channeldb/kvdb/etcd/embed.go b/channeldb/kvdb/etcd/embed.go index f19363f35..96ea71ab5 100644 --- a/channeldb/kvdb/etcd/embed.go +++ b/channeldb/kvdb/etcd/embed.go @@ -3,6 +3,7 @@ package etcd import ( + "context" "fmt" "net" "net/url" @@ -63,6 +64,7 @@ func NewEmbeddedEtcdInstance(path string) (*BackendConfig, func(), error) { } connConfig := &BackendConfig{ + Ctx: context.Background(), Host: "http://" + peerURL, User: "user", Pass: "pass", diff --git a/channeldb/kvdb/etcd/stm.go b/channeldb/kvdb/etcd/stm.go index a3f8c2233..bf769d116 100644 --- a/channeldb/kvdb/etcd/stm.go +++ b/channeldb/kvdb/etcd/stm.go @@ -229,64 +229,51 @@ func makeSTM(cli *v3.Client, manual bool, so ...STMOptionFunc) *stm { // errors and handling commit. The loop will quit on every error except // CommitError which is used to indicate a necessary retry. func runSTM(s *stm, apply func(STM) error) error { - out := make(chan error, 1) + var ( + retries int + stats CommitStats + err error + ) - go func() { - var ( - retries int - stats CommitStats - ) +loop: + // In a loop try to apply and commit and roll back if the database has + // changed (CommitError). + for { + select { + // Check if the STM is aborted and break the retry loop if it is. + case <-s.options.ctx.Done(): + err = fmt.Errorf("aborted") + break loop - defer func() { - // Recover DatabaseError panics so - // we can return them. - if r := recover(); r != nil { - e, ok := r.(DatabaseError) - if !ok { - // Unknown panic. - panic(r) - } - - // Return the error. - out <- e.Unwrap() - } - }() - - var err error - - // In a loop try to apply and commit and roll back - // if the database has changed (CommitError). - for { - // Abort STM if there was an application error. - if err = apply(s); err != nil { - break - } - - stats, err = s.commit() - - // Re-apply only upon commit error - // (meaning the database was changed). - if _, ok := err.(CommitError); !ok { - // Anything that's not a CommitError - // aborts the STM run loop. - break - } - - // Rollback before trying to re-apply. - s.Rollback() - retries++ + default: } - if s.options.commitStatsCallback != nil { - stats.Retries = retries - s.options.commitStatsCallback(err == nil, stats) + // Apply the transaction closure and abort the STM if there was an + // application error. + if err = apply(s); err != nil { + break loop } - // Return the error to the caller. - out <- err - }() + stats, err = s.commit() - return <-out + // Re-apply only upon commit error (meaning the database was changed). + if _, ok := err.(CommitError); !ok { + // Anything that's not a CommitError + // aborts the STM run loop. + break loop + } + + // Rollback before trying to re-apply. + s.Rollback() + retries++ + } + + if s.options.commitStatsCallback != nil { + stats.Retries = retries + s.options.commitStatsCallback(err == nil, stats) + } + + return err } // add inserts a txn response to the read set. This is useful when the txn @@ -367,18 +354,10 @@ func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) { s.options.ctx, key, append(opts, s.getOpts...)..., ) if err != nil { - dbErr := DatabaseError{ + return nil, DatabaseError{ msg: "stm.fetch() failed", err: err, } - - // Do not panic when executing a manual transaction. - if s.manual { - return nil, dbErr - } - - // Panic when executing inside the STM runloop. - panic(dbErr) } // Set revison and serializable options upon first fetch diff --git a/channeldb/kvdb/kvdb_etcd.go b/channeldb/kvdb/kvdb_etcd.go index 265e7daeb..523112141 100644 --- a/channeldb/kvdb/kvdb_etcd.go +++ b/channeldb/kvdb/kvdb_etcd.go @@ -3,6 +3,8 @@ package kvdb import ( + "context" + "github.com/lightningnetwork/lnd/channeldb/kvdb/etcd" ) @@ -12,10 +14,13 @@ const TestBackend = EtcdBackendName // GetEtcdBackend returns an etcd backend configured according to the // passed etcdConfig. -func GetEtcdBackend(prefix string, etcdConfig *EtcdConfig) (Backend, error) { +func GetEtcdBackend(ctx context.Context, prefix string, + etcdConfig *EtcdConfig) (Backend, error) { + // Config translation is needed here in order to keep the // etcd package fully independent from the rest of the source tree. backendConfig := etcd.BackendConfig{ + Ctx: ctx, Host: etcdConfig.Host, User: etcdConfig.User, Pass: etcdConfig.Pass, diff --git a/channeldb/kvdb/kvdb_no_etcd.go b/channeldb/kvdb/kvdb_no_etcd.go index ea5de4275..71090f475 100644 --- a/channeldb/kvdb/kvdb_no_etcd.go +++ b/channeldb/kvdb/kvdb_no_etcd.go @@ -3,6 +3,7 @@ package kvdb import ( + "context" "fmt" ) @@ -13,7 +14,9 @@ const TestBackend = BoltBackendName var errEtcdNotAvailable = fmt.Errorf("etcd backend not available") // GetEtcdBackend is a stub returning nil and errEtcdNotAvailable error. -func GetEtcdBackend(prefix string, etcdConfig *EtcdConfig) (Backend, error) { +func GetEtcdBackend(ctx context.Context, prefix string, + etcdConfig *EtcdConfig) (Backend, error) { + return nil, errEtcdNotAvailable } diff --git a/lncfg/db.go b/lncfg/db.go index a265f95f1..d63da8caf 100644 --- a/lncfg/db.go +++ b/lncfg/db.go @@ -1,6 +1,7 @@ package lncfg import ( + "context" "fmt" "github.com/lightningnetwork/lnd/channeldb/kvdb" @@ -50,12 +51,12 @@ func (db *DB) Validate() error { } // GetBackend returns a kvdb.Backend as set in the DB config. -func (db *DB) GetBackend(dbPath string, networkName string) ( - kvdb.Backend, error) { +func (db *DB) GetBackend(ctx context.Context, dbPath string, + networkName string) (kvdb.Backend, error) { if db.Backend == etcdBackend { // Prefix will separate key/values in the db. - return kvdb.GetEtcdBackend(networkName, db.Etcd) + return kvdb.GetEtcdBackend(ctx, networkName, db.Etcd) } return kvdb.GetBoltBackend(dbPath, dbName, db.Bolt.NoFreeListSync) diff --git a/lnd.go b/lnd.go index 105148997..cccf30e61 100644 --- a/lnd.go +++ b/lnd.go @@ -251,7 +251,11 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error { "minutes...") startOpenTime := time.Now() - chanDbBackend, err := cfg.DB.GetBackend( + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + chanDbBackend, err := cfg.DB.GetBackend(ctx, cfg.localDatabaseDir(), cfg.networkName(), ) if err != nil { @@ -283,10 +287,6 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error { ltndLog.Infof("Database now open (time_to_open=%v)!", openTime) // Only process macaroons if --no-macaroons isn't set. - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - tlsCfg, restCreds, restProxyDest, err := getTLSConfig(cfg) if err != nil { err := fmt.Errorf("unable to load TLS credentials: %v", err)