package discovery import ( "bytes" "math/rand" "reflect" "testing" "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/require" ) func createTestMessageStore(t *testing.T) *MessageStore { t.Helper() db, err := channeldb.Open(t.TempDir()) if err != nil { t.Fatalf("unable to open db: %v", err) } t.Cleanup(func() { db.Close() }) store, err := NewMessageStore(db) if err != nil { t.Fatalf("unable to initialize message store: %v", err) } return store } func randPubKey(t *testing.T) *btcec.PublicKey { priv, err := btcec.NewPrivateKey() require.NoError(t, err, "unable to create private key") return priv.PubKey() } func randCompressedPubKey(t *testing.T) [33]byte { t.Helper() pubKey := randPubKey(t) var compressedPubKey [33]byte copy(compressedPubKey[:], pubKey.SerializeCompressed()) return compressedPubKey } func randAnnounceSignatures() *lnwire.AnnounceSignatures1 { return &lnwire.AnnounceSignatures1{ ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), ExtraOpaqueData: make([]byte, 0), } } func randChannelUpdate() *lnwire.ChannelUpdate1 { return &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), ExtraOpaqueData: make([]byte, 0), } } // TestMessageStoreMessages ensures that messages can be properly queried from // the store. func TestMessageStoreMessages(t *testing.T) { t.Parallel() // We'll start by creating our test message store. msgStore := createTestMessageStore(t) // We'll then create some test messages for two test peers, and none for // an additional test peer. channelUpdate1 := randChannelUpdate() announceSignatures1 := randAnnounceSignatures() peer1 := randCompressedPubKey(t) if err := msgStore.AddMessage(channelUpdate1, peer1); err != nil { t.Fatalf("unable to add message: %v", err) } if err := msgStore.AddMessage(announceSignatures1, peer1); err != nil { t.Fatalf("unable to add message: %v", err) } expectedPeerMsgs1 := map[uint64]lnwire.MessageType{ channelUpdate1.ShortChannelID.ToUint64(): channelUpdate1.MsgType(), announceSignatures1.ShortChannelID.ToUint64(): announceSignatures1.MsgType(), } channelUpdate2 := randChannelUpdate() peer2 := randCompressedPubKey(t) if err := msgStore.AddMessage(channelUpdate2, peer2); err != nil { t.Fatalf("unable to add message: %v", err) } expectedPeerMsgs2 := map[uint64]lnwire.MessageType{ channelUpdate2.ShortChannelID.ToUint64(): channelUpdate2.MsgType(), } peer3 := randCompressedPubKey(t) expectedPeerMsgs3 := map[uint64]lnwire.MessageType{} // assertPeerMsgs is a helper closure that we'll use to ensure we // retrieve the correct set of messages for a given peer. assertPeerMsgs := func(peerMsgs []lnwire.Message, expected map[uint64]lnwire.MessageType) { t.Helper() if len(peerMsgs) != len(expected) { t.Fatalf("expected %d pending messages, got %d", len(expected), len(peerMsgs)) } for _, msg := range peerMsgs { var shortChanID uint64 switch msg := msg.(type) { case *lnwire.AnnounceSignatures1: shortChanID = msg.ShortChannelID.ToUint64() case *lnwire.ChannelUpdate1: shortChanID = msg.ShortChannelID.ToUint64() default: t.Fatalf("found unexpected message type %T", msg) } msgType, ok := expected[shortChanID] if !ok { t.Fatalf("retrieved message with unexpected ID "+ "%d from store", shortChanID) } if msgType != msg.MsgType() { t.Fatalf("expected message of type %v, got %v", msg.MsgType(), msgType) } } } // Then, we'll query the store for the set of messages for each peer and // ensure it matches what we expect. peers := [][33]byte{peer1, peer2, peer3} expectedPeerMsgs := []map[uint64]lnwire.MessageType{ expectedPeerMsgs1, expectedPeerMsgs2, expectedPeerMsgs3, } for i, peer := range peers { peerMsgs, err := msgStore.MessagesForPeer(peer) if err != nil { t.Fatalf("unable to retrieve messages: %v", err) } assertPeerMsgs(peerMsgs, expectedPeerMsgs[i]) } // Finally, we'll query the store for all of its messages of every peer. // Again, each peer should have a set of messages that match what we // expect. // // We'll construct the expected response. Only the first two peers will // have messages. totalPeerMsgs := make(map[[33]byte]map[uint64]lnwire.MessageType, 2) for i := 0; i < 2; i++ { totalPeerMsgs[peers[i]] = expectedPeerMsgs[i] } msgs, err := msgStore.Messages() if err != nil { t.Fatalf("unable to retrieve all peers with pending messages: "+ "%v", err) } if len(msgs) != len(totalPeerMsgs) { t.Fatalf("expected %d peers with messages, got %d", len(totalPeerMsgs), len(msgs)) } for peer, peerMsgs := range msgs { expected, ok := totalPeerMsgs[peer] if !ok { t.Fatalf("expected to find pending messages for peer %x", peer) } assertPeerMsgs(peerMsgs, expected) } peerPubKeys, err := msgStore.Peers() if err != nil { t.Fatalf("unable to retrieve all peers with pending messages: "+ "%v", err) } if len(peerPubKeys) != len(totalPeerMsgs) { t.Fatalf("expected %d peers with messages, got %d", len(totalPeerMsgs), len(peerPubKeys)) } for peerPubKey := range peerPubKeys { if _, ok := totalPeerMsgs[peerPubKey]; !ok { t.Fatalf("expected to find peer %x", peerPubKey) } } } // TestMessageStoreUnsupportedMessage ensures that we are not able to add a // message which is unsupported, and if a message is found to be unsupported by // the current version of the store, that it is properly filtered out from the // response. func TestMessageStoreUnsupportedMessage(t *testing.T) { t.Parallel() // We'll start by creating our test message store. msgStore := createTestMessageStore(t) // Create a message that is known to not be supported by the store. peer := randCompressedPubKey(t) unsupportedMsg := &lnwire.Error{} // Attempting to add it to the store should result in // ErrUnsupportedMessage. err := msgStore.AddMessage(unsupportedMsg, peer) if err != ErrUnsupportedMessage { t.Fatalf("expected ErrUnsupportedMessage, got %v", err) } // We'll now pretend that the message is actually supported in a future // version of the store, so it's able to be added successfully. To // replicate this, we'll add the message manually rather than through // the existing AddMessage method. msgKey := peer[:] var rawMsg bytes.Buffer if _, err := lnwire.WriteMessage(&rawMsg, unsupportedMsg, 0); err != nil { t.Fatalf("unable to serialize message: %v", err) } err = kvdb.Update(msgStore.db, func(tx kvdb.RwTx) error { messageStore := tx.ReadWriteBucket(messageStoreBucket) return messageStore.Put(msgKey, rawMsg.Bytes()) }, func() {}) require.NoError(t, err, "unable to add unsupported message to store") // Finally, we'll check that the store can properly filter out messages // that are currently unknown to it. We'll make sure this is done for // both Messages and MessagesForPeer. totalMsgs, err := msgStore.Messages() require.NoError(t, err, "unable to retrieve messages") if len(totalMsgs) != 0 { t.Fatalf("expected to filter out unsupported message") } peerMsgs, err := msgStore.MessagesForPeer(peer) require.NoError(t, err, "unable to retrieve peer messages") if len(peerMsgs) != 0 { t.Fatalf("expected to filter out unsupported message") } } // TestMessageStoreDeleteMessage ensures that we can properly delete messages // from the store. func TestMessageStoreDeleteMessage(t *testing.T) { t.Parallel() msgStore := createTestMessageStore(t) // assertMsg is a helper closure we'll use to ensure a message // does/doesn't exist within the store. assertMsg := func(msg lnwire.Message, peer [33]byte, exists bool) { t.Helper() storeMsgs, err := msgStore.MessagesForPeer(peer) if err != nil { t.Fatalf("unable to retrieve messages: %v", err) } found := false for _, storeMsg := range storeMsgs { if reflect.DeepEqual(msg, storeMsg) { found = true } } if found != exists { str := "find" if !exists { str = "not find" } t.Fatalf("expected to %v message %v", str, spew.Sdump(msg)) } } // An AnnounceSignatures message should exist within the store after // adding it, and should no longer exists after deleting it. peer := randCompressedPubKey(t) annSig := randAnnounceSignatures() if err := msgStore.AddMessage(annSig, peer); err != nil { t.Fatalf("unable to add message: %v", err) } assertMsg(annSig, peer, true) if err := msgStore.DeleteMessage(annSig, peer); err != nil { t.Fatalf("unable to delete message: %v", err) } assertMsg(annSig, peer, false) // The store allows overwriting ChannelUpdates, since there can be // multiple versions, so we'll test things slightly different. // // The ChannelUpdate message should exist within the store after adding // it. chanUpdate := randChannelUpdate() if err := msgStore.AddMessage(chanUpdate, peer); err != nil { t.Fatalf("unable to add message: %v", err) } assertMsg(chanUpdate, peer, true) // Now, we'll create a new version for the same ChannelUpdate message. // Adding this one to the store will overwrite the previous one, so only // the new one should exist. newChanUpdate := randChannelUpdate() newChanUpdate.ShortChannelID = chanUpdate.ShortChannelID newChanUpdate.Timestamp = chanUpdate.Timestamp + 1 if err := msgStore.AddMessage(newChanUpdate, peer); err != nil { t.Fatalf("unable to add message: %v", err) } assertMsg(chanUpdate, peer, false) assertMsg(newChanUpdate, peer, true) // Deleting the older message should act as a NOP and should NOT delete // the newer version as the older no longer exists. if err := msgStore.DeleteMessage(chanUpdate, peer); err != nil { t.Fatalf("unable to delete message: %v", err) } assertMsg(chanUpdate, peer, false) assertMsg(newChanUpdate, peer, true) // The newer version should no longer exist within the store after // deleting it. if err := msgStore.DeleteMessage(newChanUpdate, peer); err != nil { t.Fatalf("unable to delete message: %v", err) } assertMsg(newChanUpdate, peer, false) }