diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go index b583d4a8d..ecbd74834 100644 --- a/protofsm/state_machine.go +++ b/protofsm/state_machine.go @@ -1,6 +1,7 @@ package protofsm import ( + "context" "fmt" "sync" "time" @@ -135,12 +136,11 @@ type StateMachine[Event any, Env Environment] struct { // query the internal state machine state. stateQuery chan stateQuery[Event, Env] + wg fn.GoroutineManager + quit chan struct{} + startOnce sync.Once stopOnce sync.Once - - // TODO(roasbeef): also use that context guard here? - quit chan struct{} - wg sync.WaitGroup } // ErrorReporter is an interface that's used to report errors that occur during @@ -194,8 +194,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env] cfg: cfg, events: make(chan Event, 1), stateQuery: make(chan stateQuery[Event, Env]), - quit: make(chan struct{}), + wg: *fn.NewGoroutineManager(context.Background()), newStateEvents: fn.NewEventDistributor[State[Event, Env]](), + quit: make(chan struct{}), } } @@ -203,8 +204,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env] // the state machine to completion. func (s *StateMachine[Event, Env]) Start() { s.startOnce.Do(func() { - s.wg.Add(1) - go s.driveMachine() + _ = s.wg.Go(func(ctx context.Context) { + s.driveMachine() + }) }) } @@ -213,7 +215,7 @@ func (s *StateMachine[Event, Env]) Start() { func (s *StateMachine[Event, Env]) Stop() { s.stopOnce.Do(func() { close(s.quit) - s.wg.Wait() + s.wg.Stop() }) } @@ -320,7 +322,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[ // executeDaemonEvent executes a daemon event, which is a special type of event // that can be emitted as part of the state transition function of the state // machine. An error is returned if the type of event is unknown. -func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen +func (s *StateMachine[Event, Env]) executeDaemonEvent( event DaemonEvent) error { switch daemonEvent := event.(type) { @@ -342,14 +344,10 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen err) } - // If a post-send event was specified, then we'll - // funnel that back into the main state machine now as - // well. - daemonEvent.PostSendEvent.WhenSome(func(event Event) { - s.wg.Add(1) - go func() { - defer s.wg.Done() - + // If a post-send event was specified, then we'll funnel + // that back into the main state machine now as well. + return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:lll + return s.wg.Go(func(ctx context.Context) { log.Debugf("FSM(%v): sending "+ "post-send event: %v", s.cfg.Env.Name(), @@ -357,10 +355,8 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen ) s.SendEvent(event) - }() + }) }) - - return nil } // If this doesn't have a SendWhen predicate, then we can just @@ -372,10 +368,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen // Otherwise, this has a SendWhen predicate, so we'll need // launch a goroutine to poll the SendWhen, then send only once // the predicate is true. - s.wg.Add(1) - go func() { - defer s.wg.Done() - + return s.wg.Go(func(ctx context.Context) { predicateTicker := time.NewTicker( s.cfg.CustomPollInterval.UnwrapOr(pollInterval), ) @@ -408,13 +401,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen return } - case <-s.quit: + case <-ctx.Done(): return } } - }() - - return nil + }) // If this is a broadcast transaction event, then we'll broadcast with // the label attached. @@ -445,9 +436,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen return fmt.Errorf("unable to register spend: %w", err) } - s.wg.Add(1) - go func() { - defer s.wg.Done() + return s.wg.Go(func(ctx context.Context) { for { select { case spend, ok := <-spendEvent.Spend: @@ -466,13 +455,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen return - case <-s.quit: + case <-ctx.Done(): return } } - }() - - return nil + }) // The state machine has requested a new event to be sent once a // specified txid+pkScript pair has confirmed. @@ -489,9 +476,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen return fmt.Errorf("unable to register conf: %w", err) } - s.wg.Add(1) - go func() { - defer s.wg.Done() + return s.wg.Go(func(ctx context.Context) { for { select { case <-confEvent.Confirmed: @@ -508,11 +493,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen return - case <-s.quit: + case <-ctx.Done(): return } } - }() + }) } return fmt.Errorf("unknown daemon event: %T", event) @@ -632,8 +617,6 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env], // incoming events, and then drives the state machine forward until it reaches // a terminal state. func (s *StateMachine[Event, Env]) driveMachine() { - defer s.wg.Done() - log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name()) currentState := s.cfg.InitialState @@ -676,16 +659,11 @@ func (s *StateMachine[Event, Env]) driveMachine() { // An outside caller is querying our state, so we'll return the // latest state. case stateQuery := <-s.stateQuery: - if !fn.SendOrQuit( - stateQuery.CurrentState, currentState, s.quit, - ) { - + if !fn.SendOrQuit(stateQuery.CurrentState, currentState, s.quit) { //nolint:lll return } - case <-s.quit: - // TODO(roasbeef): logs, etc - // * something in env? + case <-s.wg.Done(): return } }