diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index bcf55cb97..01746bc43 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - "reflect" - "io" "math" @@ -60,23 +58,92 @@ func messageToString(msg lnwire.Message) string { return spew.Sdump(msg) } +// expectedMessage struct hols the message which travels from one peer to +// another, and additional information like, should this message we skipped +// for handling. +type expectedMessage struct { + from string + to string + message lnwire.Message + skip bool +} + // createLogFunc is a helper function which returns the function which will be // used for logging message are received from another peer. func createLogFunc(name string, channelID lnwire.ChannelID) messageInterceptor { - return func(m lnwire.Message) { - if getChanID(m) == channelID { + return func(m lnwire.Message) (bool, error) { + chanID, err := getChanID(m) + if err != nil { + return false, err + } + + if chanID == channelID { // Skip logging of extend revocation window messages. switch m := m.(type) { case *lnwire.RevokeAndAck: var zeroHash chainhash.Hash if bytes.Equal(zeroHash[:], m.Revocation[:]) { - return + return false, nil } } fmt.Printf("---------------------- \n %v received: "+ "%v", name, messageToString(m)) } + return false, nil + } +} + +// createInterceptorFunc creates the function by the given set of messages +// which, checks the order of the messages and skip the ones which were +// indicated to be intercepted. +func createInterceptorFunc(peer string, messages []expectedMessage, + chanID lnwire.ChannelID, debug bool) messageInterceptor { + + // Filter message which should be received with given peer name. + var expectToReceive []expectedMessage + for _, message := range messages { + if message.to == peer { + expectToReceive = append(expectToReceive, message) + } + } + + // Return function which checks the message order and skip the + // messages. + return func(m lnwire.Message) (bool, error) { + messageChanID, err := getChanID(m) + if err != nil { + return false, err + } + + if messageChanID == chanID { + if len(expectToReceive) == 0 { + return false, errors.Errorf("received unexpected message out "+ + "of range: %v", m.MsgType()) + } + + expectedMessage := expectToReceive[0] + expectToReceive = expectToReceive[1:] + + if expectedMessage.message.MsgType() != m.MsgType() { + return false, errors.Errorf("%v received wrong message: \n"+ + "real: %v\nexpected: %v", peer, m.MsgType(), + expectedMessage.message.MsgType()) + } + + if debug { + if expectedMessage.skip { + fmt.Printf("'%v' skiped the received message: %v \n", + peer, m.MsgType()) + } else { + fmt.Printf("'%v' received message: %v \n", peer, + m.MsgType()) + } + } + + return expectedMessage.skip, nil + } + return false, nil } } @@ -101,11 +168,11 @@ func TestChannelLinkSingleHopPayment(t *testing.T) { debug := false if debug { // Log message that alice receives. - n.aliceServer.record(createLogFunc("alice", + n.aliceServer.intersect(createLogFunc("alice", n.aliceChannelLink.ChanID())) // Log message that bob receives. - n.bobServer.record(createLogFunc("bob", + n.bobServer.intersect(createLogFunc("bob", n.firstBobChannelLink.ChanID())) } @@ -168,11 +235,11 @@ func TestChannelLinkBidirectionalOneHopPayments(t *testing.T) { debug := false if debug { // Log message that alice receives. - n.aliceServer.record(createLogFunc("alice", + n.aliceServer.intersect(createLogFunc("alice", n.aliceChannelLink.ChanID())) // Log message that bob receives. - n.bobServer.record(createLogFunc("bob", + n.bobServer.intersect(createLogFunc("bob", n.firstBobChannelLink.ChanID())) } @@ -292,19 +359,19 @@ func TestChannelLinkMultiHopPayment(t *testing.T) { debug := false if debug { // Log messages that alice receives from bob. - n.aliceServer.record(createLogFunc("[alice]<-bob<-carol: ", + n.aliceServer.intersect(createLogFunc("[alice]<-bob<-carol: ", n.aliceChannelLink.ChanID())) // Log messages that bob receives from alice. - n.bobServer.record(createLogFunc("alice->[bob]->carol: ", + n.bobServer.intersect(createLogFunc("alice->[bob]->carol: ", n.firstBobChannelLink.ChanID())) // Log messages that bob receives from carol. - n.bobServer.record(createLogFunc("alice<-[bob]<-carol: ", + n.bobServer.intersect(createLogFunc("alice<-[bob]<-carol: ", n.secondBobChannelLink.ChanID())) // Log messages that carol receives from bob. - n.carolServer.record(createLogFunc("alice->bob->[carol]", + n.carolServer.intersect(createLogFunc("alice->bob->[carol]", n.carolChannelLink.ChanID())) } @@ -1105,70 +1172,40 @@ func TestChannelLinkSingleHopMessageOrdering(t *testing.T) { testStartingHeight, ) - chanPoint := n.aliceChannelLink.ChanID() + chanID := n.aliceChannelLink.ChanID() - // The order in which Alice receives wire messages. - var aliceOrder []lnwire.Message - aliceOrder = append(aliceOrder, []lnwire.Message{ - &lnwire.RevokeAndAck{}, - &lnwire.CommitSig{}, - &lnwire.UpdateFufillHTLC{}, - &lnwire.CommitSig{}, - &lnwire.RevokeAndAck{}, - }...) + messages := []expectedMessage{ + {"alice", "bob", &lnwire.UpdateAddHTLC{}, false}, + {"alice", "bob", &lnwire.CommitSig{}, false}, + {"bob", "alice", &lnwire.RevokeAndAck{}, false}, + {"bob", "alice", &lnwire.CommitSig{}, false}, + {"alice", "bob", &lnwire.RevokeAndAck{}, false}, - // The order in which Bob receives wire messages. - var bobOrder []lnwire.Message - bobOrder = append(bobOrder, []lnwire.Message{ - &lnwire.UpdateAddHTLC{}, - &lnwire.CommitSig{}, - &lnwire.RevokeAndAck{}, - &lnwire.RevokeAndAck{}, - &lnwire.CommitSig{}, - }...) + {"bob", "alice", &lnwire.UpdateFufillHTLC{}, false}, + {"bob", "alice", &lnwire.CommitSig{}, false}, + {"alice", "bob", &lnwire.RevokeAndAck{}, false}, + {"alice", "bob", &lnwire.CommitSig{}, false}, + {"bob", "alice", &lnwire.RevokeAndAck{}, false}, + } debug := false if debug { // Log message that alice receives. - n.aliceServer.record(createLogFunc("alice", + n.aliceServer.intersect(createLogFunc("alice", n.aliceChannelLink.ChanID())) // Log message that bob receives. - n.bobServer.record(createLogFunc("bob", + n.bobServer.intersect(createLogFunc("bob", n.firstBobChannelLink.ChanID())) } // Check that alice receives messages in right order. - n.aliceServer.record(func(m lnwire.Message) { - if getChanID(m) == chanPoint { - if len(aliceOrder) == 0 { - t.Fatal("redundant messages") - } - - if reflect.TypeOf(aliceOrder[0]) != reflect.TypeOf(m) { - t.Fatalf("alice received wrong message: \n"+ - "real: %v\n expected: %v", m.MsgType(), - aliceOrder[0].MsgType()) - } - aliceOrder = aliceOrder[1:] - } - }) + n.aliceServer.intersect(createInterceptorFunc("alice", messages, chanID, + false)) // Check that bob receives messages in right order. - n.bobServer.record(func(m lnwire.Message) { - if getChanID(m) == chanPoint { - if len(bobOrder) == 0 { - t.Fatal("redundant messages") - } - - if reflect.TypeOf(bobOrder[0]) != reflect.TypeOf(m) { - t.Fatalf("bob received wrong message: \n"+ - "real: %v\n expected: %v", m.MsgType(), - bobOrder[0].MsgType()) - } - bobOrder = bobOrder[1:] - } - }) + n.bobServer.intersect(createInterceptorFunc("bob", messages, chanID, + false)) if err := n.start(); err != nil { t.Fatalf("unable to start three hop network: %v", err) diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 2c37c480a..2850a1427 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -5,13 +5,14 @@ import ( "encoding/binary" "fmt" "sync" - "testing" "io" "sync/atomic" "bytes" + "testing" + "github.com/btcsuite/fastsha256" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" @@ -39,8 +40,8 @@ type mockServer struct { id [33]byte htlcSwitch *Switch - registry *mockInvoiceRegistry - recordFuncs []func(lnwire.Message) + registry *mockInvoiceRegistry + interceptorFuncs []messageInterceptor } var _ Peer = (*mockServer)(nil) @@ -51,14 +52,14 @@ func newMockServer(t *testing.T, name string) *mockServer { copy(id[:], h[:]) return &mockServer{ - t: t, - id: id, - name: name, - messages: make(chan lnwire.Message, 3000), - quit: make(chan bool), - registry: newMockRegistry(), - htlcSwitch: New(Config{}), - recordFuncs: make([]func(lnwire.Message), 0), + t: t, + id: id, + name: name, + messages: make(chan lnwire.Message, 3000), + quit: make(chan bool), + registry: newMockRegistry(), + htlcSwitch: New(Config{}), + interceptorFuncs: make([]messageInterceptor, 0), } } @@ -76,8 +77,20 @@ func (s *mockServer) Start() error { for { select { case msg := <-s.messages: - for _, f := range s.recordFuncs { - f(msg) + var shouldSkip bool + + for _, interceptor := range s.interceptorFuncs { + skip, err := interceptor(msg) + if err != nil { + s.errChan <- errors.Errorf("%v: error in the "+ + "interceptor: %v", s.name, err) + return + } + shouldSkip = shouldSkip || skip + } + + if shouldSkip { + continue } if err := s.readHandler(msg); err != nil { @@ -245,13 +258,13 @@ func (f *ForwardingInfo) decode(r io.Reader) error { } // messageInterceptor is function that handles the incoming peer messages and -// may decide should we handle it or not. -type messageInterceptor func(m lnwire.Message) +// may decide should the peer skip the message or not. +type messageInterceptor func(m lnwire.Message) (bool, error) // Record is used to set the function which will be triggered when new // lnwire message was received. -func (s *mockServer) record(f messageInterceptor) { - s.recordFuncs = append(s.recordFuncs, f) +func (s *mockServer) intersect(f messageInterceptor) { + s.interceptorFuncs = append(s.interceptorFuncs, f) } func (s *mockServer) SendMessage(message lnwire.Message) error { @@ -297,11 +310,8 @@ func (s *mockServer) readHandler(message lnwire.Message) error { // the server when handler stacked (server unavailable) done := make(chan struct{}) go func() { - defer func() { - done <- struct{}{} - }() - link.HandleChannelUpdate(message) + done <- struct{}{} }() select { case <-done: diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index b2b8d31cd..35bd3fa7d 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -253,22 +253,26 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } // getChanID retrieves the channel point from nwire message. -func getChanID(msg lnwire.Message) lnwire.ChannelID { - var point lnwire.ChannelID +func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) { + var chanID lnwire.ChannelID switch msg := msg.(type) { case *lnwire.UpdateAddHTLC: - point = msg.ChanID + chanID = msg.ChanID case *lnwire.UpdateFufillHTLC: - point = msg.ChanID + chanID = msg.ChanID case *lnwire.UpdateFailHTLC: - point = msg.ChanID + chanID = msg.ChanID case *lnwire.RevokeAndAck: - point = msg.ChanID + chanID = msg.ChanID case *lnwire.CommitSig: - point = msg.ChanID + chanID = msg.ChanID + case *lnwire.ChannelReestablish: + chanID = msg.ChanID + default: + return chanID, errors.New("unknown type") } - return point + return chanID, nil } // generatePayment generates the htlc add request by given path blob and