diff --git a/routing/validation_barrier_test.go b/routing/validation_barrier_test.go new file mode 100644 index 000000000..6fd3930b3 --- /dev/null +++ b/routing/validation_barrier_test.go @@ -0,0 +1,151 @@ +package routing_test + +import ( + "encoding/binary" + "testing" + "time" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing" +) + +// TestValidationBarrierSemaphore checks basic properties of the validation +// barrier's semaphore wrt. enqueuing/dequeuing. +func TestValidationBarrierSemaphore(t *testing.T) { + const ( + numTasks = 8 + numPendingTasks = 8 + timeout = 50 * time.Millisecond + ) + + quit := make(chan struct{}) + barrier := routing.NewValidationBarrier(numTasks, quit) + + // Saturate the semaphore with jobs. + for i := 0; i < numTasks; i++ { + barrier.InitJobDependencies(nil) + } + + // Spawn additional tasks that will signal completion when added. + jobAdded := make(chan struct{}) + for i := 0; i < numPendingTasks; i++ { + go func() { + barrier.InitJobDependencies(nil) + jobAdded <- struct{}{} + }() + } + + // Check that no jobs are added while semaphore is full. + select { + case <-time.After(timeout): + // Expected since no slots open. + case <-jobAdded: + t.Fatalf("job should not have been added") + } + + // Complete jobs one at a time and verify that they get added. + for i := 0; i < numPendingTasks; i++ { + barrier.CompleteJob() + + select { + case <-time.After(timeout): + t.Fatalf("timeout waiting for job to be added") + case <-jobAdded: + // Expected since one slot opened up. + } + } +} + +// TestValidationBarrierQuit checks that pending validation tasks will return an +// error from WaitForDependants if the barrier's quit signal is canceled. +func TestValidationBarrierQuit(t *testing.T) { + const ( + numTasks = 8 + timeout = 50 * time.Millisecond + ) + + quit := make(chan struct{}) + barrier := routing.NewValidationBarrier(2*numTasks, quit) + + // Create a set of unique channel announcements that we will prep for + // validation. + anns := make([]*lnwire.ChannelAnnouncement, 0, numTasks) + for i := 0; i < numTasks; i++ { + anns = append(anns, &lnwire.ChannelAnnouncement{ + ShortChannelID: lnwire.NewShortChanIDFromInt(uint64(i)), + NodeID1: nodeIDFromInt(uint64(2 * i)), + NodeID2: nodeIDFromInt(uint64(2*i + 1)), + }) + barrier.InitJobDependencies(anns[i]) + } + + // Create a set of channel updates, that must wait until their + // associated channel announcement has been verified. + chanUpds := make([]*lnwire.ChannelUpdate, 0, numTasks) + for i := 0; i < numTasks; i++ { + chanUpds = append(chanUpds, &lnwire.ChannelUpdate{ + ShortChannelID: lnwire.NewShortChanIDFromInt(uint64(i)), + }) + barrier.InitJobDependencies(chanUpds[i]) + } + + // Spawn additional tasks that will send the error returned after + // waiting for the announcements to finish. In the background, we will + // iteratively queue the channel updates, which will send back the error + // returned from waiting. + jobErrs := make(chan error) + for i := 0; i < numTasks; i++ { + go func(ii int) { + jobErrs <- barrier.WaitForDependants(chanUpds[ii]) + }(i) + } + + // Check that no jobs are added while semaphore is full. + select { + case <-time.After(timeout): + // Expected since no slots open. + case <-jobErrs: + t.Fatalf("job should not have been signaled") + } + + // Complete the first half of jobs, one at a time, verifying that they + // get signaled. Then, quit the barrier and check that all others exit + // with the correct error. + for i := 0; i < numTasks; i++ { + switch { + // First half, signal completion and task semaphore + case i < numTasks/2: + barrier.SignalDependants(anns[i]) + barrier.CompleteJob() + + // At midpoint, quit the validation barrier. + case i == numTasks/2: + close(quit) + } + + var err error + select { + case <-time.After(timeout): + t.Fatalf("timeout waiting for job to be signaled") + case err = <-jobErrs: + } + + switch { + // First half should return without failure. + case i < numTasks/2 && err != nil: + t.Fatalf("unexpected failure while waiting: %v", err) + + // Last half should return the shutdown error. + case i >= numTasks/2 && err != routing.ErrVBarrierShuttingDown: + t.Fatalf("expected failure after quitting: want %v, "+ + "got %v", routing.ErrVBarrierShuttingDown, err) + } + } +} + +// nodeIDFromInt creates a node ID by writing a uint64 to the first 8 bytes. +func nodeIDFromInt(i uint64) [33]byte { + var nodeID [33]byte + binary.BigEndian.PutUint64(nodeID[:8], i) + return nodeID +}