diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 0743c1704..d7aecee24 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -864,6 +864,22 @@ func TestLightningWireProtocol(t *testing.T) { NewShortChanIDFromInt(uint64(r.Int63()))) } + v[0] = reflect.ValueOf(req) + }, + MsgPing: func(v []reflect.Value, r *rand.Rand) { + // We use a special message generator here to ensure we + // don't generate ping messages that are too large, + // which'll cause the test to fail. + // + // We'll allow the test to generate padding bytes up to + // the max message limit, factoring in the 2 bytes for + // the num pong bytes. + paddingBytes := make([]byte, r.Intn(MaxMsgBody-1)) + req := Ping{ + NumPongBytes: uint16(r.Intn(MaxPongBytes + 1)), + PaddingBytes: paddingBytes, + } + v[0] = reflect.ValueOf(req) }, } diff --git a/lnwire/ping.go b/lnwire/ping.go index 1e2877d0c..a21f2fa8b 100644 --- a/lnwire/ping.go +++ b/lnwire/ping.go @@ -38,9 +38,16 @@ var _ Message = (*Ping)(nil) // // This is part of the lnwire.Message interface. func (p *Ping) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, - &p.NumPongBytes, - &p.PaddingBytes) + err := ReadElements(r, &p.NumPongBytes, &p.PaddingBytes) + if err != nil { + return err + } + + if p.NumPongBytes > MaxPongBytes { + return ErrMaxPongBytesExceeded + } + + return nil } // Encode serializes the target Ping into the passed io.Writer observing the diff --git a/lnwire/pong.go b/lnwire/pong.go index 953eef5f5..3ab80d70f 100644 --- a/lnwire/pong.go +++ b/lnwire/pong.go @@ -2,9 +2,19 @@ package lnwire import ( "bytes" + "fmt" "io" ) +// MaxPongBytes is the maximum number of extra bytes a pong can be requested to +// send. The type of the message (19) takes 2 bytes, the length field takes up +// 2 bytes, leaving 65531 bytes. +const MaxPongBytes = 65531 + +// ErrMaxPongBytesExceeded indicates that the NumPongBytes field from the ping +// message has exceeded MaxPongBytes. +var ErrMaxPongBytesExceeded = fmt.Errorf("pong bytes exceeded") + // PongPayload is a set of opaque bytes sent in response to a ping message. type PongPayload []byte diff --git a/peer/brontide.go b/peer/brontide.go index e58b703dc..51d2dc325 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -327,6 +327,12 @@ type Config struct { // from the peer. HandleCustomMessage func(peer [33]byte, msg *lnwire.Custom) error + // PongBuf is a slice we'll reuse instead of allocating memory on the + // heap. Since only reads will occur and no writes, there is no need + // for any synchronization primitives. As a result, it's safe to share + // this across multiple Peer struct instances. + PongBuf []byte + // Quit is the server's quit channel. If this is closed, we halt operation. Quit chan struct{} } @@ -1394,10 +1400,8 @@ out: // Next, we'll send over the amount of specified pong // bytes. - // - // TODO(roasbeef): read out from pong scratch instead? - pongBytes := make([]byte, msg.NumPongBytes) - p.queueMsg(lnwire.NewPong(pongBytes), nil) + pong := lnwire.NewPong(p.cfg.PongBuf[0:msg.NumPongBytes]) + p.queueMsg(pong, nil) case *lnwire.OpenChannel, *lnwire.AcceptChannel, diff --git a/peer/brontide_test.go b/peer/brontide_test.go index 825a767f0..a84e19581 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -999,6 +999,7 @@ func TestStaticRemoteDowngrade(t *testing.T) { Features: test.features, Conn: mockConn, WritePool: writePool, + PongBuf: make([]byte, lnwire.MaxPongBytes), }, } @@ -1103,6 +1104,7 @@ func TestPeerCustomMessage(t *testing.T) { } return nil }, + PongBuf: make([]byte, lnwire.MaxPongBytes), }) // Set up the init sequence. diff --git a/server.go b/server.go index d86e27a4f..1aceebc14 100644 --- a/server.go +++ b/server.go @@ -229,6 +229,12 @@ type server struct { // intended to replace it. scheduledPeerConnection map[string]func() + // pongBuf is a shared pong reply buffer we'll use across all active + // peer goroutines. We know the max size of a pong message + // (lnwire.MaxPongBytes), so we can allocate this ahead of time, and + // avoid allocations each time we need to send a pong message. + pongBuf []byte + cc *chainreg.ChainControl fundingMgr *funding.Manager @@ -582,6 +588,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, peerErrors: make(map[string]*queue.CircularBuffer), ignorePeerTermination: make(map[*peer.Brontide]struct{}), scheduledPeerConnection: make(map[string]func()), + pongBuf: make([]byte, lnwire.MaxPongBytes), peersByPub: make(map[string]*peer.Brontide), inboundPeers: make(map[string]*peer.Brontide), @@ -3491,6 +3498,8 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, DisconnectPeer: s.DisconnectPeer, GenNodeAnnouncement: s.genNodeAnnouncement, + PongBuf: s.pongBuf, + PrunePersistentPeerConnection: s.prunePersistentPeerConnection, FetchLastChanUpdate: s.fetchLastChanUpdate(),