mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
Merge pull request #5683 from guggero/websocket-write-deadline
lnrpc: Fix WebSocket write deadline not being extended
This commit is contained in:
commit
5e6532594c
@ -324,6 +324,12 @@ you.
|
||||
* [Fix crash with empty AMP or MPP record in
|
||||
invoice](https://github.com/lightningnetwork/lnd/pull/5743).
|
||||
|
||||
* The underlying gRPC connection of a WebSocket is now [properly closed when the
|
||||
WebSocket end of a connection is
|
||||
closed](https://github.com/lightningnetwork/lnd/pull/5683). A bug with the
|
||||
write deadline that caused connections to suddenly break was also fixed in the
|
||||
same PR.
|
||||
|
||||
## Documentation
|
||||
|
||||
The [code contribution guidelines have been updated to mention the new
|
||||
|
@ -33,6 +33,11 @@ const (
|
||||
// additional header field and its value. We use the plus symbol because
|
||||
// the default delimiters aren't allowed in the protocol names.
|
||||
WebSocketProtocolDelimiter = "+"
|
||||
|
||||
// PingContent is the content of the ping message we send out. This is
|
||||
// an arbitrary non-empty message that has no deeper meaning but should
|
||||
// be sent back by the client in the pong message.
|
||||
PingContent = "are you there?"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -147,12 +152,12 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
ctx, cancelFn := context.WithCancel(r.Context())
|
||||
defer cancelFn()
|
||||
|
||||
requestForwarder := newRequestForwardingReader()
|
||||
request, err := http.NewRequestWithContext(
|
||||
r.Context(), r.Method, r.URL.String(), requestForwarder,
|
||||
ctx, r.Method, r.URL.String(), requestForwarder,
|
||||
)
|
||||
if err != nil {
|
||||
p.logger.Errorf("WS: error preparing request:", err)
|
||||
@ -181,6 +186,7 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
responseForwarder.Close()
|
||||
requestForwarder.CloseWriter()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
@ -188,9 +194,19 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
||||
p.backend.ServeHTTP(responseForwarder, request)
|
||||
}()
|
||||
|
||||
// Read loop: Take messages from websocket and write to http request.
|
||||
// Read loop: Take messages from websocket and write them to the payload
|
||||
// channel. This needs to be its own goroutine because for non-client
|
||||
// streaming RPCs, the requestForwarder.Write() in the second goroutine
|
||||
// will block until the request has fully completed. But for the ping/
|
||||
// pong handler to work, we need to have an active call to
|
||||
// conn.ReadMessage() going on. So we make sure we have such an active
|
||||
// call by starting a second read as soon as the first one has
|
||||
// completed.
|
||||
payloadChannel := make(chan []byte, 1)
|
||||
go func() {
|
||||
defer cancelFn()
|
||||
defer close(payloadChannel)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -209,6 +225,34 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
||||
err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case payloadChannel <- payload:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward loop: Take messages from the incoming payload channel and
|
||||
// write them to the http request.
|
||||
go func() {
|
||||
defer cancelFn()
|
||||
for {
|
||||
var payload []byte
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case newPayload, more := <-payloadChannel:
|
||||
if !more {
|
||||
p.logger.Infof("WS: incoming payload " +
|
||||
"chan closed")
|
||||
return
|
||||
}
|
||||
|
||||
payload = newPayload
|
||||
}
|
||||
|
||||
_, err = requestForwarder.Write(payload)
|
||||
if err != nil {
|
||||
p.logger.Errorf("WS: error writing message "+
|
||||
@ -238,12 +282,17 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
||||
|
||||
// Whenever a pong message comes in, we extend the deadline
|
||||
// until the next read is expected by the interval plus pong
|
||||
// wait time.
|
||||
// wait time. Since we can never _reach_ any of the deadlines,
|
||||
// we also have to advance the deadline for the next expected
|
||||
// write to happen, in case the next thing we actually write is
|
||||
// the next ping.
|
||||
conn.SetPongHandler(func(appData string) error {
|
||||
nextDeadline := time.Now().Add(
|
||||
p.pingInterval + p.pongWait,
|
||||
)
|
||||
_ = conn.SetReadDeadline(nextDeadline)
|
||||
_ = conn.SetWriteDeadline(nextDeadline)
|
||||
|
||||
return nil
|
||||
})
|
||||
go func() {
|
||||
@ -263,10 +312,10 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
||||
writeDeadline := time.Now().Add(
|
||||
p.pongWait,
|
||||
)
|
||||
_ = conn.SetWriteDeadline(writeDeadline)
|
||||
|
||||
err := conn.WriteMessage(
|
||||
websocket.PingMessage, nil,
|
||||
err := conn.WriteControl(
|
||||
websocket.PingMessage,
|
||||
[]byte(PingContent),
|
||||
writeDeadline,
|
||||
)
|
||||
if err != nil {
|
||||
p.logger.Warnf("WS: could not "+
|
||||
|
@ -43,13 +43,16 @@ var (
|
||||
}
|
||||
urlEnc = base64.URLEncoding
|
||||
webSocketDialer = &websocket.Dialer{
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
HandshakeTimeout: time.Second,
|
||||
TLSClientConfig: insecureTransport.TLSClientConfig,
|
||||
}
|
||||
resultPattern = regexp.MustCompile("{\"result\":(.*)}")
|
||||
closeMsg = websocket.FormatCloseMessage(
|
||||
websocket.CloseNormalClosure, "done",
|
||||
)
|
||||
|
||||
pingInterval = time.Millisecond * 200
|
||||
pongWait = time.Millisecond * 50
|
||||
)
|
||||
|
||||
// testRestAPI tests that the most important features of the REST API work
|
||||
@ -201,11 +204,18 @@ func testRestAPI(net *lntest.NetworkHarness, ht *harnessTest) {
|
||||
}, {
|
||||
name: "websocket bi-directional subscription",
|
||||
run: wsTestCaseBiDirectionalSubscription,
|
||||
}, {
|
||||
name: "websocket ping and pong timeout",
|
||||
run: wsTestPingPongTimeout,
|
||||
}}
|
||||
|
||||
// Make sure Alice allows all CORS origins. Bob will keep the default.
|
||||
// We also make sure the ping/pong messages are sent very often, so we
|
||||
// can test them without waiting half a minute.
|
||||
net.Alice.Cfg.ExtraArgs = append(
|
||||
net.Alice.Cfg.ExtraArgs, "--restcors=\"*\"",
|
||||
fmt.Sprintf("--ws-ping-interval=%s", pingInterval),
|
||||
fmt.Sprintf("--ws-pong-wait=%s", pongWait),
|
||||
)
|
||||
err := net.RestartNode(net.Alice, nil)
|
||||
if err != nil {
|
||||
@ -432,7 +442,9 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest,
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
msgChan := make(chan *lnrpc.ChannelAcceptResponse)
|
||||
// Buffer the message channel to make sure we're always blocking on
|
||||
// conn.ReadMessage() to allow the ping/pong mechanism to work.
|
||||
msgChan := make(chan *lnrpc.ChannelAcceptResponse, 1)
|
||||
errChan := make(chan error)
|
||||
done := make(chan struct{})
|
||||
timeout := time.After(defaultTimeout)
|
||||
@ -522,6 +534,111 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest,
|
||||
close(done)
|
||||
}
|
||||
|
||||
func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) {
|
||||
initialRequest := &lnrpc.InvoiceSubscription{
|
||||
AddIndex: 1, SettleIndex: 1,
|
||||
}
|
||||
url := "/v1/invoices/subscribe"
|
||||
|
||||
// This time we send the macaroon in the special header
|
||||
// Sec-Websocket-Protocol which is the only header field available to
|
||||
// browsers when opening a WebSocket.
|
||||
mac, err := net.Alice.ReadMacaroon(
|
||||
net.Alice.AdminMacPath(), defaultTimeout,
|
||||
)
|
||||
require.NoError(ht.t, err, "read admin mac")
|
||||
macBytes, err := mac.MarshalBinary()
|
||||
require.NoError(ht.t, err, "marshal admin mac")
|
||||
|
||||
customHeader := make(http.Header)
|
||||
customHeader.Set(lnrpc.HeaderWebSocketProtocol, fmt.Sprintf(
|
||||
"Grpc-Metadata-Macaroon+%s", hex.EncodeToString(macBytes),
|
||||
))
|
||||
conn, err := openWebSocket(
|
||||
net.Alice, url, "GET", initialRequest, customHeader,
|
||||
)
|
||||
require.Nil(ht.t, err, "websocket")
|
||||
defer func() {
|
||||
err := conn.WriteMessage(websocket.CloseMessage, closeMsg)
|
||||
require.NoError(ht.t, err)
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// We want to be able to read invoices for a long time, making sure we
|
||||
// can continue to read even after we've gone through several ping/pong
|
||||
// cycles.
|
||||
invoices := make(chan *lnrpc.Invoice, 1)
|
||||
errors := make(chan error)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
// The chunked/streamed responses come wrapped in either
|
||||
// a {"result":{}} or {"error":{}} wrapper which we'll
|
||||
// get rid of here.
|
||||
msgStr := string(msg)
|
||||
if !strings.Contains(msgStr, "\"result\":") {
|
||||
errors <- fmt.Errorf("invalid msg: %s", msgStr)
|
||||
return
|
||||
}
|
||||
msgStr = resultPattern.ReplaceAllString(msgStr, "${1}")
|
||||
|
||||
// Make sure we can parse the unwrapped message into the
|
||||
// expected proto message.
|
||||
protoMsg := &lnrpc.Invoice{}
|
||||
err = jsonpb.UnmarshalString(msgStr, protoMsg)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
invoices <- protoMsg
|
||||
|
||||
// Make sure we exit the loop once we've sent through
|
||||
// all expected test messages.
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Let's create five invoices and wait for them to arrive. We'll wait
|
||||
// for at least one ping/pong cycle between each invoice.
|
||||
ctxb := context.Background()
|
||||
const numInvoices = 5
|
||||
const value = 123
|
||||
const memo = "websocket"
|
||||
for i := 0; i < numInvoices; i++ {
|
||||
_, err := net.Alice.AddInvoice(ctxb, &lnrpc.Invoice{
|
||||
Value: value,
|
||||
Memo: memo,
|
||||
})
|
||||
require.NoError(ht.t, err)
|
||||
|
||||
select {
|
||||
case streamMsg := <-invoices:
|
||||
require.Equal(ht.t, int64(value), streamMsg.Value)
|
||||
require.Equal(ht.t, memo, streamMsg.Memo)
|
||||
|
||||
case err := <-errors:
|
||||
require.Fail(ht.t, "Error reading invoice: %v", err)
|
||||
}
|
||||
|
||||
// Let's wait for at least a whole ping/pong cycle to happen, so
|
||||
// we can be sure the read/write deadlines are set correctly.
|
||||
// We double the pong wait just to add some extra margin.
|
||||
time.Sleep(pingInterval + 2*pongWait)
|
||||
}
|
||||
close(done)
|
||||
}
|
||||
|
||||
// invokeGET calls the given URL with the GET method and appropriate macaroon
|
||||
// header fields then tries to unmarshal the response into the given response
|
||||
// proto message.
|
||||
|
Loading…
Reference in New Issue
Block a user