lnrpc+itest: fix write deadline issue with WS ping

Fixes #5680.
To make sure we're always reading from the WebSocket connection, we need
to always have an ongoing (but blocking) conn.ReadMessage() call going
on. To achieve this, we do the read in a separate goroutine and write to
a buffered channel. That way we can always read the next message while
the current one is being forwarded. This allows incoming ping messages
to be received and processed which then leads to the deadlines to be
extended correctly.
This commit is contained in:
Oliver Gugger 2021-08-31 13:23:35 +02:00
parent 5f94ebbd7d
commit 4b7452a35e
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
3 changed files with 179 additions and 8 deletions

View File

@ -324,6 +324,12 @@ you.
* [Fix crash with empty AMP or MPP record in
invoice](https://github.com/lightningnetwork/lnd/pull/5743).
* The underlying gRPC connection of a WebSocket is now [properly closed when the
WebSocket end of a connection is
closed](https://github.com/lightningnetwork/lnd/pull/5683). A bug with the
write deadline that caused connections to suddenly break was also fixed in the
same PR.
## Documentation
The [code contribution guidelines have been updated to mention the new

View File

@ -33,6 +33,11 @@ const (
// additional header field and its value. We use the plus symbol because
// the default delimiters aren't allowed in the protocol names.
WebSocketProtocolDelimiter = "+"
// PingContent is the content of the ping message we send out. This is
// an arbitrary non-empty message that has no deeper meaning but should
// be sent back by the client in the pong message.
PingContent = "are you there?"
)
var (
@ -189,9 +194,19 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
p.backend.ServeHTTP(responseForwarder, request)
}()
// Read loop: Take messages from websocket and write to http request.
// Read loop: Take messages from websocket and write them to the payload
// channel. This needs to be its own goroutine because for non-client
// streaming RPCs, the requestForwarder.Write() in the second goroutine
// will block until the request has fully completed. But for the ping/
// pong handler to work, we need to have an active call to
// conn.ReadMessage() going on. So we make sure we have such an active
// call by starting a second read as soon as the first one has
// completed.
payloadChannel := make(chan []byte, 1)
go func() {
defer cancelFn()
defer close(payloadChannel)
for {
select {
case <-ctx.Done():
@ -210,6 +225,34 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
err)
return
}
select {
case payloadChannel <- payload:
case <-ctx.Done():
return
}
}
}()
// Forward loop: Take messages from the incoming payload channel and
// write them to the http request.
go func() {
defer cancelFn()
for {
var payload []byte
select {
case <-ctx.Done():
return
case newPayload, more := <-payloadChannel:
if !more {
p.logger.Infof("WS: incoming payload " +
"chan closed")
return
}
payload = newPayload
}
_, err = requestForwarder.Write(payload)
if err != nil {
p.logger.Errorf("WS: error writing message "+
@ -239,12 +282,17 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
// Whenever a pong message comes in, we extend the deadline
// until the next read is expected by the interval plus pong
// wait time.
// wait time. Since we can never _reach_ any of the deadlines,
// we also have to advance the deadline for the next expected
// write to happen, in case the next thing we actually write is
// the next ping.
conn.SetPongHandler(func(appData string) error {
nextDeadline := time.Now().Add(
p.pingInterval + p.pongWait,
)
_ = conn.SetReadDeadline(nextDeadline)
_ = conn.SetWriteDeadline(nextDeadline)
return nil
})
go func() {
@ -264,10 +312,10 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
writeDeadline := time.Now().Add(
p.pongWait,
)
_ = conn.SetWriteDeadline(writeDeadline)
err := conn.WriteMessage(
websocket.PingMessage, nil,
err := conn.WriteControl(
websocket.PingMessage,
[]byte(PingContent),
writeDeadline,
)
if err != nil {
p.logger.Warnf("WS: could not "+

View File

@ -43,13 +43,16 @@ var (
}
urlEnc = base64.URLEncoding
webSocketDialer = &websocket.Dialer{
HandshakeTimeout: 45 * time.Second,
HandshakeTimeout: time.Second,
TLSClientConfig: insecureTransport.TLSClientConfig,
}
resultPattern = regexp.MustCompile("{\"result\":(.*)}")
closeMsg = websocket.FormatCloseMessage(
websocket.CloseNormalClosure, "done",
)
pingInterval = time.Millisecond * 200
pongWait = time.Millisecond * 50
)
// testRestAPI tests that the most important features of the REST API work
@ -201,11 +204,18 @@ func testRestAPI(net *lntest.NetworkHarness, ht *harnessTest) {
}, {
name: "websocket bi-directional subscription",
run: wsTestCaseBiDirectionalSubscription,
}, {
name: "websocket ping and pong timeout",
run: wsTestPingPongTimeout,
}}
// Make sure Alice allows all CORS origins. Bob will keep the default.
// We also make sure the ping/pong messages are sent very often, so we
// can test them without waiting half a minute.
net.Alice.Cfg.ExtraArgs = append(
net.Alice.Cfg.ExtraArgs, "--restcors=\"*\"",
fmt.Sprintf("--ws-ping-interval=%s", pingInterval),
fmt.Sprintf("--ws-pong-wait=%s", pongWait),
)
err := net.RestartNode(net.Alice, nil)
if err != nil {
@ -432,7 +442,9 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest,
_ = conn.Close()
}()
msgChan := make(chan *lnrpc.ChannelAcceptResponse)
// Buffer the message channel to make sure we're always blocking on
// conn.ReadMessage() to allow the ping/pong mechanism to work.
msgChan := make(chan *lnrpc.ChannelAcceptResponse, 1)
errChan := make(chan error)
done := make(chan struct{})
timeout := time.After(defaultTimeout)
@ -522,6 +534,111 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest,
close(done)
}
func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) {
initialRequest := &lnrpc.InvoiceSubscription{
AddIndex: 1, SettleIndex: 1,
}
url := "/v1/invoices/subscribe"
// This time we send the macaroon in the special header
// Sec-Websocket-Protocol which is the only header field available to
// browsers when opening a WebSocket.
mac, err := net.Alice.ReadMacaroon(
net.Alice.AdminMacPath(), defaultTimeout,
)
require.NoError(ht.t, err, "read admin mac")
macBytes, err := mac.MarshalBinary()
require.NoError(ht.t, err, "marshal admin mac")
customHeader := make(http.Header)
customHeader.Set(lnrpc.HeaderWebSocketProtocol, fmt.Sprintf(
"Grpc-Metadata-Macaroon+%s", hex.EncodeToString(macBytes),
))
conn, err := openWebSocket(
net.Alice, url, "GET", initialRequest, customHeader,
)
require.Nil(ht.t, err, "websocket")
defer func() {
err := conn.WriteMessage(websocket.CloseMessage, closeMsg)
require.NoError(ht.t, err)
_ = conn.Close()
}()
// We want to be able to read invoices for a long time, making sure we
// can continue to read even after we've gone through several ping/pong
// cycles.
invoices := make(chan *lnrpc.Invoice, 1)
errors := make(chan error)
done := make(chan struct{})
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
errors <- err
return
}
// The chunked/streamed responses come wrapped in either
// a {"result":{}} or {"error":{}} wrapper which we'll
// get rid of here.
msgStr := string(msg)
if !strings.Contains(msgStr, "\"result\":") {
errors <- fmt.Errorf("invalid msg: %s", msgStr)
return
}
msgStr = resultPattern.ReplaceAllString(msgStr, "${1}")
// Make sure we can parse the unwrapped message into the
// expected proto message.
protoMsg := &lnrpc.Invoice{}
err = jsonpb.UnmarshalString(msgStr, protoMsg)
if err != nil {
errors <- err
return
}
invoices <- protoMsg
// Make sure we exit the loop once we've sent through
// all expected test messages.
select {
case <-done:
return
default:
}
}
}()
// Let's create five invoices and wait for them to arrive. We'll wait
// for at least one ping/pong cycle between each invoice.
ctxb := context.Background()
const numInvoices = 5
const value = 123
const memo = "websocket"
for i := 0; i < numInvoices; i++ {
_, err := net.Alice.AddInvoice(ctxb, &lnrpc.Invoice{
Value: value,
Memo: memo,
})
require.NoError(ht.t, err)
select {
case streamMsg := <-invoices:
require.Equal(ht.t, int64(value), streamMsg.Value)
require.Equal(ht.t, memo, streamMsg.Memo)
case err := <-errors:
require.Fail(ht.t, "Error reading invoice: %v", err)
}
// Let's wait for at least a whole ping/pong cycle to happen, so
// we can be sure the read/write deadlines are set correctly.
// We double the pong wait just to add some extra margin.
time.Sleep(pingInterval + 2*pongWait)
}
close(done)
}
// invokeGET calls the given URL with the GET method and appropriate macaroon
// header fields then tries to unmarshal the response into the given response
// proto message.