From 5f94ebbd7df1d658847af1435165422124223826 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 22 Jul 2021 15:43:12 +0200 Subject: [PATCH 1/2] lnrpc: use request context in WebSocket proxy The request context was not properly used to pass it along to the gRPC endpoint which caused streaming calls to still be active on the gRPC side even if the WS side already hung up. We also issue an explicit close on the forwarding writer to signal when the WS side was closed. --- lnrpc/websocket_proxy.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lnrpc/websocket_proxy.go b/lnrpc/websocket_proxy.go index 8c5001daa..9d4d2e9a7 100644 --- a/lnrpc/websocket_proxy.go +++ b/lnrpc/websocket_proxy.go @@ -147,12 +147,12 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, } }() - ctx, cancelFn := context.WithCancel(context.Background()) + ctx, cancelFn := context.WithCancel(r.Context()) defer cancelFn() requestForwarder := newRequestForwardingReader() request, err := http.NewRequestWithContext( - r.Context(), r.Method, r.URL.String(), requestForwarder, + ctx, r.Method, r.URL.String(), requestForwarder, ) if err != nil { p.logger.Errorf("WS: error preparing request:", err) @@ -181,6 +181,7 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, go func() { <-ctx.Done() responseForwarder.Close() + requestForwarder.CloseWriter() }() go func() { From 4b7452a35e3d1ccda0f078767bf208394f3f76fc Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 31 Aug 2021 13:23:35 +0200 Subject: [PATCH 2/2] lnrpc+itest: fix write deadline issue with WS ping Fixes #5680. To make sure we're always reading from the WebSocket connection, we need to always have an ongoing (but blocking) conn.ReadMessage() call going on. To achieve this, we do the read in a separate goroutine and write to a buffered channel. That way we can always read the next message while the current one is being forwarded. This allows incoming ping messages to be received and processed which then leads to the deadlines to be extended correctly. --- docs/release-notes/release-notes-0.14.0.md | 6 + lnrpc/websocket_proxy.go | 60 +++++++++- lntest/itest/lnd_rest_api_test.go | 121 ++++++++++++++++++++- 3 files changed, 179 insertions(+), 8 deletions(-) diff --git a/docs/release-notes/release-notes-0.14.0.md b/docs/release-notes/release-notes-0.14.0.md index 311c92dae..3a352ad08 100644 --- a/docs/release-notes/release-notes-0.14.0.md +++ b/docs/release-notes/release-notes-0.14.0.md @@ -324,6 +324,12 @@ you. * [Fix crash with empty AMP or MPP record in invoice](https://github.com/lightningnetwork/lnd/pull/5743). +* The underlying gRPC connection of a WebSocket is now [properly closed when the + WebSocket end of a connection is + closed](https://github.com/lightningnetwork/lnd/pull/5683). A bug with the + write deadline that caused connections to suddenly break was also fixed in the + same PR. + ## Documentation The [code contribution guidelines have been updated to mention the new diff --git a/lnrpc/websocket_proxy.go b/lnrpc/websocket_proxy.go index 9d4d2e9a7..803cf6c36 100644 --- a/lnrpc/websocket_proxy.go +++ b/lnrpc/websocket_proxy.go @@ -33,6 +33,11 @@ const ( // additional header field and its value. We use the plus symbol because // the default delimiters aren't allowed in the protocol names. WebSocketProtocolDelimiter = "+" + + // PingContent is the content of the ping message we send out. This is + // an arbitrary non-empty message that has no deeper meaning but should + // be sent back by the client in the pong message. + PingContent = "are you there?" ) var ( @@ -189,9 +194,19 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, p.backend.ServeHTTP(responseForwarder, request) }() - // Read loop: Take messages from websocket and write to http request. + // Read loop: Take messages from websocket and write them to the payload + // channel. This needs to be its own goroutine because for non-client + // streaming RPCs, the requestForwarder.Write() in the second goroutine + // will block until the request has fully completed. But for the ping/ + // pong handler to work, we need to have an active call to + // conn.ReadMessage() going on. So we make sure we have such an active + // call by starting a second read as soon as the first one has + // completed. + payloadChannel := make(chan []byte, 1) go func() { defer cancelFn() + defer close(payloadChannel) + for { select { case <-ctx.Done(): @@ -210,6 +225,34 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, err) return } + + select { + case payloadChannel <- payload: + case <-ctx.Done(): + return + } + } + }() + + // Forward loop: Take messages from the incoming payload channel and + // write them to the http request. + go func() { + defer cancelFn() + for { + var payload []byte + select { + case <-ctx.Done(): + return + case newPayload, more := <-payloadChannel: + if !more { + p.logger.Infof("WS: incoming payload " + + "chan closed") + return + } + + payload = newPayload + } + _, err = requestForwarder.Write(payload) if err != nil { p.logger.Errorf("WS: error writing message "+ @@ -239,12 +282,17 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, // Whenever a pong message comes in, we extend the deadline // until the next read is expected by the interval plus pong - // wait time. + // wait time. Since we can never _reach_ any of the deadlines, + // we also have to advance the deadline for the next expected + // write to happen, in case the next thing we actually write is + // the next ping. conn.SetPongHandler(func(appData string) error { nextDeadline := time.Now().Add( p.pingInterval + p.pongWait, ) _ = conn.SetReadDeadline(nextDeadline) + _ = conn.SetWriteDeadline(nextDeadline) + return nil }) go func() { @@ -264,10 +312,10 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, writeDeadline := time.Now().Add( p.pongWait, ) - _ = conn.SetWriteDeadline(writeDeadline) - - err := conn.WriteMessage( - websocket.PingMessage, nil, + err := conn.WriteControl( + websocket.PingMessage, + []byte(PingContent), + writeDeadline, ) if err != nil { p.logger.Warnf("WS: could not "+ diff --git a/lntest/itest/lnd_rest_api_test.go b/lntest/itest/lnd_rest_api_test.go index 96b6dcd89..33e5601a1 100644 --- a/lntest/itest/lnd_rest_api_test.go +++ b/lntest/itest/lnd_rest_api_test.go @@ -43,13 +43,16 @@ var ( } urlEnc = base64.URLEncoding webSocketDialer = &websocket.Dialer{ - HandshakeTimeout: 45 * time.Second, + HandshakeTimeout: time.Second, TLSClientConfig: insecureTransport.TLSClientConfig, } resultPattern = regexp.MustCompile("{\"result\":(.*)}") closeMsg = websocket.FormatCloseMessage( websocket.CloseNormalClosure, "done", ) + + pingInterval = time.Millisecond * 200 + pongWait = time.Millisecond * 50 ) // testRestAPI tests that the most important features of the REST API work @@ -201,11 +204,18 @@ func testRestAPI(net *lntest.NetworkHarness, ht *harnessTest) { }, { name: "websocket bi-directional subscription", run: wsTestCaseBiDirectionalSubscription, + }, { + name: "websocket ping and pong timeout", + run: wsTestPingPongTimeout, }} // Make sure Alice allows all CORS origins. Bob will keep the default. + // We also make sure the ping/pong messages are sent very often, so we + // can test them without waiting half a minute. net.Alice.Cfg.ExtraArgs = append( net.Alice.Cfg.ExtraArgs, "--restcors=\"*\"", + fmt.Sprintf("--ws-ping-interval=%s", pingInterval), + fmt.Sprintf("--ws-pong-wait=%s", pongWait), ) err := net.RestartNode(net.Alice, nil) if err != nil { @@ -432,7 +442,9 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest, _ = conn.Close() }() - msgChan := make(chan *lnrpc.ChannelAcceptResponse) + // Buffer the message channel to make sure we're always blocking on + // conn.ReadMessage() to allow the ping/pong mechanism to work. + msgChan := make(chan *lnrpc.ChannelAcceptResponse, 1) errChan := make(chan error) done := make(chan struct{}) timeout := time.After(defaultTimeout) @@ -522,6 +534,111 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest, close(done) } +func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) { + initialRequest := &lnrpc.InvoiceSubscription{ + AddIndex: 1, SettleIndex: 1, + } + url := "/v1/invoices/subscribe" + + // This time we send the macaroon in the special header + // Sec-Websocket-Protocol which is the only header field available to + // browsers when opening a WebSocket. + mac, err := net.Alice.ReadMacaroon( + net.Alice.AdminMacPath(), defaultTimeout, + ) + require.NoError(ht.t, err, "read admin mac") + macBytes, err := mac.MarshalBinary() + require.NoError(ht.t, err, "marshal admin mac") + + customHeader := make(http.Header) + customHeader.Set(lnrpc.HeaderWebSocketProtocol, fmt.Sprintf( + "Grpc-Metadata-Macaroon+%s", hex.EncodeToString(macBytes), + )) + conn, err := openWebSocket( + net.Alice, url, "GET", initialRequest, customHeader, + ) + require.Nil(ht.t, err, "websocket") + defer func() { + err := conn.WriteMessage(websocket.CloseMessage, closeMsg) + require.NoError(ht.t, err) + _ = conn.Close() + }() + + // We want to be able to read invoices for a long time, making sure we + // can continue to read even after we've gone through several ping/pong + // cycles. + invoices := make(chan *lnrpc.Invoice, 1) + errors := make(chan error) + done := make(chan struct{}) + go func() { + for { + _, msg, err := conn.ReadMessage() + if err != nil { + errors <- err + return + } + + // The chunked/streamed responses come wrapped in either + // a {"result":{}} or {"error":{}} wrapper which we'll + // get rid of here. + msgStr := string(msg) + if !strings.Contains(msgStr, "\"result\":") { + errors <- fmt.Errorf("invalid msg: %s", msgStr) + return + } + msgStr = resultPattern.ReplaceAllString(msgStr, "${1}") + + // Make sure we can parse the unwrapped message into the + // expected proto message. + protoMsg := &lnrpc.Invoice{} + err = jsonpb.UnmarshalString(msgStr, protoMsg) + if err != nil { + errors <- err + return + } + + invoices <- protoMsg + + // Make sure we exit the loop once we've sent through + // all expected test messages. + select { + case <-done: + return + default: + } + } + }() + + // Let's create five invoices and wait for them to arrive. We'll wait + // for at least one ping/pong cycle between each invoice. + ctxb := context.Background() + const numInvoices = 5 + const value = 123 + const memo = "websocket" + for i := 0; i < numInvoices; i++ { + _, err := net.Alice.AddInvoice(ctxb, &lnrpc.Invoice{ + Value: value, + Memo: memo, + }) + require.NoError(ht.t, err) + + select { + case streamMsg := <-invoices: + require.Equal(ht.t, int64(value), streamMsg.Value) + require.Equal(ht.t, memo, streamMsg.Memo) + + case err := <-errors: + require.Fail(ht.t, "Error reading invoice: %v", err) + } + + // Let's wait for at least a whole ping/pong cycle to happen, so + // we can be sure the read/write deadlines are set correctly. + // We double the pong wait just to add some extra margin. + time.Sleep(pingInterval + 2*pongWait) + } + close(done) +} + // invokeGET calls the given URL with the GET method and appropriate macaroon // header fields then tries to unmarshal the response into the given response // proto message.