protofsm: use new fn.GoroutineManager to manage goroutines

This fixes an isuse that can occur when we have concurrent calls to
`Stop` while the state machine is driving forward.
This commit is contained in:
Olaoluwa Osuntokun 2024-11-13 17:10:30 -08:00
parent 675b8b5f61
commit 66815e5000
No known key found for this signature in database
GPG Key ID: 90525F7DEEE0AD86

View File

@ -1,6 +1,7 @@
package protofsm
import (
"context"
"fmt"
"sync"
"time"
@ -135,12 +136,11 @@ type StateMachine[Event any, Env Environment] struct {
// query the internal state machine state.
stateQuery chan stateQuery[Event, Env]
wg fn.GoroutineManager
quit chan struct{}
startOnce sync.Once
stopOnce sync.Once
// TODO(roasbeef): also use that context guard here?
quit chan struct{}
wg sync.WaitGroup
}
// ErrorReporter is an interface that's used to report errors that occur during
@ -194,8 +194,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
cfg: cfg,
events: make(chan Event, 1),
stateQuery: make(chan stateQuery[Event, Env]),
quit: make(chan struct{}),
wg: *fn.NewGoroutineManager(context.Background()),
newStateEvents: fn.NewEventDistributor[State[Event, Env]](),
quit: make(chan struct{}),
}
}
@ -203,8 +204,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
// the state machine to completion.
func (s *StateMachine[Event, Env]) Start() {
s.startOnce.Do(func() {
s.wg.Add(1)
go s.driveMachine()
_ = s.wg.Go(func(ctx context.Context) {
s.driveMachine()
})
})
}
@ -213,7 +215,7 @@ func (s *StateMachine[Event, Env]) Start() {
func (s *StateMachine[Event, Env]) Stop() {
s.stopOnce.Do(func() {
close(s.quit)
s.wg.Wait()
s.wg.Stop()
})
}
@ -320,7 +322,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
// executeDaemonEvent executes a daemon event, which is a special type of event
// that can be emitted as part of the state transition function of the state
// machine. An error is returned if the type of event is unknown.
func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
func (s *StateMachine[Event, Env]) executeDaemonEvent(
event DaemonEvent) error {
switch daemonEvent := event.(type) {
@ -342,14 +344,10 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
err)
}
// If a post-send event was specified, then we'll
// funnel that back into the main state machine now as
// well.
daemonEvent.PostSendEvent.WhenSome(func(event Event) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
// If a post-send event was specified, then we'll funnel
// that back into the main state machine now as well.
return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:lll
return s.wg.Go(func(ctx context.Context) {
log.Debugf("FSM(%v): sending "+
"post-send event: %v",
s.cfg.Env.Name(),
@ -357,10 +355,8 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
)
s.SendEvent(event)
}()
})
return nil
})
}
// If this doesn't have a SendWhen predicate, then we can just
@ -372,10 +368,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
// Otherwise, this has a SendWhen predicate, so we'll need
// launch a goroutine to poll the SendWhen, then send only once
// the predicate is true.
s.wg.Add(1)
go func() {
defer s.wg.Done()
return s.wg.Go(func(ctx context.Context) {
predicateTicker := time.NewTicker(
s.cfg.CustomPollInterval.UnwrapOr(pollInterval),
)
@ -408,13 +401,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return
}
case <-s.quit:
case <-ctx.Done():
return
}
}
}()
return nil
})
// If this is a broadcast transaction event, then we'll broadcast with
// the label attached.
@ -445,9 +436,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return fmt.Errorf("unable to register spend: %w", err)
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
return s.wg.Go(func(ctx context.Context) {
for {
select {
case spend, ok := <-spendEvent.Spend:
@ -466,13 +455,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return
case <-s.quit:
case <-ctx.Done():
return
}
}
}()
return nil
})
// The state machine has requested a new event to be sent once a
// specified txid+pkScript pair has confirmed.
@ -489,9 +476,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return fmt.Errorf("unable to register conf: %w", err)
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
return s.wg.Go(func(ctx context.Context) {
for {
select {
case <-confEvent.Confirmed:
@ -508,11 +493,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return
case <-s.quit:
case <-ctx.Done():
return
}
}
}()
})
}
return fmt.Errorf("unknown daemon event: %T", event)
@ -632,8 +617,6 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
// incoming events, and then drives the state machine forward until it reaches
// a terminal state.
func (s *StateMachine[Event, Env]) driveMachine() {
defer s.wg.Done()
log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name())
currentState := s.cfg.InitialState
@ -676,16 +659,11 @@ func (s *StateMachine[Event, Env]) driveMachine() {
// An outside caller is querying our state, so we'll return the
// latest state.
case stateQuery := <-s.stateQuery:
if !fn.SendOrQuit(
stateQuery.CurrentState, currentState, s.quit,
) {
if !fn.SendOrQuit(stateQuery.CurrentState, currentState, s.quit) { //nolint:lll
return
}
case <-s.quit:
// TODO(roasbeef): logs, etc
// * something in env?
case <-s.wg.Done():
return
}
}