diff --git a/clock/test_clock.go b/clock/test_clock.go index f4319cee3..85e33d4fe 100644 --- a/clock/test_clock.go +++ b/clock/test_clock.go @@ -10,6 +10,7 @@ type TestClock struct { currentTime time.Time timeChanMap map[time.Time][]chan time.Time timeLock sync.Mutex + tickSignal chan time.Duration } // NewTestClock returns a new test clock. @@ -20,6 +21,19 @@ func NewTestClock(startTime time.Time) *TestClock { } } +// NewTestClockWithTickSignal will create a new test clock with an added +// channel which will be used to signal when a new ticker is registered. +// This is useful when creating a ticker on a separate goroutine and we'd +// like to wait for that to happen before advancing the test case. +func NewTestClockWithTickSignal(startTime time.Time, + tickSignal chan time.Duration) *TestClock { + + testClock := NewTestClock(startTime) + testClock.tickSignal = tickSignal + + return testClock +} + // Now returns the current (test) time. func (c *TestClock) Now() time.Time { c.timeLock.Lock() @@ -32,7 +46,14 @@ func (c *TestClock) Now() time.Time { // duration has passed passed by the user set test time. func (c *TestClock) TickAfter(duration time.Duration) <-chan time.Time { c.timeLock.Lock() - defer c.timeLock.Unlock() + defer func() { + c.timeLock.Unlock() + + // Signal that the ticker has been added. + if c.tickSignal != nil { + c.tickSignal <- duration + } + }() triggerTime := c.currentTime.Add(duration) ch := make(chan time.Time, 1) diff --git a/clock/test_clock_test.go b/clock/test_clock_test.go index 879cc8fd1..275f5c1a5 100644 --- a/clock/test_clock_test.go +++ b/clock/test_clock_test.go @@ -1,8 +1,11 @@ package clock import ( + "fmt" "testing" "time" + + "github.com/stretchr/testify/assert" ) var ( @@ -42,6 +45,7 @@ func TestTickAfter(t *testing.T) { select { case <-ticker: tick = true + case <-time.After(time.Millisecond): } @@ -61,3 +65,34 @@ func TestTickAfter(t *testing.T) { tickOrTimeOut(ticker2, true) tickOrTimeOut(ticker3, false) } + +// TestTickSignal tests that TickAfter signals registration allowing +// safe time advancement. +func TestTickSignal(t *testing.T) { + const interval = time.Second + + ch := make(chan time.Duration) + c := NewTestClockWithTickSignal(testTime, ch) + err := make(chan error, 1) + + go func() { + select { + // TickAfter will signal registration but will not + // tick, unless we read the signal and set the time. + case <-c.TickAfter(interval): + err <- nil + + // Signal timeout if tick didn't happen. + case <-time.After(time.Second): + err <- fmt.Errorf("timeout") + } + }() + + tick := <-ch + // Expect that the interval is correctly passed over the channel. + assert.Equal(t, interval, tick) + + // Once the ticker is registered, set the time to make it fire. + c.SetTime(testTime.Add(time.Second)) + assert.NoError(t, <-err) +}