diff --git a/fn/goroutine_manager.go b/fn/goroutine_manager.go index b209cf93c..28b72233a 100644 --- a/fn/goroutine_manager.go +++ b/fn/goroutine_manager.go @@ -130,6 +130,56 @@ func (g *GoroutineManager) Go(ctx context.Context, return true } +// GoBlocking tries to start a new blocking goroutine and returns a boolean +// indicating its success. It returns true if the goroutine was successfully +// created and false otherwise. A goroutine will fail to be created iff the +// goroutine manager has stopped (Stop() was called and all the goroutines +// have finished). To make sure GoBlocking succeeds, call it right after +// creating a GoroutineManager (in Start() method of your object) or from +// another goroutine created by the same GoroutineManager. +// +// The difference from Go() is that GoroutineManager doesn't manage contexts so +// the goroutine can run as long as needed. GoroutineManager will still wait for +// its completion in the Stop() method. But it is the caller's responsibility to +// stop the launched goroutine and to pass a context to it if needed. +// +// This method is intended to perform shutdown of important tasks, where +// interruption is not desirable. +func (g *GoroutineManager) GoBlocking(f func()) bool { + // Protect the whole code of the method with the mutex, because we + // access quit and count. + g.mu.Lock() + defer g.mu.Unlock() + + // If the goroutine manager has completelly stopped, stop. This happens + // only if Stop() was called and all goroutines have finished. + select { + case <-g.quit: + if g.count == 0 { + return false + } + default: + } + + g.count++ + go func() { + defer func() { + g.mu.Lock() + g.count-- + g.mu.Unlock() + + // We use Signal() and not Broadcast(), because there + // could be only one user of g.cond.Wait(), because of + // g.stopped. + g.cond.Signal() + }() + + f() + }() + + return true +} + // Stop prevents new goroutines from being added and waits for all running // goroutines to finish. func (g *GoroutineManager) Stop() { diff --git a/fn/goroutine_manager_test.go b/fn/goroutine_manager_test.go index 4b3f1914e..231c1a7e3 100644 --- a/fn/goroutine_manager_test.go +++ b/fn/goroutine_manager_test.go @@ -106,6 +106,62 @@ func TestGoroutineManager(t *testing.T) { m.Stop() }) + t.Run("GoBlocking", func(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + m = NewGoroutineManager() + ) + + // Start a blocking task. + taskChan := make(chan struct{}) + require.True(t, m.GoBlocking(func() { + <-taskChan + })) + + // Start stopping GoroutineManager. + stopped := make(chan struct{}) + go func() { + m.Stop() + close(stopped) + }() + + // Make sure Stop() is waiting. + select { + case <-stopped: + t.Fatalf("The Stop() method must be waiting") + case <-time.After(time.Millisecond * 200): + } + + // Since the first goroutine is still running, we can launch + // another blocking goroutine. + secondBlockingTaskDone := make(chan struct{}) + require.True(t, m.GoBlocking(func() { + close(secondBlockingTaskDone) + })) + + // Make sure the second blocking goroutine has started and + // executed. + <-secondBlockingTaskDone + + // However we can't start a non-blocking goroutine. + require.False(t, m.Go(ctx, func(ctx context.Context) { + t.Fatalf("The goroutine should not have started") + })) + + // Now let the first goroutine finish. + close(taskChan) + + // And make sure Stop() unblocked. + <-stopped + + // Now we can't start a goroutine even if it is blocking. + require.False(t, m.GoBlocking(func() { + t.Fatalf("The goroutine should not have started") + })) + }) + // Start many goroutines while calling Stop. We do this to make sure // that the GoroutineManager does not crash when these calls are done in // parallel because of the potential race between Go() and Stop() when @@ -124,11 +180,12 @@ func TestGoroutineManager(t *testing.T) { close(stopChan) }) - // Start 100 goroutines sequentially. Sequential order is needed - // to keep counter low (0 or 1) to increase probability of the - // race condition triggered if it exists. If mutex is removed in + // Start 100 goroutines sequentially, both with Go() and + // GoBlocking(). Sequential order is needed to keep counter low + // (0 or 1) to increase probability of the race condition + // triggered if it exists. If mutex is removed in // the implementation, this test crashes under `-race`. - for i := 0; i < 100; i++ { + for i := 0; i < 50; i++ { taskChan := make(chan struct{}) ok := m.Go(ctx, func(ctx context.Context) { close(taskChan) @@ -137,6 +194,15 @@ func TestGoroutineManager(t *testing.T) { if ok { <-taskChan } + + taskChan = make(chan struct{}) + ok = m.GoBlocking(func() { + close(taskChan) + }) + // If goroutine was started, wait for its completion. + if ok { + <-taskChan + } } // Wait for Stop to complete.