From 54a25146f4c532e8a7647adfddb77d1415335a37 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Mon, 8 Nov 2021 14:04:19 +0100 Subject: [PATCH] rpcperms: add unique request ID This commit adds a unique request ID that is the same for each gRPC request and response intercept message or each request/response message of a gRPC stream. --- .../lnd_rpc_middleware_interceptor_test.go | 15 ++++++++ rpcperms/interceptor.go | 35 ++++++++++++++----- rpcperms/middleware_handler.go | 21 ++++++----- 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/lntest/itest/lnd_rpc_middleware_interceptor_test.go b/lntest/itest/lnd_rpc_middleware_interceptor_test.go index b894edd90..fd9e5e95b 100644 --- a/lntest/itest/lnd_rpc_middleware_interceptor_test.go +++ b/lntest/itest/lnd_rpc_middleware_interceptor_test.go @@ -545,6 +545,12 @@ func (h *middlewareHarness) interceptUnary(methodURI string, res := respIntercept.GetResponse() require.NotNil(h.t, res) + // We expect the request ID to be the same for the request intercept + // and the response intercept messages. But the message IDs must be + // different/unique. + require.Equal(h.t, reqIntercept.RequestId, respIntercept.RequestId) + require.NotEqual(h.t, reqIntercept.MsgId, respIntercept.MsgId) + // We need to accept the response as well. h.sendAccept(respIntercept.MsgId, responseReplacement) @@ -593,6 +599,15 @@ func (h *middlewareHarness) interceptStream(methodURI string, res := respIntercept.GetResponse() require.NotNil(h.t, res) + // We expect the request ID to be the same for the auth intercept, + // request intercept and the response intercept messages. But the + // message IDs must be different/unique. + require.Equal(h.t, authIntercept.RequestId, respIntercept.RequestId) + require.Equal(h.t, reqIntercept.RequestId, respIntercept.RequestId) + require.NotEqual(h.t, authIntercept.MsgId, reqIntercept.MsgId) + require.NotEqual(h.t, authIntercept.MsgId, respIntercept.MsgId) + require.NotEqual(h.t, reqIntercept.MsgId, respIntercept.MsgId) + // We need to accept the response as well. h.sendAccept(respIntercept.MsgId, responseReplacement) diff --git a/rpcperms/interceptor.go b/rpcperms/interceptor.go index 97e752b17..7be2f0ea0 100644 --- a/rpcperms/interceptor.go +++ b/rpcperms/interceptor.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "github.com/btcsuite/btclog" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" @@ -134,6 +135,12 @@ var ( // | edited gRPC request to client // v type InterceptorChain struct { + // lastRequestID is the ID of the last gRPC request or stream that was + // intercepted by the middleware interceptor. + // + // NOTE: Must be used atomically! + lastRequestID uint64 + // Required by the grpc-gateway/v2 library for forward compatibility. lnrpc.UnimplementedStateServer @@ -790,7 +797,8 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn return nil, err } - err = r.acceptRequest(msg) + requestID := atomic.AddUint64(&r.lastRequestID, 1) + err = r.acceptRequest(requestID, msg) if err != nil { return nil, err } @@ -800,7 +808,9 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn return resp, respErr } - return r.interceptResponse(ctx, false, info.FullMethod, resp) + return r.interceptResponse( + ctx, requestID, false, info.FullMethod, resp, + ) } } @@ -845,13 +855,15 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer return err } - err = r.acceptRequest(msg) + requestID := atomic.AddUint64(&r.lastRequestID, 1) + err = r.acceptRequest(requestID, msg) if err != nil { return err } wrappedSS := &serverStreamWrapper{ ServerStream: ss, + requestID: requestID, fullMethod: info.FullMethod, interceptor: r, } @@ -900,7 +912,9 @@ func (r *InterceptorChain) middlewareRegistered() bool { // registered for it. This means either a middleware has requested read-only // access or the request actually has a macaroon which a caveat the middleware // registered for. -func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error { +func (r *InterceptorChain) acceptRequest(requestID uint64, + msg *InterceptionRequest) error { + r.RLock() defer r.RUnlock() @@ -915,7 +929,7 @@ func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error { continue } - resp, err := middleware.intercept(msg) + resp, err := middleware.intercept(requestID, msg) // Error during interception itself. if err != nil { @@ -936,7 +950,8 @@ func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error { // overwrite/replace the response, this needs to be handled differently than the // request/auth path above. func (r *InterceptorChain) interceptResponse(ctx context.Context, - isStream bool, fullMethod string, m interface{}) (interface{}, error) { + requestID uint64, isStream bool, fullMethod string, + m interface{}) (interface{}, error) { r.RLock() defer r.RUnlock() @@ -960,7 +975,7 @@ func (r *InterceptorChain) interceptResponse(ctx context.Context, continue } - resp, err := middleware.intercept(msg) + resp, err := middleware.intercept(requestID, msg) // Error during interception itself. if err != nil { @@ -988,6 +1003,8 @@ type serverStreamWrapper struct { // ServerStream is the stream that's being wrapped. grpc.ServerStream + requestID uint64 + fullMethod string interceptor *InterceptorChain @@ -997,7 +1014,7 @@ type serverStreamWrapper struct { // intercept streaming RPC responses. func (w *serverStreamWrapper) SendMsg(m interface{}) error { newMsg, err := w.interceptor.interceptResponse( - w.ServerStream.Context(), true, w.fullMethod, m, + w.ServerStream.Context(), w.requestID, true, w.fullMethod, m, ) if err != nil { return err @@ -1022,5 +1039,5 @@ func (w *serverStreamWrapper) RecvMsg(m interface{}) error { return err } - return w.interceptor.acceptRequest(msg) + return w.interceptor.acceptRequest(w.requestID, msg) } diff --git a/rpcperms/middleware_handler.go b/rpcperms/middleware_handler.go index 74f171cca..db122619d 100644 --- a/rpcperms/middleware_handler.go +++ b/rpcperms/middleware_handler.go @@ -107,14 +107,15 @@ func NewMiddlewareHandler(name, customCaveatName string, readOnly bool, // feedback on it and sending the feedback to the appropriate channel. All steps // are guarded by the configured timeout to make sure a middleware cannot slow // down requests too much. -func (h *MiddlewareHandler) intercept( +func (h *MiddlewareHandler) intercept(requestID uint64, req *InterceptionRequest) (*interceptResponse, error) { respChan := make(chan *interceptResponse, 1) newRequest := &interceptRequest{ - request: req, - response: respChan, + requestID: requestID, + request: req, + response: respChan, } // timeout is the time after which intercept requests expire. @@ -233,7 +234,9 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error, req := newRequest.request interceptRequests[msgID] = newRequest - interceptReq, err := req.ToRPC(msgID) + interceptReq, err := req.ToRPC( + newRequest.requestID, msgID, + ) if err != nil { return err } @@ -447,10 +450,11 @@ func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte, } // ToRPC converts the interception request to its RPC counterpart. -func (r *InterceptionRequest) ToRPC(msgID uint64) (*lnrpc.RPCMiddlewareRequest, - error) { +func (r *InterceptionRequest) ToRPC(requestID, + msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) { rpcRequest := &lnrpc.RPCMiddlewareRequest{ + RequestId: requestID, MsgId: msgID, RawMacaroon: r.RawMacaroon, CustomCaveatCondition: r.CustomCaveatCondition, @@ -495,8 +499,9 @@ func (r *InterceptionRequest) ToRPC(msgID uint64) (*lnrpc.RPCMiddlewareRequest, // out to a middleware and the response that is eventually sent back by the // middleware. type interceptRequest struct { - request *InterceptionRequest - response chan *interceptResponse + requestID uint64 + request *InterceptionRequest + response chan *interceptResponse } // interceptResponse is the response a middleware sends back for each