package fn import ( "context" "sync" "sync/atomic" "time" ) var ( // DefaultTimeout is the default timeout used for context operations. DefaultTimeout = 30 * time.Second ) // ContextGuard is a struct that provides a wait group and main quit channel // that can be used to create guarded contexts. type ContextGuard struct { mu sync.Mutex wg sync.WaitGroup quit chan struct{} stopped sync.Once // id is used to generate unique ids for each context that should be // cancelled when the main quit signal is triggered. id atomic.Uint32 // cancelFns is a map of cancel functions that can be used to cancel // any context that should be cancelled when the main quit signal is // triggered. The key is the id of the context. The mutex must be held // when accessing this map. cancelFns map[uint32]context.CancelFunc } // NewContextGuard constructs and returns a new instance of ContextGuard. func NewContextGuard() *ContextGuard { return &ContextGuard{ quit: make(chan struct{}), cancelFns: make(map[uint32]context.CancelFunc), } } // Quit is used to signal the main quit channel, which will cancel all // non-blocking contexts derived from the ContextGuard. func (g *ContextGuard) Quit() { g.stopped.Do(func() { g.mu.Lock() defer g.mu.Unlock() for _, cancel := range g.cancelFns { cancel() } close(g.quit) }) } // Done returns a channel that will be closed when the main quit signal is // triggered. func (g *ContextGuard) Done() <-chan struct{} { return g.quit } // WgAdd is used to add delta to the internal wait group of the ContextGuard. func (g *ContextGuard) WgAdd(delta int) { g.wg.Add(delta) } // WgDone is used to decrement the internal wait group of the ContextGuard. func (g *ContextGuard) WgDone() { g.wg.Done() } // WgWait is used to block until the internal wait group of the ContextGuard is // empty. func (g *ContextGuard) WgWait() { g.wg.Wait() } // ctxGuardOptions is used to configure the behaviour of the context derived // via the WithCtx method of the ContextGuard. type ctxGuardOptions struct { blocking bool withTimeout bool timeout time.Duration } // ContextGuardOption defines the signature of a functional option that can be // used to configure the behaviour of the context derived via the WithCtx method // of the ContextGuard. type ContextGuardOption func(*ctxGuardOptions) // WithBlockingCG is used to create a cancellable context that will NOT be // cancelled if the main quit signal is triggered, to block shutdown of // important tasks. func WithBlockingCG() ContextGuardOption { return func(o *ctxGuardOptions) { o.blocking = true } } // WithCustomTimeoutCG is used to create a cancellable context with a custom // timeout. Such a context will be cancelled if either the parent context is // cancelled, the timeout is reached or, if the Blocking option is not provided, // the main quit signal is triggered. func WithCustomTimeoutCG(timeout time.Duration) ContextGuardOption { return func(o *ctxGuardOptions) { o.withTimeout = true o.timeout = timeout } } // WithTimeoutCG is used to create a cancellable context with a default timeout. // Such a context will be cancelled if either the parent context is cancelled, // the timeout is reached or, if the Blocking option is not provided, the main // quit signal is triggered. func WithTimeoutCG() ContextGuardOption { return func(o *ctxGuardOptions) { o.withTimeout = true o.timeout = DefaultTimeout } } // Create is used to derive a cancellable context from the parent. Various // options can be provided to configure the behaviour of the derived context. func (g *ContextGuard) Create(ctx context.Context, options ...ContextGuardOption) (context.Context, context.CancelFunc) { // Exit early if the parent context has already been cancelled. select { case <-ctx.Done(): return ctx, func() {} default: } var opts ctxGuardOptions for _, o := range options { o(&opts) } g.mu.Lock() defer g.mu.Unlock() var cancel context.CancelFunc if opts.withTimeout { ctx, cancel = context.WithTimeout(ctx, opts.timeout) } else { ctx, cancel = context.WithCancel(ctx) } if opts.blocking { g.ctxBlocking(ctx, cancel) return ctx, cancel } // If the call is non-blocking, then we can exit early if the main quit // signal has been triggered. select { case <-g.quit: cancel() return ctx, cancel default: } cancel = g.ctxQuitUnsafe(ctx, cancel) return ctx, cancel } // ctxQuitUnsafe spins off a goroutine that will block until the passed context // is cancelled or until the quit channel has been signaled after which it will // call the passed cancel function and decrement the wait group. // // NOTE: the caller must hold the ContextGuard's mutex before calling this // function. func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context, cancel context.CancelFunc) context.CancelFunc { cancel = g.addCancelFnUnsafe(cancel) g.wg.Add(1) go func() { defer cancel() defer g.wg.Done() select { case <-g.quit: case <-ctx.Done(): } }() return cancel } // ctxBlocking spins off a goroutine that will block until the passed context // is cancelled after which it will call the passed cancel function and // decrement the wait group. func (g *ContextGuard) ctxBlocking(ctx context.Context, cancel context.CancelFunc) { g.wg.Add(1) go func() { defer cancel() defer g.wg.Done() select { case <-ctx.Done(): } }() } // addCancelFnUnsafe adds a context cancel function to the manager and returns a // call-back which can safely be used to cancel the context. // // NOTE: the caller must hold the ContextGuard's mutex before calling this // function. func (g *ContextGuard) addCancelFnUnsafe( cancel context.CancelFunc) context.CancelFunc { id := g.id.Add(1) g.cancelFns[id] = cancel return g.cancelCtxFn(id) } // cancelCtxFn returns a call-back that can be used to cancel the context // associated with the passed id. func (g *ContextGuard) cancelCtxFn(id uint32) context.CancelFunc { return func() { g.mu.Lock() fn, ok := g.cancelFns[id] if !ok { g.mu.Unlock() return } delete(g.cancelFns, id) g.mu.Unlock() fn() } }