Merge pull request #5683 from guggero/websocket-write-deadline

lnrpc: Fix WebSocket write deadline not being extended
This commit is contained in:
Olaoluwa Osuntokun 2021-09-20 17:07:41 -07:00 committed by GitHub
commit 5e6532594c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 182 additions and 10 deletions

View File

@ -324,6 +324,12 @@ you.
* [Fix crash with empty AMP or MPP record in * [Fix crash with empty AMP or MPP record in
invoice](https://github.com/lightningnetwork/lnd/pull/5743). invoice](https://github.com/lightningnetwork/lnd/pull/5743).
* The underlying gRPC connection of a WebSocket is now [properly closed when the
WebSocket end of a connection is
closed](https://github.com/lightningnetwork/lnd/pull/5683). A bug with the
write deadline that caused connections to suddenly break was also fixed in the
same PR.
## Documentation ## Documentation
The [code contribution guidelines have been updated to mention the new The [code contribution guidelines have been updated to mention the new

View File

@ -33,6 +33,11 @@ const (
// additional header field and its value. We use the plus symbol because // additional header field and its value. We use the plus symbol because
// the default delimiters aren't allowed in the protocol names. // the default delimiters aren't allowed in the protocol names.
WebSocketProtocolDelimiter = "+" WebSocketProtocolDelimiter = "+"
// PingContent is the content of the ping message we send out. This is
// an arbitrary non-empty message that has no deeper meaning but should
// be sent back by the client in the pong message.
PingContent = "are you there?"
) )
var ( var (
@ -147,12 +152,12 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
} }
}() }()
ctx, cancelFn := context.WithCancel(context.Background()) ctx, cancelFn := context.WithCancel(r.Context())
defer cancelFn() defer cancelFn()
requestForwarder := newRequestForwardingReader() requestForwarder := newRequestForwardingReader()
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
r.Context(), r.Method, r.URL.String(), requestForwarder, ctx, r.Method, r.URL.String(), requestForwarder,
) )
if err != nil { if err != nil {
p.logger.Errorf("WS: error preparing request:", err) p.logger.Errorf("WS: error preparing request:", err)
@ -181,6 +186,7 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
go func() { go func() {
<-ctx.Done() <-ctx.Done()
responseForwarder.Close() responseForwarder.Close()
requestForwarder.CloseWriter()
}() }()
go func() { go func() {
@ -188,9 +194,19 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
p.backend.ServeHTTP(responseForwarder, request) p.backend.ServeHTTP(responseForwarder, request)
}() }()
// Read loop: Take messages from websocket and write to http request. // Read loop: Take messages from websocket and write them to the payload
// channel. This needs to be its own goroutine because for non-client
// streaming RPCs, the requestForwarder.Write() in the second goroutine
// will block until the request has fully completed. But for the ping/
// pong handler to work, we need to have an active call to
// conn.ReadMessage() going on. So we make sure we have such an active
// call by starting a second read as soon as the first one has
// completed.
payloadChannel := make(chan []byte, 1)
go func() { go func() {
defer cancelFn() defer cancelFn()
defer close(payloadChannel)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -209,6 +225,34 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
err) err)
return return
} }
select {
case payloadChannel <- payload:
case <-ctx.Done():
return
}
}
}()
// Forward loop: Take messages from the incoming payload channel and
// write them to the http request.
go func() {
defer cancelFn()
for {
var payload []byte
select {
case <-ctx.Done():
return
case newPayload, more := <-payloadChannel:
if !more {
p.logger.Infof("WS: incoming payload " +
"chan closed")
return
}
payload = newPayload
}
_, err = requestForwarder.Write(payload) _, err = requestForwarder.Write(payload)
if err != nil { if err != nil {
p.logger.Errorf("WS: error writing message "+ p.logger.Errorf("WS: error writing message "+
@ -238,12 +282,17 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
// Whenever a pong message comes in, we extend the deadline // Whenever a pong message comes in, we extend the deadline
// until the next read is expected by the interval plus pong // until the next read is expected by the interval plus pong
// wait time. // wait time. Since we can never _reach_ any of the deadlines,
// we also have to advance the deadline for the next expected
// write to happen, in case the next thing we actually write is
// the next ping.
conn.SetPongHandler(func(appData string) error { conn.SetPongHandler(func(appData string) error {
nextDeadline := time.Now().Add( nextDeadline := time.Now().Add(
p.pingInterval + p.pongWait, p.pingInterval + p.pongWait,
) )
_ = conn.SetReadDeadline(nextDeadline) _ = conn.SetReadDeadline(nextDeadline)
_ = conn.SetWriteDeadline(nextDeadline)
return nil return nil
}) })
go func() { go func() {
@ -263,10 +312,10 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
writeDeadline := time.Now().Add( writeDeadline := time.Now().Add(
p.pongWait, p.pongWait,
) )
_ = conn.SetWriteDeadline(writeDeadline) err := conn.WriteControl(
websocket.PingMessage,
err := conn.WriteMessage( []byte(PingContent),
websocket.PingMessage, nil, writeDeadline,
) )
if err != nil { if err != nil {
p.logger.Warnf("WS: could not "+ p.logger.Warnf("WS: could not "+

View File

@ -43,13 +43,16 @@ var (
} }
urlEnc = base64.URLEncoding urlEnc = base64.URLEncoding
webSocketDialer = &websocket.Dialer{ webSocketDialer = &websocket.Dialer{
HandshakeTimeout: 45 * time.Second, HandshakeTimeout: time.Second,
TLSClientConfig: insecureTransport.TLSClientConfig, TLSClientConfig: insecureTransport.TLSClientConfig,
} }
resultPattern = regexp.MustCompile("{\"result\":(.*)}") resultPattern = regexp.MustCompile("{\"result\":(.*)}")
closeMsg = websocket.FormatCloseMessage( closeMsg = websocket.FormatCloseMessage(
websocket.CloseNormalClosure, "done", websocket.CloseNormalClosure, "done",
) )
pingInterval = time.Millisecond * 200
pongWait = time.Millisecond * 50
) )
// testRestAPI tests that the most important features of the REST API work // testRestAPI tests that the most important features of the REST API work
@ -201,11 +204,18 @@ func testRestAPI(net *lntest.NetworkHarness, ht *harnessTest) {
}, { }, {
name: "websocket bi-directional subscription", name: "websocket bi-directional subscription",
run: wsTestCaseBiDirectionalSubscription, run: wsTestCaseBiDirectionalSubscription,
}, {
name: "websocket ping and pong timeout",
run: wsTestPingPongTimeout,
}} }}
// Make sure Alice allows all CORS origins. Bob will keep the default. // Make sure Alice allows all CORS origins. Bob will keep the default.
// We also make sure the ping/pong messages are sent very often, so we
// can test them without waiting half a minute.
net.Alice.Cfg.ExtraArgs = append( net.Alice.Cfg.ExtraArgs = append(
net.Alice.Cfg.ExtraArgs, "--restcors=\"*\"", net.Alice.Cfg.ExtraArgs, "--restcors=\"*\"",
fmt.Sprintf("--ws-ping-interval=%s", pingInterval),
fmt.Sprintf("--ws-pong-wait=%s", pongWait),
) )
err := net.RestartNode(net.Alice, nil) err := net.RestartNode(net.Alice, nil)
if err != nil { if err != nil {
@ -432,7 +442,9 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest,
_ = conn.Close() _ = conn.Close()
}() }()
msgChan := make(chan *lnrpc.ChannelAcceptResponse) // Buffer the message channel to make sure we're always blocking on
// conn.ReadMessage() to allow the ping/pong mechanism to work.
msgChan := make(chan *lnrpc.ChannelAcceptResponse, 1)
errChan := make(chan error) errChan := make(chan error)
done := make(chan struct{}) done := make(chan struct{})
timeout := time.After(defaultTimeout) timeout := time.After(defaultTimeout)
@ -522,6 +534,111 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest,
close(done) close(done)
} }
func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) {
initialRequest := &lnrpc.InvoiceSubscription{
AddIndex: 1, SettleIndex: 1,
}
url := "/v1/invoices/subscribe"
// This time we send the macaroon in the special header
// Sec-Websocket-Protocol which is the only header field available to
// browsers when opening a WebSocket.
mac, err := net.Alice.ReadMacaroon(
net.Alice.AdminMacPath(), defaultTimeout,
)
require.NoError(ht.t, err, "read admin mac")
macBytes, err := mac.MarshalBinary()
require.NoError(ht.t, err, "marshal admin mac")
customHeader := make(http.Header)
customHeader.Set(lnrpc.HeaderWebSocketProtocol, fmt.Sprintf(
"Grpc-Metadata-Macaroon+%s", hex.EncodeToString(macBytes),
))
conn, err := openWebSocket(
net.Alice, url, "GET", initialRequest, customHeader,
)
require.Nil(ht.t, err, "websocket")
defer func() {
err := conn.WriteMessage(websocket.CloseMessage, closeMsg)
require.NoError(ht.t, err)
_ = conn.Close()
}()
// We want to be able to read invoices for a long time, making sure we
// can continue to read even after we've gone through several ping/pong
// cycles.
invoices := make(chan *lnrpc.Invoice, 1)
errors := make(chan error)
done := make(chan struct{})
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
errors <- err
return
}
// The chunked/streamed responses come wrapped in either
// a {"result":{}} or {"error":{}} wrapper which we'll
// get rid of here.
msgStr := string(msg)
if !strings.Contains(msgStr, "\"result\":") {
errors <- fmt.Errorf("invalid msg: %s", msgStr)
return
}
msgStr = resultPattern.ReplaceAllString(msgStr, "${1}")
// Make sure we can parse the unwrapped message into the
// expected proto message.
protoMsg := &lnrpc.Invoice{}
err = jsonpb.UnmarshalString(msgStr, protoMsg)
if err != nil {
errors <- err
return
}
invoices <- protoMsg
// Make sure we exit the loop once we've sent through
// all expected test messages.
select {
case <-done:
return
default:
}
}
}()
// Let's create five invoices and wait for them to arrive. We'll wait
// for at least one ping/pong cycle between each invoice.
ctxb := context.Background()
const numInvoices = 5
const value = 123
const memo = "websocket"
for i := 0; i < numInvoices; i++ {
_, err := net.Alice.AddInvoice(ctxb, &lnrpc.Invoice{
Value: value,
Memo: memo,
})
require.NoError(ht.t, err)
select {
case streamMsg := <-invoices:
require.Equal(ht.t, int64(value), streamMsg.Value)
require.Equal(ht.t, memo, streamMsg.Memo)
case err := <-errors:
require.Fail(ht.t, "Error reading invoice: %v", err)
}
// Let's wait for at least a whole ping/pong cycle to happen, so
// we can be sure the read/write deadlines are set correctly.
// We double the pong wait just to add some extra margin.
time.Sleep(pingInterval + 2*pongWait)
}
close(done)
}
// invokeGET calls the given URL with the GET method and appropriate macaroon // invokeGET calls the given URL with the GET method and appropriate macaroon
// header fields then tries to unmarshal the response into the given response // header fields then tries to unmarshal the response into the given response
// proto message. // proto message.