diff --git a/discovery/message_store.go b/discovery/message_store.go new file mode 100644 index 000000000..e0c10a865 --- /dev/null +++ b/discovery/message_store.go @@ -0,0 +1,294 @@ +package discovery + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // messageStoreBucket is a key used to create a top level bucket in the + // gossiper database, used for storing messages that are to be sent to + // peers. Upon restarts, these messages will be read and resent to their + // respective peers. + // + // maps: + // pubKey (33 bytes) + msgShortChanID (8 bytes) + msgType (2 bytes) -> msg + messageStoreBucket = []byte("message-store") + + // ErrUnsupportedMessage is an error returned when we attempt to add a + // message to the store that is not supported. + ErrUnsupportedMessage = errors.New("unsupported message type") + + // ErrCorruptedMessageStore indicates that the on-disk bucketing + // structure has altered since the gossip message store instance was + // initialized. + ErrCorruptedMessageStore = errors.New("gossip message store has been " + + "corrupted") +) + +// GossipMessageStore is a store responsible for storing gossip messages which +// we should reliably send to our peers. +type GossipMessageStore interface { + // AddMessage adds a message to the store for this peer. + AddMessage(lnwire.Message, [33]byte) error + + // DeleteMessage deletes a message from the store for this peer. + DeleteMessage(lnwire.Message, [33]byte) error + + // Messages returns the total set of messages that exist within the + // store for all peers. + Messages() (map[[33]byte][]lnwire.Message, error) + + // Peers returns the public key of all peers with messages within the + // store. + Peers() (map[[33]byte]struct{}, error) + + // MessagesForPeer returns the set of messages that exists within the + // store for the given peer. + MessagesForPeer([33]byte) ([]lnwire.Message, error) +} + +// MessageStore is an implementation of the GossipMessageStore interface backed +// by a channeldb instance. By design, this store will only keep the latest +// version of a message (like in the case of multiple ChannelUpdate's) for a +// channel with a peer. +type MessageStore struct { + db *channeldb.DB +} + +// A compile-time assertion to ensure messageStore implements the +// GossipMessageStore interface. +var _ GossipMessageStore = (*MessageStore)(nil) + +// NewMessageStore creates a new message store backed by a channeldb instance. +func NewMessageStore(db *channeldb.DB) (*MessageStore, error) { + err := db.Update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists(messageStoreBucket) + return err + }) + if err != nil { + return nil, fmt.Errorf("unable to create required buckets: %v", + err) + } + + return &MessageStore{db}, nil +} + +// msgShortChanID retrieves the short channel ID of the message. +func msgShortChanID(msg lnwire.Message) (lnwire.ShortChannelID, error) { + var shortChanID lnwire.ShortChannelID + switch msg := msg.(type) { + case *lnwire.AnnounceSignatures: + shortChanID = msg.ShortChannelID + case *lnwire.ChannelUpdate: + shortChanID = msg.ShortChannelID + default: + return shortChanID, ErrUnsupportedMessage + } + + return shortChanID, nil +} + +// messageStoreKey constructs the database key for the message to be stored. +func messageStoreKey(msg lnwire.Message, peerPubKey [33]byte) ([]byte, error) { + shortChanID, err := msgShortChanID(msg) + if err != nil { + return nil, err + } + + var k [33 + 8 + 2]byte + copy(k[:33], peerPubKey[:]) + binary.BigEndian.PutUint64(k[33:41], shortChanID.ToUint64()) + binary.BigEndian.PutUint16(k[41:43], uint16(msg.MsgType())) + + return k[:], nil +} + +// AddMessage adds a message to the store for this peer. +func (s *MessageStore) AddMessage(msg lnwire.Message, peerPubKey [33]byte) error { + // Construct the key for which we'll find this message with in the store. + msgKey, err := messageStoreKey(msg, peerPubKey) + if err != nil { + return err + } + + // Serialize the message with its wire encoding. + var b bytes.Buffer + if _, err := lnwire.WriteMessage(&b, msg, 0); err != nil { + return err + } + + return s.db.Batch(func(tx *bbolt.Tx) error { + messageStore := tx.Bucket(messageStoreBucket) + if messageStore == nil { + return ErrCorruptedMessageStore + } + + return messageStore.Put(msgKey, b.Bytes()) + }) +} + +// DeleteMessage deletes a message from the store for this peer. +func (s *MessageStore) DeleteMessage(msg lnwire.Message, + peerPubKey [33]byte) error { + + // Construct the key for which we'll find this message with in the + // store. + msgKey, err := messageStoreKey(msg, peerPubKey) + if err != nil { + return err + } + + return s.db.Batch(func(tx *bbolt.Tx) error { + messageStore := tx.Bucket(messageStoreBucket) + if messageStore == nil { + return ErrCorruptedMessageStore + } + + // In the event that we're attempting to delete a ChannelUpdate + // from the store, we'll make sure that we're actually deleting + // the correct one as it can be overwritten. + if msg, ok := msg.(*lnwire.ChannelUpdate); ok { + // Deleting a value from a bucket that doesn't exist + // acts as a NOP, so we'll return if a message doesn't + // exist under this key. + v := messageStore.Get(msgKey) + if v == nil { + return nil + } + + dbMsg, err := lnwire.ReadMessage(bytes.NewReader(v), 0) + if err != nil { + return err + } + + // If the timestamps don't match, then the update stored + // should be the latest one, so we'll avoid deleting it. + if msg.Timestamp != dbMsg.(*lnwire.ChannelUpdate).Timestamp { + return nil + } + } + + return messageStore.Delete(msgKey) + }) +} + +// readMessage reads a message from its serialized form and ensures its +// supported by the current version of the message store. +func readMessage(msgBytes []byte) (lnwire.Message, error) { + msg, err := lnwire.ReadMessage(bytes.NewReader(msgBytes), 0) + if err != nil { + return nil, err + } + + // Check if the message is supported by the store. We can reuse the + // check for ShortChannelID as its a dependency on messages stored. + if _, err := msgShortChanID(msg); err != nil { + return nil, err + } + + return msg, nil +} + +// Messages returns the total set of messages that exist within the store for +// all peers. +func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) { + msgs := make(map[[33]byte][]lnwire.Message) + err := s.db.View(func(tx *bbolt.Tx) error { + messageStore := tx.Bucket(messageStoreBucket) + if messageStore == nil { + return ErrCorruptedMessageStore + } + + return messageStore.ForEach(func(k, v []byte) error { + var pubKey [33]byte + copy(pubKey[:], k[:33]) + + // Deserialize the message from its raw bytes and filter + // out any which are not currently supported by the + // store. + msg, err := readMessage(v) + if err == ErrUnsupportedMessage { + return nil + } + if err != nil { + return err + } + + msgs[pubKey] = append(msgs[pubKey], msg) + return nil + }) + }) + if err != nil { + return nil, err + } + + return msgs, nil +} + +// MessagesForPeer returns the set of messages that exists within the store for +// the given peer. +func (s *MessageStore) MessagesForPeer( + peerPubKey [33]byte) ([]lnwire.Message, error) { + + var msgs []lnwire.Message + err := s.db.View(func(tx *bbolt.Tx) error { + messageStore := tx.Bucket(messageStoreBucket) + if messageStore == nil { + return ErrCorruptedMessageStore + } + + c := messageStore.Cursor() + k, v := c.Seek(peerPubKey[:]) + for ; bytes.HasPrefix(k, peerPubKey[:]); k, v = c.Next() { + // Deserialize the message from its raw bytes and filter + // out any which are not currently supported by the + // store. + msg, err := readMessage(v) + if err == ErrUnsupportedMessage { + continue + } + if err != nil { + return err + } + + msgs = append(msgs, msg) + } + + return nil + }) + if err != nil { + return nil, err + } + + return msgs, nil +} + +// Peers returns the public key of all peers with messages within the store. +func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) { + peers := make(map[[33]byte]struct{}) + err := s.db.View(func(tx *bbolt.Tx) error { + messageStore := tx.Bucket(messageStoreBucket) + if messageStore == nil { + return ErrCorruptedMessageStore + } + + return messageStore.ForEach(func(k, _ []byte) error { + var pubKey [33]byte + copy(pubKey[:], k[:33]) + peers[pubKey] = struct{}{} + return nil + }) + }) + if err != nil { + return nil, err + } + + return peers, nil +} diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go new file mode 100644 index 000000000..a106ad225 --- /dev/null +++ b/discovery/message_store_test.go @@ -0,0 +1,351 @@ +package discovery + +import ( + "bytes" + "io/ioutil" + "math/rand" + "os" + "reflect" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" +) + +func createTestMessageStore(t *testing.T) (*MessageStore, func()) { + t.Helper() + + tempDir, err := ioutil.TempDir("", "channeldb") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + db, err := channeldb.Open(tempDir) + if err != nil { + os.RemoveAll(tempDir) + t.Fatalf("unable to open db: %v", err) + } + + cleanUp := func() { + db.Close() + os.RemoveAll(tempDir) + } + + store, err := NewMessageStore(db) + if err != nil { + cleanUp() + t.Fatalf("unable to initialize message store: %v", err) + } + + return store, cleanUp +} + +func randPubKey(t *testing.T) *btcec.PublicKey { + priv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to create private key: %v", err) + } + + return priv.PubKey() +} + +func randCompressedPubKey(t *testing.T) [33]byte { + t.Helper() + + pubKey := randPubKey(t) + + var compressedPubKey [33]byte + copy(compressedPubKey[:], pubKey.SerializeCompressed()) + + return compressedPubKey +} + +func randAnnounceSignatures() *lnwire.AnnounceSignatures { + return &lnwire.AnnounceSignatures{ + ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + } +} + +func randChannelUpdate() *lnwire.ChannelUpdate { + return &lnwire.ChannelUpdate{ + ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + } +} + +// TestMessageStoreMessages ensures that messages can be properly queried from +// the store. +func TestMessageStoreMessages(t *testing.T) { + t.Parallel() + + // We'll start by creating our test message store. + msgStore, cleanUp := createTestMessageStore(t) + defer cleanUp() + + // We'll then create some test messages for two test peers, and none for + // an additional test peer. + channelUpdate1 := randChannelUpdate() + announceSignatures1 := randAnnounceSignatures() + peer1 := randCompressedPubKey(t) + if err := msgStore.AddMessage(channelUpdate1, peer1); err != nil { + t.Fatalf("unable to add message: %v", err) + } + if err := msgStore.AddMessage(announceSignatures1, peer1); err != nil { + t.Fatalf("unable to add message: %v", err) + } + expectedPeerMsgs1 := map[uint64]lnwire.MessageType{ + channelUpdate1.ShortChannelID.ToUint64(): channelUpdate1.MsgType(), + announceSignatures1.ShortChannelID.ToUint64(): announceSignatures1.MsgType(), + } + + channelUpdate2 := randChannelUpdate() + peer2 := randCompressedPubKey(t) + if err := msgStore.AddMessage(channelUpdate2, peer2); err != nil { + t.Fatalf("unable to add message: %v", err) + } + expectedPeerMsgs2 := map[uint64]lnwire.MessageType{ + channelUpdate2.ShortChannelID.ToUint64(): channelUpdate2.MsgType(), + } + + peer3 := randCompressedPubKey(t) + expectedPeerMsgs3 := map[uint64]lnwire.MessageType{} + + // assertPeerMsgs is a helper closure that we'll use to ensure we + // retrieve the correct set of messages for a given peer. + assertPeerMsgs := func(peerMsgs []lnwire.Message, + expected map[uint64]lnwire.MessageType) { + + t.Helper() + + if len(peerMsgs) != len(expected) { + t.Fatalf("expected %d pending messages, got %d", + len(expected), len(peerMsgs)) + } + for _, msg := range peerMsgs { + var shortChanID uint64 + switch msg := msg.(type) { + case *lnwire.AnnounceSignatures: + shortChanID = msg.ShortChannelID.ToUint64() + case *lnwire.ChannelUpdate: + shortChanID = msg.ShortChannelID.ToUint64() + default: + t.Fatalf("found unexpected message type %T", msg) + } + + msgType, ok := expected[shortChanID] + if !ok { + t.Fatalf("retrieved message with unexpected ID "+ + "%d from store", shortChanID) + } + if msgType != msg.MsgType() { + t.Fatalf("expected message of type %v, got %v", + msg.MsgType(), msgType) + } + } + } + + // Then, we'll query the store for the set of messages for each peer and + // ensure it matches what we expect. + peers := [][33]byte{peer1, peer2, peer3} + expectedPeerMsgs := []map[uint64]lnwire.MessageType{ + expectedPeerMsgs1, expectedPeerMsgs2, expectedPeerMsgs3, + } + for i, peer := range peers { + peerMsgs, err := msgStore.MessagesForPeer(peer) + if err != nil { + t.Fatalf("unable to retrieve messages: %v", err) + } + assertPeerMsgs(peerMsgs, expectedPeerMsgs[i]) + } + + // Finally, we'll query the store for all of its messages of every peer. + // Again, each peer should have a set of messages that match what we + // expect. + // + // We'll construct the expected response. Only the first two peers will + // have messages. + totalPeerMsgs := make(map[[33]byte]map[uint64]lnwire.MessageType, 2) + for i := 0; i < 2; i++ { + totalPeerMsgs[peers[i]] = expectedPeerMsgs[i] + } + + msgs, err := msgStore.Messages() + if err != nil { + t.Fatalf("unable to retrieve all peers with pending messages: "+ + "%v", err) + } + if len(msgs) != len(totalPeerMsgs) { + t.Fatalf("expected %d peers with messages, got %d", + len(totalPeerMsgs), len(msgs)) + } + for peer, peerMsgs := range msgs { + expected, ok := totalPeerMsgs[peer] + if !ok { + t.Fatalf("expected to find pending messages for peer %x", + peer) + } + + assertPeerMsgs(peerMsgs, expected) + } + + peerPubKeys, err := msgStore.Peers() + if err != nil { + t.Fatalf("unable to retrieve all peers with pending messages: "+ + "%v", err) + } + if len(peerPubKeys) != len(totalPeerMsgs) { + t.Fatalf("expected %d peers with messages, got %d", + len(totalPeerMsgs), len(peerPubKeys)) + } + for peerPubKey := range peerPubKeys { + if _, ok := totalPeerMsgs[peerPubKey]; !ok { + t.Fatalf("expected to find peer %x", peerPubKey) + } + } +} + +// TestMessageStoreUnsupportedMessage ensures that we are not able to add a +// message which is unsupported, and if a message is found to be unsupported by +// the current version of the store, that it is properly filtered out from the +// response. +func TestMessageStoreUnsupportedMessage(t *testing.T) { + t.Parallel() + + // We'll start by creating our test message store. + msgStore, cleanUp := createTestMessageStore(t) + defer cleanUp() + + // Create a message that is known to not be supported by the store. + peer := randCompressedPubKey(t) + unsupportedMsg := &lnwire.Error{} + + // Attempting to add it to the store should result in + // ErrUnsupportedMessage. + err := msgStore.AddMessage(unsupportedMsg, peer) + if err != ErrUnsupportedMessage { + t.Fatalf("expected ErrUnsupportedMessage, got %v", err) + } + + // We'll now pretend that the message is actually supported in a future + // version of the store, so it's able to be added successfully. To + // replicate this, we'll add the message manually rather than through + // the existing AddMessage method. + msgKey := peer[:] + var rawMsg bytes.Buffer + if _, err := lnwire.WriteMessage(&rawMsg, unsupportedMsg, 0); err != nil { + t.Fatalf("unable to serialize message: %v", err) + } + err = msgStore.db.Update(func(tx *bbolt.Tx) error { + messageStore := tx.Bucket(messageStoreBucket) + return messageStore.Put(msgKey, rawMsg.Bytes()) + }) + if err != nil { + t.Fatalf("unable to add unsupported message to store: %v", err) + } + + // Finally, we'll check that the store can properly filter out messages + // that are currently unknown to it. We'll make sure this is done for + // both Messages and MessagesForPeer. + totalMsgs, err := msgStore.Messages() + if err != nil { + t.Fatalf("unable to retrieve messages: %v", err) + } + if len(totalMsgs) != 0 { + t.Fatalf("expected to filter out unsupported message") + } + peerMsgs, err := msgStore.MessagesForPeer(peer) + if err != nil { + t.Fatalf("unable to retrieve peer messages: %v", err) + } + if len(peerMsgs) != 0 { + t.Fatalf("expected to filter out unsupported message") + } +} + +// TestMessageStoreDeleteMessage ensures that we can properly delete messages +// from the store. +func TestMessageStoreDeleteMessage(t *testing.T) { + t.Parallel() + + msgStore, cleanUp := createTestMessageStore(t) + defer cleanUp() + + // assertMsg is a helper closure we'll use to ensure a message + // does/doesn't exist within the store. + assertMsg := func(msg lnwire.Message, peer [33]byte, exists bool) { + t.Helper() + + storeMsgs, err := msgStore.MessagesForPeer(peer) + if err != nil { + t.Fatalf("unable to retrieve messages: %v", err) + } + + found := false + for _, storeMsg := range storeMsgs { + if reflect.DeepEqual(msg, storeMsg) { + found = true + } + } + + if found != exists { + str := "find" + if !exists { + str = "not find" + } + t.Fatalf("expected to %v message %v", str, + spew.Sdump(msg)) + } + } + + // An AnnounceSignatures message should exist within the store after + // adding it, and should no longer exists after deleting it. + peer := randCompressedPubKey(t) + annSig := randAnnounceSignatures() + if err := msgStore.AddMessage(annSig, peer); err != nil { + t.Fatalf("unable to add message: %v", err) + } + assertMsg(annSig, peer, true) + if err := msgStore.DeleteMessage(annSig, peer); err != nil { + t.Fatalf("unable to delete message: %v", err) + } + assertMsg(annSig, peer, false) + + // The store allows overwriting ChannelUpdates, since there can be + // multiple versions, so we'll test things slightly different. + // + // The ChannelUpdate message should exist within the store after adding + // it. + chanUpdate := randChannelUpdate() + if err := msgStore.AddMessage(chanUpdate, peer); err != nil { + t.Fatalf("unable to add message: %v", err) + } + assertMsg(chanUpdate, peer, true) + + // Now, we'll create a new version for the same ChannelUpdate message. + // Adding this one to the store will overwrite the previous one, so only + // the new one should exist. + newChanUpdate := randChannelUpdate() + newChanUpdate.ShortChannelID = chanUpdate.ShortChannelID + newChanUpdate.Timestamp = chanUpdate.Timestamp + 1 + if err := msgStore.AddMessage(newChanUpdate, peer); err != nil { + t.Fatalf("unable to add message: %v", err) + } + assertMsg(chanUpdate, peer, false) + assertMsg(newChanUpdate, peer, true) + + // Deleting the older message should act as a NOP and should NOT delete + // the newer version as the older no longer exists. + if err := msgStore.DeleteMessage(chanUpdate, peer); err != nil { + t.Fatalf("unable to delete message: %v", err) + } + assertMsg(chanUpdate, peer, false) + assertMsg(newChanUpdate, peer, true) + + // The newer version should no longer exist within the store after + // deleting it. + if err := msgStore.DeleteMessage(newChanUpdate, peer); err != nil { + t.Fatalf("unable to delete message: %v", err) + } + assertMsg(newChanUpdate, peer, false) +}