mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
Merge pull request #5683 from guggero/websocket-write-deadline
lnrpc: Fix WebSocket write deadline not being extended
This commit is contained in:
commit
5e6532594c
@ -324,6 +324,12 @@ you.
|
|||||||
* [Fix crash with empty AMP or MPP record in
|
* [Fix crash with empty AMP or MPP record in
|
||||||
invoice](https://github.com/lightningnetwork/lnd/pull/5743).
|
invoice](https://github.com/lightningnetwork/lnd/pull/5743).
|
||||||
|
|
||||||
|
* The underlying gRPC connection of a WebSocket is now [properly closed when the
|
||||||
|
WebSocket end of a connection is
|
||||||
|
closed](https://github.com/lightningnetwork/lnd/pull/5683). A bug with the
|
||||||
|
write deadline that caused connections to suddenly break was also fixed in the
|
||||||
|
same PR.
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
The [code contribution guidelines have been updated to mention the new
|
The [code contribution guidelines have been updated to mention the new
|
||||||
|
@ -33,6 +33,11 @@ const (
|
|||||||
// additional header field and its value. We use the plus symbol because
|
// additional header field and its value. We use the plus symbol because
|
||||||
// the default delimiters aren't allowed in the protocol names.
|
// the default delimiters aren't allowed in the protocol names.
|
||||||
WebSocketProtocolDelimiter = "+"
|
WebSocketProtocolDelimiter = "+"
|
||||||
|
|
||||||
|
// PingContent is the content of the ping message we send out. This is
|
||||||
|
// an arbitrary non-empty message that has no deeper meaning but should
|
||||||
|
// be sent back by the client in the pong message.
|
||||||
|
PingContent = "are you there?"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -147,12 +152,12 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ctx, cancelFn := context.WithCancel(context.Background())
|
ctx, cancelFn := context.WithCancel(r.Context())
|
||||||
defer cancelFn()
|
defer cancelFn()
|
||||||
|
|
||||||
requestForwarder := newRequestForwardingReader()
|
requestForwarder := newRequestForwardingReader()
|
||||||
request, err := http.NewRequestWithContext(
|
request, err := http.NewRequestWithContext(
|
||||||
r.Context(), r.Method, r.URL.String(), requestForwarder,
|
ctx, r.Method, r.URL.String(), requestForwarder,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.logger.Errorf("WS: error preparing request:", err)
|
p.logger.Errorf("WS: error preparing request:", err)
|
||||||
@ -181,6 +186,7 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
|||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
responseForwarder.Close()
|
responseForwarder.Close()
|
||||||
|
requestForwarder.CloseWriter()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@ -188,9 +194,19 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
|||||||
p.backend.ServeHTTP(responseForwarder, request)
|
p.backend.ServeHTTP(responseForwarder, request)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Read loop: Take messages from websocket and write to http request.
|
// Read loop: Take messages from websocket and write them to the payload
|
||||||
|
// channel. This needs to be its own goroutine because for non-client
|
||||||
|
// streaming RPCs, the requestForwarder.Write() in the second goroutine
|
||||||
|
// will block until the request has fully completed. But for the ping/
|
||||||
|
// pong handler to work, we need to have an active call to
|
||||||
|
// conn.ReadMessage() going on. So we make sure we have such an active
|
||||||
|
// call by starting a second read as soon as the first one has
|
||||||
|
// completed.
|
||||||
|
payloadChannel := make(chan []byte, 1)
|
||||||
go func() {
|
go func() {
|
||||||
defer cancelFn()
|
defer cancelFn()
|
||||||
|
defer close(payloadChannel)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@ -209,6 +225,34 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
|||||||
err)
|
err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case payloadChannel <- payload:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Forward loop: Take messages from the incoming payload channel and
|
||||||
|
// write them to the http request.
|
||||||
|
go func() {
|
||||||
|
defer cancelFn()
|
||||||
|
for {
|
||||||
|
var payload []byte
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case newPayload, more := <-payloadChannel:
|
||||||
|
if !more {
|
||||||
|
p.logger.Infof("WS: incoming payload " +
|
||||||
|
"chan closed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = newPayload
|
||||||
|
}
|
||||||
|
|
||||||
_, err = requestForwarder.Write(payload)
|
_, err = requestForwarder.Write(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.logger.Errorf("WS: error writing message "+
|
p.logger.Errorf("WS: error writing message "+
|
||||||
@ -238,12 +282,17 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
|||||||
|
|
||||||
// Whenever a pong message comes in, we extend the deadline
|
// Whenever a pong message comes in, we extend the deadline
|
||||||
// until the next read is expected by the interval plus pong
|
// until the next read is expected by the interval plus pong
|
||||||
// wait time.
|
// wait time. Since we can never _reach_ any of the deadlines,
|
||||||
|
// we also have to advance the deadline for the next expected
|
||||||
|
// write to happen, in case the next thing we actually write is
|
||||||
|
// the next ping.
|
||||||
conn.SetPongHandler(func(appData string) error {
|
conn.SetPongHandler(func(appData string) error {
|
||||||
nextDeadline := time.Now().Add(
|
nextDeadline := time.Now().Add(
|
||||||
p.pingInterval + p.pongWait,
|
p.pingInterval + p.pongWait,
|
||||||
)
|
)
|
||||||
_ = conn.SetReadDeadline(nextDeadline)
|
_ = conn.SetReadDeadline(nextDeadline)
|
||||||
|
_ = conn.SetWriteDeadline(nextDeadline)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
go func() {
|
go func() {
|
||||||
@ -263,10 +312,10 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
|
|||||||
writeDeadline := time.Now().Add(
|
writeDeadline := time.Now().Add(
|
||||||
p.pongWait,
|
p.pongWait,
|
||||||
)
|
)
|
||||||
_ = conn.SetWriteDeadline(writeDeadline)
|
err := conn.WriteControl(
|
||||||
|
websocket.PingMessage,
|
||||||
err := conn.WriteMessage(
|
[]byte(PingContent),
|
||||||
websocket.PingMessage, nil,
|
writeDeadline,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.logger.Warnf("WS: could not "+
|
p.logger.Warnf("WS: could not "+
|
||||||
|
@ -43,13 +43,16 @@ var (
|
|||||||
}
|
}
|
||||||
urlEnc = base64.URLEncoding
|
urlEnc = base64.URLEncoding
|
||||||
webSocketDialer = &websocket.Dialer{
|
webSocketDialer = &websocket.Dialer{
|
||||||
HandshakeTimeout: 45 * time.Second,
|
HandshakeTimeout: time.Second,
|
||||||
TLSClientConfig: insecureTransport.TLSClientConfig,
|
TLSClientConfig: insecureTransport.TLSClientConfig,
|
||||||
}
|
}
|
||||||
resultPattern = regexp.MustCompile("{\"result\":(.*)}")
|
resultPattern = regexp.MustCompile("{\"result\":(.*)}")
|
||||||
closeMsg = websocket.FormatCloseMessage(
|
closeMsg = websocket.FormatCloseMessage(
|
||||||
websocket.CloseNormalClosure, "done",
|
websocket.CloseNormalClosure, "done",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pingInterval = time.Millisecond * 200
|
||||||
|
pongWait = time.Millisecond * 50
|
||||||
)
|
)
|
||||||
|
|
||||||
// testRestAPI tests that the most important features of the REST API work
|
// testRestAPI tests that the most important features of the REST API work
|
||||||
@ -201,11 +204,18 @@ func testRestAPI(net *lntest.NetworkHarness, ht *harnessTest) {
|
|||||||
}, {
|
}, {
|
||||||
name: "websocket bi-directional subscription",
|
name: "websocket bi-directional subscription",
|
||||||
run: wsTestCaseBiDirectionalSubscription,
|
run: wsTestCaseBiDirectionalSubscription,
|
||||||
|
}, {
|
||||||
|
name: "websocket ping and pong timeout",
|
||||||
|
run: wsTestPingPongTimeout,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
// Make sure Alice allows all CORS origins. Bob will keep the default.
|
// Make sure Alice allows all CORS origins. Bob will keep the default.
|
||||||
|
// We also make sure the ping/pong messages are sent very often, so we
|
||||||
|
// can test them without waiting half a minute.
|
||||||
net.Alice.Cfg.ExtraArgs = append(
|
net.Alice.Cfg.ExtraArgs = append(
|
||||||
net.Alice.Cfg.ExtraArgs, "--restcors=\"*\"",
|
net.Alice.Cfg.ExtraArgs, "--restcors=\"*\"",
|
||||||
|
fmt.Sprintf("--ws-ping-interval=%s", pingInterval),
|
||||||
|
fmt.Sprintf("--ws-pong-wait=%s", pongWait),
|
||||||
)
|
)
|
||||||
err := net.RestartNode(net.Alice, nil)
|
err := net.RestartNode(net.Alice, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -432,7 +442,9 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest,
|
|||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
msgChan := make(chan *lnrpc.ChannelAcceptResponse)
|
// Buffer the message channel to make sure we're always blocking on
|
||||||
|
// conn.ReadMessage() to allow the ping/pong mechanism to work.
|
||||||
|
msgChan := make(chan *lnrpc.ChannelAcceptResponse, 1)
|
||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
timeout := time.After(defaultTimeout)
|
timeout := time.After(defaultTimeout)
|
||||||
@ -522,6 +534,111 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest,
|
|||||||
close(done)
|
close(done)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) {
|
||||||
|
initialRequest := &lnrpc.InvoiceSubscription{
|
||||||
|
AddIndex: 1, SettleIndex: 1,
|
||||||
|
}
|
||||||
|
url := "/v1/invoices/subscribe"
|
||||||
|
|
||||||
|
// This time we send the macaroon in the special header
|
||||||
|
// Sec-Websocket-Protocol which is the only header field available to
|
||||||
|
// browsers when opening a WebSocket.
|
||||||
|
mac, err := net.Alice.ReadMacaroon(
|
||||||
|
net.Alice.AdminMacPath(), defaultTimeout,
|
||||||
|
)
|
||||||
|
require.NoError(ht.t, err, "read admin mac")
|
||||||
|
macBytes, err := mac.MarshalBinary()
|
||||||
|
require.NoError(ht.t, err, "marshal admin mac")
|
||||||
|
|
||||||
|
customHeader := make(http.Header)
|
||||||
|
customHeader.Set(lnrpc.HeaderWebSocketProtocol, fmt.Sprintf(
|
||||||
|
"Grpc-Metadata-Macaroon+%s", hex.EncodeToString(macBytes),
|
||||||
|
))
|
||||||
|
conn, err := openWebSocket(
|
||||||
|
net.Alice, url, "GET", initialRequest, customHeader,
|
||||||
|
)
|
||||||
|
require.Nil(ht.t, err, "websocket")
|
||||||
|
defer func() {
|
||||||
|
err := conn.WriteMessage(websocket.CloseMessage, closeMsg)
|
||||||
|
require.NoError(ht.t, err)
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// We want to be able to read invoices for a long time, making sure we
|
||||||
|
// can continue to read even after we've gone through several ping/pong
|
||||||
|
// cycles.
|
||||||
|
invoices := make(chan *lnrpc.Invoice, 1)
|
||||||
|
errors := make(chan error)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
_, msg, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// The chunked/streamed responses come wrapped in either
|
||||||
|
// a {"result":{}} or {"error":{}} wrapper which we'll
|
||||||
|
// get rid of here.
|
||||||
|
msgStr := string(msg)
|
||||||
|
if !strings.Contains(msgStr, "\"result\":") {
|
||||||
|
errors <- fmt.Errorf("invalid msg: %s", msgStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msgStr = resultPattern.ReplaceAllString(msgStr, "${1}")
|
||||||
|
|
||||||
|
// Make sure we can parse the unwrapped message into the
|
||||||
|
// expected proto message.
|
||||||
|
protoMsg := &lnrpc.Invoice{}
|
||||||
|
err = jsonpb.UnmarshalString(msgStr, protoMsg)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
invoices <- protoMsg
|
||||||
|
|
||||||
|
// Make sure we exit the loop once we've sent through
|
||||||
|
// all expected test messages.
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Let's create five invoices and wait for them to arrive. We'll wait
|
||||||
|
// for at least one ping/pong cycle between each invoice.
|
||||||
|
ctxb := context.Background()
|
||||||
|
const numInvoices = 5
|
||||||
|
const value = 123
|
||||||
|
const memo = "websocket"
|
||||||
|
for i := 0; i < numInvoices; i++ {
|
||||||
|
_, err := net.Alice.AddInvoice(ctxb, &lnrpc.Invoice{
|
||||||
|
Value: value,
|
||||||
|
Memo: memo,
|
||||||
|
})
|
||||||
|
require.NoError(ht.t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case streamMsg := <-invoices:
|
||||||
|
require.Equal(ht.t, int64(value), streamMsg.Value)
|
||||||
|
require.Equal(ht.t, memo, streamMsg.Memo)
|
||||||
|
|
||||||
|
case err := <-errors:
|
||||||
|
require.Fail(ht.t, "Error reading invoice: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let's wait for at least a whole ping/pong cycle to happen, so
|
||||||
|
// we can be sure the read/write deadlines are set correctly.
|
||||||
|
// We double the pong wait just to add some extra margin.
|
||||||
|
time.Sleep(pingInterval + 2*pongWait)
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}
|
||||||
|
|
||||||
// invokeGET calls the given URL with the GET method and appropriate macaroon
|
// invokeGET calls the given URL with the GET method and appropriate macaroon
|
||||||
// header fields then tries to unmarshal the response into the given response
|
// header fields then tries to unmarshal the response into the given response
|
||||||
// proto message.
|
// proto message.
|
||||||
|
Loading…
Reference in New Issue
Block a user