mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 21:35:24 +01:00
brontide/noise_test: add TestFlush
This commit is contained in:
parent
333caac09c
commit
e3728da478
@ -31,6 +31,10 @@ const (
|
||||
// length of a message payload.
|
||||
lengthHeaderSize = 2
|
||||
|
||||
// encHeaderSize is the number of bytes required to hold an encrypted
|
||||
// header and it's MAC.
|
||||
encHeaderSize = lengthHeaderSize + macSize
|
||||
|
||||
// keyRotationInterval is the number of messages sent on a single
|
||||
// cipher stream before the keys are rotated forwards.
|
||||
keyRotationInterval = 1000
|
||||
@ -370,7 +374,7 @@ type Machine struct {
|
||||
// nextCipherHeader is a static buffer that we'll use to read in the
|
||||
// next ciphertext header from the wire. The header is a 2 byte length
|
||||
// (of the next ciphertext), followed by a 16 byte MAC.
|
||||
nextCipherHeader [lengthHeaderSize + macSize]byte
|
||||
nextCipherHeader [encHeaderSize]byte
|
||||
|
||||
// nextHeaderSend holds a reference to the remaining header bytes to
|
||||
// write out for a pending message. This allows us to tolerate timeout
|
||||
|
@ -561,3 +561,152 @@ func (t *timeoutWriter) Write(p []byte) (int, error) {
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
const payloadSize = 10
|
||||
|
||||
type flushChunk struct {
|
||||
errAfter int64
|
||||
expN int
|
||||
expErr error
|
||||
}
|
||||
|
||||
type flushTest struct {
|
||||
name string
|
||||
chunks []flushChunk
|
||||
}
|
||||
|
||||
var flushTests = []flushTest{
|
||||
{
|
||||
name: "partial header write",
|
||||
chunks: []flushChunk{
|
||||
// Write 18-byte header in two parts, 16 then 2.
|
||||
{
|
||||
errAfter: encHeaderSize - 2,
|
||||
expN: 0,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
{
|
||||
errAfter: 2,
|
||||
expN: 0,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write payload and MAC in one go.
|
||||
{
|
||||
errAfter: -1,
|
||||
expN: payloadSize,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full payload then full mac",
|
||||
chunks: []flushChunk{
|
||||
// Write entire header and entire payload w/o MAC.
|
||||
{
|
||||
errAfter: encHeaderSize + payloadSize,
|
||||
expN: payloadSize,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write the entire MAC.
|
||||
{
|
||||
errAfter: -1,
|
||||
expN: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "payload-only, straddle, mac-only",
|
||||
chunks: []flushChunk{
|
||||
// Write header and all but last byte of payload.
|
||||
{
|
||||
errAfter: encHeaderSize + payloadSize - 1,
|
||||
expN: payloadSize - 1,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write last byte of payload and first byte of MAC.
|
||||
{
|
||||
errAfter: 2,
|
||||
expN: 1,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write 10 bytes of the MAC.
|
||||
{
|
||||
errAfter: 10,
|
||||
expN: 0,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write the remaining 5 MAC bytes.
|
||||
{
|
||||
errAfter: -1,
|
||||
expN: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestFlush asserts a Machine's ability to handle timeouts during Flush that
|
||||
// cause partial writes, and that the machine can properly resume writes on
|
||||
// subsequent calls to Flush.
|
||||
func TestFlush(t *testing.T) {
|
||||
// Run each test individually, to assert that they pass in isolation.
|
||||
for _, test := range flushTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var (
|
||||
w bytes.Buffer
|
||||
b Machine
|
||||
)
|
||||
b.split()
|
||||
testFlush(t, test, &b, &w)
|
||||
})
|
||||
}
|
||||
|
||||
// Finally, run the tests serially as if all on one connection.
|
||||
t.Run("flush serial", func(t *testing.T) {
|
||||
var (
|
||||
w bytes.Buffer
|
||||
b Machine
|
||||
)
|
||||
b.split()
|
||||
for _, test := range flushTests {
|
||||
testFlush(t, test, &b, &w)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// testFlush buffers a message on the Machine, then flushes it to the io.Writer
|
||||
// in chunks. Once complete, a final call to flush is made to assert that Write
|
||||
// is not called again.
|
||||
func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) {
|
||||
payload := make([]byte, payloadSize)
|
||||
if err := b.WriteMessage(payload); err != nil {
|
||||
t.Fatalf("unable to write message: %v", err)
|
||||
}
|
||||
|
||||
for _, chunk := range test.chunks {
|
||||
assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr)
|
||||
}
|
||||
|
||||
// We should always be able to call Flush after a message has been
|
||||
// successfully written, and it should result in a NOP.
|
||||
assertFlush(t, b, w, 0, 0, nil)
|
||||
}
|
||||
|
||||
// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a
|
||||
// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after
|
||||
// n bytes. The method asserts that the returned error matches expErr and that
|
||||
// the number of bytes written by Flush matches expN.
|
||||
func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int,
|
||||
expErr error) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
if n >= 0 {
|
||||
w = NewTimeoutWriter(w, n)
|
||||
}
|
||||
nn, err := b.Flush(w)
|
||||
if err != expErr {
|
||||
t.Fatalf("expected flush err: %v, got: %v", expErr, err)
|
||||
}
|
||||
if nn != expN {
|
||||
t.Fatalf("expected n: %d, got: %d", expN, nn)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user