mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-15 03:51:23 +01:00
fn: Remove ctx from GoroutineManager constructor
This commit is contained in:
parent
d6eeaec246
commit
51eeb9ece3
2 changed files with 238 additions and 167 deletions
|
@ -3,51 +3,123 @@ package fn
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GoroutineManager is used to launch goroutines until context expires or the
|
// GoroutineManager is used to launch goroutines until context expires or the
|
||||||
// manager is stopped. The Stop method blocks until all started goroutines stop.
|
// manager is stopped. The Stop method blocks until all started goroutines stop.
|
||||||
type GoroutineManager struct {
|
type GoroutineManager struct {
|
||||||
wg sync.WaitGroup
|
// id is used to generate unique ids for each goroutine.
|
||||||
mu sync.Mutex
|
id atomic.Uint32
|
||||||
ctx context.Context
|
|
||||||
cancel func()
|
// cancelFns is a map of cancel functions that can be used to cancel the
|
||||||
|
// context of a goroutine. The mutex must be held when accessing this
|
||||||
|
// map. The key is the id of the goroutine.
|
||||||
|
cancelFns map[uint32]context.CancelFunc
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
stopped sync.Once
|
||||||
|
quit chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGoroutineManager constructs and returns a new instance of
|
// NewGoroutineManager constructs and returns a new instance of
|
||||||
// GoroutineManager.
|
// GoroutineManager.
|
||||||
func NewGoroutineManager(ctx context.Context) *GoroutineManager {
|
func NewGoroutineManager() *GoroutineManager {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
return &GoroutineManager{
|
return &GoroutineManager{
|
||||||
ctx: ctx,
|
cancelFns: make(map[uint32]context.CancelFunc),
|
||||||
cancel: cancel,
|
quit: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Go tries to start a new goroutine and returns a boolean indicating its
|
// addCancelFn adds a context cancel function to the manager and returns an id
|
||||||
// success. It fails iff the goroutine manager is stopping or its context passed
|
// that can can be used to cancel the context later on when the goroutine is
|
||||||
// to NewGoroutineManager has expired.
|
// done.
|
||||||
func (g *GoroutineManager) Go(f func(ctx context.Context)) bool {
|
func (g *GoroutineManager) addCancelFn(cancel context.CancelFunc) uint32 {
|
||||||
// Calling wg.Add(1) and wg.Wait() when wg's counter is 0 is a race
|
|
||||||
// condition, since it is not clear should Wait() block or not. This
|
|
||||||
// kind of race condition is detected by Go runtime and results in a
|
|
||||||
// crash if running with `-race`. To prevent this, whole Go method is
|
|
||||||
// protected with a mutex. The call to wg.Wait() inside Stop() can still
|
|
||||||
// run in parallel with Go, but in that case g.ctx is in expired state,
|
|
||||||
// because cancel() was called in Stop, so Go returns before wg.Add(1)
|
|
||||||
// call.
|
|
||||||
g.mu.Lock()
|
g.mu.Lock()
|
||||||
defer g.mu.Unlock()
|
defer g.mu.Unlock()
|
||||||
|
|
||||||
if g.ctx.Err() != nil {
|
id := g.id.Add(1)
|
||||||
|
g.cancelFns[id] = cancel
|
||||||
|
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancel cancels the context associated with the passed id.
|
||||||
|
func (g *GoroutineManager) cancel(id uint32) {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
|
||||||
|
g.cancelUnsafe(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancelUnsafe cancels the context associated with the passed id without
|
||||||
|
// acquiring the mutex.
|
||||||
|
func (g *GoroutineManager) cancelUnsafe(id uint32) {
|
||||||
|
fn, ok := g.cancelFns[id]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fn()
|
||||||
|
|
||||||
|
delete(g.cancelFns, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Go tries to start a new goroutine and returns a boolean indicating its
|
||||||
|
// success. It returns true if the goroutine was successfully created and false
|
||||||
|
// otherwise. A goroutine will fail to be created iff the goroutine manager is
|
||||||
|
// stopping or the passed context has already expired. The passed call-back
|
||||||
|
// function must exit if the passed context expires.
|
||||||
|
func (g *GoroutineManager) Go(ctx context.Context,
|
||||||
|
f func(ctx context.Context)) bool {
|
||||||
|
|
||||||
|
// Derive a cancellable context from the passed context and store its
|
||||||
|
// cancel function in the manager. The context will be cancelled when
|
||||||
|
// either the parent context is cancelled or the quit channel is closed
|
||||||
|
// which will call the stored cancel function.
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
id := g.addCancelFn(cancel)
|
||||||
|
|
||||||
|
// Calling wg.Add(1) and wg.Wait() when the wg's counter is 0 is a race
|
||||||
|
// condition, since it is not clear if Wait() should block or not. This
|
||||||
|
// kind of race condition is detected by Go runtime and results in a
|
||||||
|
// crash if running with `-race`. To prevent this, we protect the calls
|
||||||
|
// to wg.Add(1) and wg.Wait() with a mutex. If we block here because
|
||||||
|
// Stop is running first, then Stop will close the quit channel which
|
||||||
|
// will cause the context to be cancelled, and we will exit before
|
||||||
|
// calling wg.Add(1). If we grab the mutex here before Stop does, then
|
||||||
|
// Stop will block until after we call wg.Add(1).
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
|
||||||
|
// Before continuing to start the goroutine, we need to check if the
|
||||||
|
// context has already expired. This could be the case if the parent
|
||||||
|
// context has already expired or if Stop has been called.
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
g.cancelUnsafe(id)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure that the goroutine is not started if the manager has stopped.
|
||||||
|
select {
|
||||||
|
case <-g.quit:
|
||||||
|
g.cancelUnsafe(id)
|
||||||
|
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
g.wg.Add(1)
|
g.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer g.wg.Done()
|
defer func() {
|
||||||
f(g.ctx)
|
g.cancel(id)
|
||||||
|
g.wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
f(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
@ -56,20 +128,30 @@ func (g *GoroutineManager) Go(f func(ctx context.Context)) bool {
|
||||||
// Stop prevents new goroutines from being added and waits for all running
|
// Stop prevents new goroutines from being added and waits for all running
|
||||||
// goroutines to finish.
|
// goroutines to finish.
|
||||||
func (g *GoroutineManager) Stop() {
|
func (g *GoroutineManager) Stop() {
|
||||||
g.mu.Lock()
|
g.stopped.Do(func() {
|
||||||
g.cancel()
|
// Closing the quit channel will prevent any new goroutines from
|
||||||
g.mu.Unlock()
|
// starting.
|
||||||
|
g.mu.Lock()
|
||||||
|
close(g.quit)
|
||||||
|
for _, cancel := range g.cancelFns {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
// Wait for all goroutines to finish. Note that this wg.Wait() call is
|
// Wait for all goroutines to finish. Note that this wg.Wait()
|
||||||
// safe, since it can't run in parallel with wg.Add(1) call in Go, since
|
// call is safe, since it can't run in parallel with wg.Add(1)
|
||||||
// we just cancelled the context and even if Go call starts running here
|
// call in Go, since we just cancelled the context and even if
|
||||||
// after acquiring the mutex, it would see that the context has expired
|
// Go call starts running here after acquiring the mutex, it
|
||||||
// and return false instead of calling wg.Add(1).
|
// would see that the context has expired and return false
|
||||||
g.wg.Wait()
|
// instead of calling wg.Add(1).
|
||||||
|
g.wg.Wait()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Done returns a channel which is closed when either the context passed to
|
// Done returns a channel which is closed once Stop has been called and the
|
||||||
// NewGoroutineManager expires or when Stop is called.
|
// quit channel closed. Note that the channel closing indicates that shutdown
|
||||||
|
// of the GoroutineManager has started but not necessarily that the Stop method
|
||||||
|
// has finished.
|
||||||
func (g *GoroutineManager) Done() <-chan struct{} {
|
func (g *GoroutineManager) Done() <-chan struct{} {
|
||||||
return g.ctx.Done()
|
return g.quit
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,156 +2,145 @@ package fn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestGoroutineManager tests that the GoroutineManager starts goroutines until
|
// TestGoroutineManager tests the behaviour of the GoroutineManager.
|
||||||
// ctx expires. It also makes sure it fails to start new goroutines after the
|
|
||||||
// context expired and the GoroutineManager is in the process of waiting for
|
|
||||||
// already started goroutines in the Stop method.
|
|
||||||
func TestGoroutineManager(t *testing.T) {
|
func TestGoroutineManager(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
m := NewGoroutineManager(context.Background())
|
// Here we test that the GoroutineManager starts goroutines until it has
|
||||||
|
// been stopped.
|
||||||
|
t.Run("GM is stopped", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
taskChan := make(chan struct{})
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
m = NewGoroutineManager()
|
||||||
|
taskChan = make(chan struct{})
|
||||||
|
)
|
||||||
|
|
||||||
require.True(t, m.Go(func(ctx context.Context) {
|
// The gm has not stopped yet and the passed in context has not
|
||||||
<-taskChan
|
// expired, so we expect the goroutine to start. The taskChan is
|
||||||
}))
|
// blocking, so this goroutine will be live for a while.
|
||||||
|
require.True(t, m.Go(ctx, func(ctx context.Context) {
|
||||||
|
<-taskChan
|
||||||
|
}))
|
||||||
|
|
||||||
t1 := time.Now()
|
t1 := time.Now()
|
||||||
|
|
||||||
// Close taskChan in 1s, causing the goroutine to stop.
|
// Close taskChan in 1s, causing the goroutine to stop.
|
||||||
time.AfterFunc(time.Second, func() {
|
time.AfterFunc(time.Second, func() {
|
||||||
close(taskChan)
|
|
||||||
})
|
|
||||||
|
|
||||||
m.Stop()
|
|
||||||
stopDelay := time.Since(t1)
|
|
||||||
|
|
||||||
// Make sure Stop was waiting for the goroutine to stop.
|
|
||||||
require.Greater(t, stopDelay, time.Second)
|
|
||||||
|
|
||||||
// Make sure new goroutines do not start after Stop.
|
|
||||||
require.False(t, m.Go(func(ctx context.Context) {}))
|
|
||||||
|
|
||||||
// When Stop() is called, the internal context expires and m.Done() is
|
|
||||||
// closed. Test this.
|
|
||||||
select {
|
|
||||||
case <-m.Done():
|
|
||||||
default:
|
|
||||||
t.Errorf("Done() channel must be closed at this point")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGoroutineManagerContextExpires tests the effect of context expiry.
|
|
||||||
func TestGoroutineManagerContextExpires(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
m := NewGoroutineManager(ctx)
|
|
||||||
|
|
||||||
require.True(t, m.Go(func(ctx context.Context) {
|
|
||||||
<-ctx.Done()
|
|
||||||
}))
|
|
||||||
|
|
||||||
// The Done channel of the manager should not be closed, so the
|
|
||||||
// following call must block.
|
|
||||||
select {
|
|
||||||
case <-m.Done():
|
|
||||||
t.Errorf("Done() channel must not be closed at this point")
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
// The Done channel of the manager should be closed, so the following
|
|
||||||
// call must not block.
|
|
||||||
select {
|
|
||||||
case <-m.Done():
|
|
||||||
default:
|
|
||||||
t.Errorf("Done() channel must be closed at this point")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure new goroutines do not start after context expiry.
|
|
||||||
require.False(t, m.Go(func(ctx context.Context) {}))
|
|
||||||
|
|
||||||
// Stop will wait for all goroutines to stop.
|
|
||||||
m.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGoroutineManagerStress starts many goroutines while calling Stop. It
|
|
||||||
// is needed to make sure the GoroutineManager does not crash if this happen.
|
|
||||||
// If the mutex was not used, it would crash because of a race condition between
|
|
||||||
// wg.Add(1) and wg.Wait().
|
|
||||||
func TestGoroutineManagerStress(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
m := NewGoroutineManager(context.Background())
|
|
||||||
|
|
||||||
stopChan := make(chan struct{})
|
|
||||||
|
|
||||||
time.AfterFunc(1*time.Millisecond, func() {
|
|
||||||
m.Stop()
|
|
||||||
close(stopChan)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Starts 100 goroutines sequentially. Sequential order is needed to
|
|
||||||
// keep wg.counter low (0 or 1) to increase probability of race
|
|
||||||
// condition to be caught if it exists. If mutex is removed in the
|
|
||||||
// implementation, this test crashes under `-race`.
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
taskChan := make(chan struct{})
|
|
||||||
ok := m.Go(func(ctx context.Context) {
|
|
||||||
close(taskChan)
|
close(taskChan)
|
||||||
})
|
})
|
||||||
// If goroutine was started, wait for its completion.
|
|
||||||
if ok {
|
m.Stop()
|
||||||
<-taskChan
|
stopDelay := time.Since(t1)
|
||||||
|
|
||||||
|
// Make sure Stop was waiting for the goroutine to stop.
|
||||||
|
require.Greater(t, stopDelay, time.Second)
|
||||||
|
|
||||||
|
// Make sure new goroutines do not start after Stop.
|
||||||
|
require.False(t, m.Go(ctx, func(ctx context.Context) {}))
|
||||||
|
|
||||||
|
// When Stop() is called, gm quit channel has been closed and so
|
||||||
|
// Done() should return.
|
||||||
|
select {
|
||||||
|
case <-m.Done():
|
||||||
|
default:
|
||||||
|
t.Errorf("Done() channel must be closed at this point")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for Stop to complete.
|
|
||||||
<-stopChan
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGoroutineManagerStopsStress launches many Stop() calls in parallel with a
|
|
||||||
// task exiting. It attempts to catch a race condition between wg.Done() and
|
|
||||||
// wg.Wait() calls. According to documentation of wg.Wait() this is acceptable,
|
|
||||||
// therefore this test passes even with -race.
|
|
||||||
func TestGoroutineManagerStopsStress(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
m := NewGoroutineManager(context.Background())
|
|
||||||
|
|
||||||
// jobChan is used to make the task to finish.
|
|
||||||
jobChan := make(chan struct{})
|
|
||||||
|
|
||||||
// Start a task and wait inside it until we start calling Stop() method.
|
|
||||||
ok := m.Go(func(ctx context.Context) {
|
|
||||||
<-jobChan
|
|
||||||
})
|
})
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
// Now launch many gorotines calling Stop() method in parallel.
|
// Test that the GoroutineManager fails to start a goroutine or exits a
|
||||||
var wg sync.WaitGroup
|
// goroutine if the caller context has expired.
|
||||||
for i := 0; i < 100; i++ {
|
t.Run("Caller context expires", func(t *testing.T) {
|
||||||
wg.Add(1)
|
t.Parallel()
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
m = NewGoroutineManager()
|
||||||
|
taskChan = make(chan struct{})
|
||||||
|
)
|
||||||
|
|
||||||
|
// Derive a child context with a cancel function.
|
||||||
|
ctxc, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
// The gm has not stopped yet and the passed in context has not
|
||||||
|
// expired, so we expect the goroutine to start.
|
||||||
|
require.True(t, m.Go(ctxc, func(ctx context.Context) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-taskChan:
|
||||||
|
t.Fatalf("The task was performed when it " +
|
||||||
|
"should not have")
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Give the GM a little bit of time to start the goroutine so
|
||||||
|
// that we can be sure that it is already listening on the
|
||||||
|
// ctx and taskChan before calling cancel.
|
||||||
|
time.Sleep(time.Millisecond * 500)
|
||||||
|
|
||||||
|
// Cancel the context so that the goroutine exits.
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Attempt to send a signal on the task channel, nothing should
|
||||||
|
// happen since the goroutine has already exited.
|
||||||
|
select {
|
||||||
|
case taskChan <- struct{}{}:
|
||||||
|
case <-time.After(time.Millisecond * 200):
|
||||||
|
}
|
||||||
|
|
||||||
|
// Again attempt to add a goroutine with the same cancelled
|
||||||
|
// context. This should fail since the context has already
|
||||||
|
// expired.
|
||||||
|
require.False(t, m.Go(ctxc, func(ctx context.Context) {
|
||||||
|
t.Fatalf("The goroutine should not have started")
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Stop the goroutine manager.
|
||||||
|
m.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start many goroutines while calling Stop. We do this to make sure
|
||||||
|
// that the GoroutineManager does not crash when these calls are done in
|
||||||
|
// parallel because of the potential race between wg.Add() and
|
||||||
|
// wg.Done() when the wg counter is 0.
|
||||||
|
t.Run("Stress test", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
m = NewGoroutineManager()
|
||||||
|
stopChan = make(chan struct{})
|
||||||
|
)
|
||||||
|
|
||||||
|
time.AfterFunc(1*time.Millisecond, func() {
|
||||||
m.Stop()
|
m.Stop()
|
||||||
}()
|
close(stopChan)
|
||||||
}
|
})
|
||||||
|
|
||||||
// Exit the task in parallel with Stop() calls.
|
// Start 100 goroutines sequentially. Sequential order is
|
||||||
close(jobChan)
|
// needed to keep wg.counter low (0 or 1) to increase
|
||||||
|
// probability of the race condition to triggered if it exists.
|
||||||
|
// If mutex is removed in the implementation, this test crashes
|
||||||
|
// under `-race`.
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
taskChan := make(chan struct{})
|
||||||
|
ok := m.Go(ctx, func(ctx context.Context) {
|
||||||
|
close(taskChan)
|
||||||
|
})
|
||||||
|
// If goroutine was started, wait for its completion.
|
||||||
|
if ok {
|
||||||
|
<-taskChan
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Wait until all the Stop() calls complete.
|
// Wait for Stop to complete.
|
||||||
wg.Wait()
|
<-stopChan
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue