lnd/watchtower/wtwire/init_test.go

108 lines
3.4 KiB
Go
Raw Normal View History

package wtwire_test
import (
"testing"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/feature"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
var (
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
mainnetChainHash = *chaincfg.MainNetParams.GenesisHash
)
type checkRemoteInitTest struct {
name string
lFeatures *lnwire.RawFeatureVector
lHash chainhash.Hash
rFeatures *lnwire.RawFeatureVector
rHash chainhash.Hash
expErr error
}
var checkRemoteInitTests = []checkRemoteInitTest{
{
name: "same chain, local-optional remote-required",
lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
rHash: testnetChainHash,
},
{
name: "same chain, local-required remote-optional",
lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
rHash: testnetChainHash,
},
{
name: "different chain, local-optional remote-required",
lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
rHash: mainnetChainHash,
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
},
{
name: "different chain, local-required remote-optional",
lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
rHash: mainnetChainHash,
expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash),
},
{
name: "same chain, remote-unknown-required",
lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired),
rHash: testnetChainHash,
expErr: feature.NewErrUnknownRequired(
[]lnwire.FeatureBit{lnwire.GossipQueriesRequired},
),
},
}
// TestCheckRemoteInit asserts the behavior of CheckRemoteInit when called with
// the remote party's Init message and the default wtwire.Features. We assert
// the validity of advertised features from the perspective of both client and
// server, as well as failure cases such as differing chain hashes or unknown
// required features.
func TestCheckRemoteInit(t *testing.T) {
for _, test := range checkRemoteInitTests {
t.Run(test.name, func(t *testing.T) {
testCheckRemoteInit(t, test)
})
}
}
func testCheckRemoteInit(t *testing.T, test checkRemoteInitTest) {
localInit := wtwire.NewInitMessage(test.lFeatures, test.lHash)
remoteInit := wtwire.NewInitMessage(test.rFeatures, test.rHash)
err := localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
switch {
// Both non-nil, pass.
case err == nil && test.expErr == nil:
return
// One is nil and one is non-nil, fail.
default:
t.Fatalf("error mismatch, want: %v, got: %v", test.expErr, err)
// Both non-nil, assert same error type.
case err != nil && test.expErr != nil:
}
// Compare error strings to assert same type.
if err.Error() != test.expErr.Error() {
t.Fatalf("error mismatch, want: %v, got: %v",
test.expErr.Error(), err.Error())
}
}