From 0410ea7374fef6da238cd94a31002cbd42e793f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christopher=20J=C3=A4mthagen?= Date: Mon, 9 Jan 2017 19:03:19 +0100 Subject: [PATCH] test: Add table driven tests for script_utils Add table-driven tests for testing GetStateHint and SetStateHint in package lnwallet. --- lnwallet/script_utils_test.go | 91 ++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 33 deletions(-) diff --git a/lnwallet/script_utils_test.go b/lnwallet/script_utils_test.go index 62f2da7eb..f4890cdc6 100644 --- a/lnwallet/script_utils_test.go +++ b/lnwallet/script_utils_test.go @@ -563,47 +563,72 @@ func TestHTLCReceiverSpendValidation(t *testing.T) { } } +var stateHintTests = []struct { + name string + from uint64 + to uint64 + inputs int + shouldFail bool +}{ + { + name: "states 0 to 1000", + from: 0, + to: 1000, + inputs: 1, + shouldFail: false, + }, + { + name: "states 'maxStateHint-1000' to 'maxStateHint'", + from: maxStateHint - 1000, + to: maxStateHint, + inputs: 1, + shouldFail: false, + }, + { + name: "state 'maxStateHint+1'", + from: maxStateHint + 1, + to: maxStateHint + 10, + inputs: 1, + shouldFail: true, + }, + { + name: "commit transaction with two inputs", + inputs: 2, + shouldFail: true, + }, +} + func TestCommitTxStateHint(t *testing.T) { - commitTx := wire.NewMsgTx(2) - commitTx.AddTxIn(&wire.TxIn{}) var obsfucator [StateHintSize]byte copy(obsfucator[:], testHdSeed[:StateHintSize]) - for i := 0; i < 10000; i++ { - stateNum := uint64(i) + for _, test := range stateHintTests { + commitTx := wire.NewMsgTx(2) - err := SetStateNumHint(commitTx, stateNum, obsfucator) - if err != nil { - t.Fatalf("unable to set state num %v: %v", i, err) + // Add supplied number of inputs to the commitment transaction. + for i := 0; i < test.inputs; i++ { + commitTx.AddTxIn(&wire.TxIn{}) } - extractedStateNum := GetStateNumHint(commitTx, obsfucator) - if extractedStateNum != stateNum { - t.Fatalf("state number mismatched, expected %v, got %v", - stateNum, extractedStateNum) + for i := test.from; i <= test.to; i++ { + stateNum := uint64(i) + + err := SetStateNumHint(commitTx, stateNum, obsfucator) + if err != nil && !test.shouldFail { + t.Fatalf("unable to set state num %v: %v", i, err) + } else if err == nil && test.shouldFail { + t.Fatalf("Failed(%v): test should fail but did not", test.name) + } + + extractedStateNum := GetStateNumHint(commitTx, obsfucator) + if extractedStateNum != stateNum && !test.shouldFail { + t.Fatalf("state number mismatched, expected %v, got %v", + stateNum, extractedStateNum) + } else if extractedStateNum == stateNum && test.shouldFail { + t.Fatalf("Failed(%v): test should fail but did not", test.name) + } } - - //Test from maximum allowed state - stateNum = uint64(maxStateHint - i) - err = SetStateNumHint(commitTx, stateNum, obsfucator) - if err != nil { - t.Fatalf("unable to set state num %v: %v", i, err) - } - - extractedStateNum = GetStateNumHint(commitTx, obsfucator) - if extractedStateNum != stateNum { - t.Fatalf("state number mismatched, expected %v, got %v", - stateNum, extractedStateNum) - } - } - - if err := SetStateNumHint(commitTx, maxStateHint+1, obsfucator); err == nil { - t.Fatalf("state number should not exceed %X", maxStateHint) - } - - commitTx.AddTxIn(&wire.TxIn{}) - if err := SetStateNumHint(commitTx, 0, obsfucator); err == nil { - t.Fatalf("more than one input in commit transaction should not be valid") + t.Logf("Passed: %v", test.name) } }