rpcperms: add unique request ID

This commit adds a unique request ID that is the same for each gRPC
request and response intercept message or each request/response message
of a gRPC stream.
This commit is contained in:
Oliver Gugger 2021-11-08 14:04:19 +01:00
parent 9a28a4a9ff
commit 54a25146f4
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
3 changed files with 54 additions and 17 deletions

View File

@ -545,6 +545,12 @@ func (h *middlewareHarness) interceptUnary(methodURI string,
res := respIntercept.GetResponse()
require.NotNil(h.t, res)
// We expect the request ID to be the same for the request intercept
// and the response intercept messages. But the message IDs must be
// different/unique.
require.Equal(h.t, reqIntercept.RequestId, respIntercept.RequestId)
require.NotEqual(h.t, reqIntercept.MsgId, respIntercept.MsgId)
// We need to accept the response as well.
h.sendAccept(respIntercept.MsgId, responseReplacement)
@ -593,6 +599,15 @@ func (h *middlewareHarness) interceptStream(methodURI string,
res := respIntercept.GetResponse()
require.NotNil(h.t, res)
// We expect the request ID to be the same for the auth intercept,
// request intercept and the response intercept messages. But the
// message IDs must be different/unique.
require.Equal(h.t, authIntercept.RequestId, respIntercept.RequestId)
require.Equal(h.t, reqIntercept.RequestId, respIntercept.RequestId)
require.NotEqual(h.t, authIntercept.MsgId, reqIntercept.MsgId)
require.NotEqual(h.t, authIntercept.MsgId, respIntercept.MsgId)
require.NotEqual(h.t, reqIntercept.MsgId, respIntercept.MsgId)
// We need to accept the response as well.
h.sendAccept(respIntercept.MsgId, responseReplacement)

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/btcsuite/btclog"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
@ -134,6 +135,12 @@ var (
// | edited gRPC request to client
// v
type InterceptorChain struct {
// lastRequestID is the ID of the last gRPC request or stream that was
// intercepted by the middleware interceptor.
//
// NOTE: Must be used atomically!
lastRequestID uint64
// Required by the grpc-gateway/v2 library for forward compatibility.
lnrpc.UnimplementedStateServer
@ -790,7 +797,8 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
return nil, err
}
err = r.acceptRequest(msg)
requestID := atomic.AddUint64(&r.lastRequestID, 1)
err = r.acceptRequest(requestID, msg)
if err != nil {
return nil, err
}
@ -800,7 +808,9 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
return resp, respErr
}
return r.interceptResponse(ctx, false, info.FullMethod, resp)
return r.interceptResponse(
ctx, requestID, false, info.FullMethod, resp,
)
}
}
@ -845,13 +855,15 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer
return err
}
err = r.acceptRequest(msg)
requestID := atomic.AddUint64(&r.lastRequestID, 1)
err = r.acceptRequest(requestID, msg)
if err != nil {
return err
}
wrappedSS := &serverStreamWrapper{
ServerStream: ss,
requestID: requestID,
fullMethod: info.FullMethod,
interceptor: r,
}
@ -900,7 +912,9 @@ func (r *InterceptorChain) middlewareRegistered() bool {
// registered for it. This means either a middleware has requested read-only
// access or the request actually has a macaroon which a caveat the middleware
// registered for.
func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error {
func (r *InterceptorChain) acceptRequest(requestID uint64,
msg *InterceptionRequest) error {
r.RLock()
defer r.RUnlock()
@ -915,7 +929,7 @@ func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error {
continue
}
resp, err := middleware.intercept(msg)
resp, err := middleware.intercept(requestID, msg)
// Error during interception itself.
if err != nil {
@ -936,7 +950,8 @@ func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error {
// overwrite/replace the response, this needs to be handled differently than the
// request/auth path above.
func (r *InterceptorChain) interceptResponse(ctx context.Context,
isStream bool, fullMethod string, m interface{}) (interface{}, error) {
requestID uint64, isStream bool, fullMethod string,
m interface{}) (interface{}, error) {
r.RLock()
defer r.RUnlock()
@ -960,7 +975,7 @@ func (r *InterceptorChain) interceptResponse(ctx context.Context,
continue
}
resp, err := middleware.intercept(msg)
resp, err := middleware.intercept(requestID, msg)
// Error during interception itself.
if err != nil {
@ -988,6 +1003,8 @@ type serverStreamWrapper struct {
// ServerStream is the stream that's being wrapped.
grpc.ServerStream
requestID uint64
fullMethod string
interceptor *InterceptorChain
@ -997,7 +1014,7 @@ type serverStreamWrapper struct {
// intercept streaming RPC responses.
func (w *serverStreamWrapper) SendMsg(m interface{}) error {
newMsg, err := w.interceptor.interceptResponse(
w.ServerStream.Context(), true, w.fullMethod, m,
w.ServerStream.Context(), w.requestID, true, w.fullMethod, m,
)
if err != nil {
return err
@ -1022,5 +1039,5 @@ func (w *serverStreamWrapper) RecvMsg(m interface{}) error {
return err
}
return w.interceptor.acceptRequest(msg)
return w.interceptor.acceptRequest(w.requestID, msg)
}

View File

@ -107,12 +107,13 @@ func NewMiddlewareHandler(name, customCaveatName string, readOnly bool,
// feedback on it and sending the feedback to the appropriate channel. All steps
// are guarded by the configured timeout to make sure a middleware cannot slow
// down requests too much.
func (h *MiddlewareHandler) intercept(
func (h *MiddlewareHandler) intercept(requestID uint64,
req *InterceptionRequest) (*interceptResponse, error) {
respChan := make(chan *interceptResponse, 1)
newRequest := &interceptRequest{
requestID: requestID,
request: req,
response: respChan,
}
@ -233,7 +234,9 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
req := newRequest.request
interceptRequests[msgID] = newRequest
interceptReq, err := req.ToRPC(msgID)
interceptReq, err := req.ToRPC(
newRequest.requestID, msgID,
)
if err != nil {
return err
}
@ -447,10 +450,11 @@ func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte,
}
// ToRPC converts the interception request to its RPC counterpart.
func (r *InterceptionRequest) ToRPC(msgID uint64) (*lnrpc.RPCMiddlewareRequest,
error) {
func (r *InterceptionRequest) ToRPC(requestID,
msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) {
rpcRequest := &lnrpc.RPCMiddlewareRequest{
RequestId: requestID,
MsgId: msgID,
RawMacaroon: r.RawMacaroon,
CustomCaveatCondition: r.CustomCaveatCondition,
@ -495,6 +499,7 @@ func (r *InterceptionRequest) ToRPC(msgID uint64) (*lnrpc.RPCMiddlewareRequest,
// out to a middleware and the response that is eventually sent back by the
// middleware.
type interceptRequest struct {
requestID uint64
request *InterceptionRequest
response chan *interceptResponse
}