lnd/pool/worker_test.go
Michael Rooke 78d9996620
trivial: Fix spelling errors
- Fixes some spelling in code comments and a couple of function names
2023-09-21 22:35:33 -04:00

360 lines
8.4 KiB
Go

package pool_test
import (
"bytes"
crand "crypto/rand"
"fmt"
"io"
"math/rand"
"testing"
"time"
"github.com/lightningnetwork/lnd/buffer"
"github.com/lightningnetwork/lnd/pool"
"github.com/stretchr/testify/require"
)
type workerPoolTest struct {
name string
newPool func() interface{}
numWorkers int
}
// TestConcreteWorkerPools asserts the behavior of any concrete implementations
// of worker pools provided by the pool package. Currently this tests the
// pool.Read and pool.Write instances.
func TestConcreteWorkerPools(t *testing.T) {
const (
gcInterval = time.Second
expiryInterval = 250 * time.Millisecond
numWorkers = 5
workerTimeout = 500 * time.Millisecond
)
tests := []workerPoolTest{
{
name: "write pool",
newPool: func() interface{} {
bp := pool.NewWriteBuffer(
gcInterval, expiryInterval,
)
return pool.NewWrite(
bp, numWorkers, workerTimeout,
)
},
numWorkers: numWorkers,
},
{
name: "read pool",
newPool: func() interface{} {
bp := pool.NewReadBuffer(
gcInterval, expiryInterval,
)
return pool.NewRead(
bp, numWorkers, workerTimeout,
)
},
numWorkers: numWorkers,
},
}
for _, test := range tests {
testWorkerPool(t, test)
}
}
func testWorkerPool(t *testing.T, test workerPoolTest) {
t.Run(test.name+" non blocking", func(t *testing.T) {
t.Parallel()
p := test.newPool()
startGeneric(t, p)
t.Cleanup(func() {
stopGeneric(t, p)
})
submitNonblockingGeneric(t, p, test.numWorkers)
})
t.Run(test.name+" blocking", func(t *testing.T) {
t.Parallel()
p := test.newPool()
startGeneric(t, p)
t.Cleanup(func() {
stopGeneric(t, p)
})
submitBlockingGeneric(t, p, test.numWorkers)
})
t.Run(test.name+" partial blocking", func(t *testing.T) {
t.Parallel()
p := test.newPool()
startGeneric(t, p)
t.Cleanup(func() {
stopGeneric(t, p)
})
submitPartialBlockingGeneric(t, p, test.numWorkers)
})
}
// submitNonblockingGeneric asserts that queueing tasks to the worker pool and
// allowing them all to unblock simultaneously results in all of the tasks being
// completed in a timely manner.
func submitNonblockingGeneric(t *testing.T, p interface{}, nWorkers int) {
// We'll submit 2*nWorkers tasks that will all be unblocked
// simultaneously.
nUnblocked := 2 * nWorkers
// First we'll queue all of the tasks for the pool.
errChan := make(chan error)
semChan := make(chan struct{})
for i := 0; i < nUnblocked; i++ {
go func() { errChan <- submitGeneric(p, semChan) }()
}
// Since we haven't signaled the semaphore, none of the them should
// complete.
pullNothing(t, errChan)
// Now, unblock them all simultaneously. All of the tasks should then be
// processed in parallel. Afterward, no more errors should come through.
close(semChan)
pullParallel(t, nUnblocked, errChan)
pullNothing(t, errChan)
}
// submitBlockingGeneric asserts that submitting blocking tasks to the pool and
// unblocking each sequentially results in a single task being processed at a
// time.
func submitBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
// We'll submit 2*nWorkers tasks that will be unblocked sequentially.
nBlocked := 2 * nWorkers
// First, queue all of the blocking tasks for the pool.
errChan := make(chan error)
semChan := make(chan struct{})
for i := 0; i < nBlocked; i++ {
go func() { errChan <- submitGeneric(p, semChan) }()
}
// Since we haven't signaled the semaphore, none of them should
// complete.
pullNothing(t, errChan)
// Now, pull each blocking task sequentially from the pool. Afterwards,
// no more errors should come through.
pullSequential(t, nBlocked, errChan, semChan)
pullNothing(t, errChan)
}
// submitPartialBlockingGeneric tests that so long as one worker is not blocked,
// any other non-blocking submitted tasks can still be processed.
func submitPartialBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
// We'll submit nWorkers-1 tasks that will be initially blocked, the
// remainder will all be unblocked simultaneously. After the unblocked
// tasks have finished, we will sequentially unblock the nWorkers-1
// tasks that were first submitted.
nBlocked := nWorkers - 1
nUnblocked := 2*nWorkers - nBlocked
// First, submit all of the blocking tasks to the pool.
errChan := make(chan error)
semChan := make(chan struct{})
for i := 0; i < nBlocked; i++ {
go func() { errChan <- submitGeneric(p, semChan) }()
}
// Since these are all blocked, no errors should be returned yet.
pullNothing(t, errChan)
// Now, add all of the non-blocking task to the pool.
semChanNB := make(chan struct{})
for i := 0; i < nUnblocked; i++ {
go func() { errChan <- submitGeneric(p, semChanNB) }()
}
// Since we haven't unblocked the second batch, we again expect no tasks
// to finish.
pullNothing(t, errChan)
// Now, unblock the unblocked task and pull all of them. After they have
// been pulled, we should see no more tasks.
close(semChanNB)
pullParallel(t, nUnblocked, errChan)
pullNothing(t, errChan)
// Finally, unblock each the blocked tasks we added initially, and
// assert that no further errors come through.
pullSequential(t, nBlocked, errChan, semChan)
pullNothing(t, errChan)
}
func pullNothing(t *testing.T, errChan chan error) {
t.Helper()
select {
case err := <-errChan:
t.Fatalf("received unexpected error before semaphore "+
"release: %v", err)
case <-time.After(time.Second):
}
}
func pullParallel(t *testing.T, n int, errChan chan error) {
t.Helper()
for i := 0; i < n; i++ {
select {
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
case <-time.After(time.Second):
t.Fatalf("task %d was not processed in time", i)
}
}
}
func pullSequential(
t *testing.T, n int, errChan chan error, semChan chan struct{},
) {
t.Helper()
for i := 0; i < n; i++ {
// Signal for another task to unblock.
select {
case semChan <- struct{}{}:
case <-time.After(time.Second):
t.Fatalf("task %d was not unblocked", i)
}
// Wait for the error to arrive, we expect it to be non-nil.
select {
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
case <-time.After(time.Second):
t.Fatalf("task %d was not processed in time", i)
}
}
}
func startGeneric(t *testing.T, p interface{}) {
t.Helper()
var err error
switch pp := p.(type) {
case *pool.Write:
err = pp.Start()
case *pool.Read:
err = pp.Start()
default:
t.Fatalf("unknown worker pool type: %T", p)
}
require.NoError(t, err, "unable to start worker pool")
}
func stopGeneric(t *testing.T, p interface{}) {
t.Helper()
var err error
switch pp := p.(type) {
case *pool.Write:
err = pp.Stop()
case *pool.Read:
err = pp.Stop()
default:
t.Fatalf("unknown worker pool type: %T", p)
}
require.NoError(t, err, "unable to stop worker pool")
}
func submitGeneric(p interface{}, sem <-chan struct{}) error {
var err error
switch pp := p.(type) {
case *pool.Write:
err = pp.Submit(func(buf *bytes.Buffer) error {
// Verify that the provided buffer has been reset to be
// zero length.
if buf.Len() != 0 {
return fmt.Errorf("buf should be length zero, "+
"instead has length %d", buf.Len())
}
// Verify that the capacity of the buffer has the
// correct underlying size of a buffer.WriteSize.
if buf.Cap() != buffer.WriteSize {
return fmt.Errorf("buf should have capacity "+
"%d, instead has capacity %d",
buffer.WriteSize, buf.Cap())
}
// Sample some random bytes that we'll use to dirty the
// buffer.
b := make([]byte, rand.Intn(buf.Cap()))
_, err := io.ReadFull(crand.Reader, b)
if err != nil {
return err
}
// Write the random bytes the buffer.
_, err = buf.Write(b)
// Wait until this task is signaled to exit.
<-sem
return err
})
case *pool.Read:
err = pp.Submit(func(buf *buffer.Read) error {
// Assert that all of the bytes in the provided array
// are zero, indicating that the buffer was reset
// between uses.
for i := range buf[:] {
if buf[i] != 0x00 {
return fmt.Errorf("byte %d of "+
"buffer.Read should be "+
"0, instead is %d", i, buf[i])
}
}
// Sample some random bytes to read into the buffer.
_, err := io.ReadFull(crand.Reader, buf[:])
// Wait until this task is signaled to exit.
<-sem
return err
})
default:
return fmt.Errorf("unknown worker pool type: %T", p)
}
if err != nil {
return fmt.Errorf("unable to submit task: %v", err)
}
return nil
}