lnwire+peer: clamp pong bytes, make ping handler more effcieint

This was not properly enforced and would be a spec violation on the
peer's end. Also re-use a pong buffer to save on heap allocations if
there are a lot of peers. The pong buffer is only read from, so this
is concurrent safe.
This commit is contained in:
Olaoluwa Osuntokun 2022-01-10 19:11:59 -08:00
parent ae16f2b631
commit 3481286ea0
No known key found for this signature in database
GPG key ID: 3BBD59E99B280306
6 changed files with 55 additions and 7 deletions

View file

@ -864,6 +864,22 @@ func TestLightningWireProtocol(t *testing.T) {
NewShortChanIDFromInt(uint64(r.Int63())))
}
v[0] = reflect.ValueOf(req)
},
MsgPing: func(v []reflect.Value, r *rand.Rand) {
// We use a special message generator here to ensure we
// don't generate ping messages that are too large,
// which'll cause the test to fail.
//
// We'll allow the test to generate padding bytes up to
// the max message limit, factoring in the 2 bytes for
// the num pong bytes.
paddingBytes := make([]byte, r.Intn(MaxMsgBody-1))
req := Ping{
NumPongBytes: uint16(r.Intn(MaxPongBytes + 1)),
PaddingBytes: paddingBytes,
}
v[0] = reflect.ValueOf(req)
},
}

View file

@ -38,9 +38,16 @@ var _ Message = (*Ping)(nil)
//
// This is part of the lnwire.Message interface.
func (p *Ping) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&p.NumPongBytes,
&p.PaddingBytes)
err := ReadElements(r, &p.NumPongBytes, &p.PaddingBytes)
if err != nil {
return err
}
if p.NumPongBytes > MaxPongBytes {
return ErrMaxPongBytesExceeded
}
return nil
}
// Encode serializes the target Ping into the passed io.Writer observing the

View file

@ -2,9 +2,19 @@ package lnwire
import (
"bytes"
"fmt"
"io"
)
// MaxPongBytes is the maximum number of extra bytes a pong can be requested to
// send. The type of the message (19) takes 2 bytes, the length field takes up
// 2 bytes, leaving 65531 bytes.
const MaxPongBytes = 65531
// ErrMaxPongBytesExceeded indicates that the NumPongBytes field from the ping
// message has exceeded MaxPongBytes.
var ErrMaxPongBytesExceeded = fmt.Errorf("pong bytes exceeded")
// PongPayload is a set of opaque bytes sent in response to a ping message.
type PongPayload []byte

View file

@ -327,6 +327,12 @@ type Config struct {
// from the peer.
HandleCustomMessage func(peer [33]byte, msg *lnwire.Custom) error
// PongBuf is a slice we'll reuse instead of allocating memory on the
// heap. Since only reads will occur and no writes, there is no need
// for any synchronization primitives. As a result, it's safe to share
// this across multiple Peer struct instances.
PongBuf []byte
// Quit is the server's quit channel. If this is closed, we halt operation.
Quit chan struct{}
}
@ -1394,10 +1400,8 @@ out:
// Next, we'll send over the amount of specified pong
// bytes.
//
// TODO(roasbeef): read out from pong scratch instead?
pongBytes := make([]byte, msg.NumPongBytes)
p.queueMsg(lnwire.NewPong(pongBytes), nil)
pong := lnwire.NewPong(p.cfg.PongBuf[0:msg.NumPongBytes])
p.queueMsg(pong, nil)
case *lnwire.OpenChannel,
*lnwire.AcceptChannel,

View file

@ -999,6 +999,7 @@ func TestStaticRemoteDowngrade(t *testing.T) {
Features: test.features,
Conn: mockConn,
WritePool: writePool,
PongBuf: make([]byte, lnwire.MaxPongBytes),
},
}
@ -1103,6 +1104,7 @@ func TestPeerCustomMessage(t *testing.T) {
}
return nil
},
PongBuf: make([]byte, lnwire.MaxPongBytes),
})
// Set up the init sequence.

View file

@ -229,6 +229,12 @@ type server struct {
// intended to replace it.
scheduledPeerConnection map[string]func()
// pongBuf is a shared pong reply buffer we'll use across all active
// peer goroutines. We know the max size of a pong message
// (lnwire.MaxPongBytes), so we can allocate this ahead of time, and
// avoid allocations each time we need to send a pong message.
pongBuf []byte
cc *chainreg.ChainControl
fundingMgr *funding.Manager
@ -582,6 +588,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
peerErrors: make(map[string]*queue.CircularBuffer),
ignorePeerTermination: make(map[*peer.Brontide]struct{}),
scheduledPeerConnection: make(map[string]func()),
pongBuf: make([]byte, lnwire.MaxPongBytes),
peersByPub: make(map[string]*peer.Brontide),
inboundPeers: make(map[string]*peer.Brontide),
@ -3491,6 +3498,8 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq,
DisconnectPeer: s.DisconnectPeer,
GenNodeAnnouncement: s.genNodeAnnouncement,
PongBuf: s.pongBuf,
PrunePersistentPeerConnection: s.prunePersistentPeerConnection,
FetchLastChanUpdate: s.fetchLastChanUpdate(),