mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
rpcperms: add unique request ID
This commit adds a unique request ID that is the same for each gRPC request and response intercept message or each request/response message of a gRPC stream.
This commit is contained in:
parent
9a28a4a9ff
commit
54a25146f4
@ -545,6 +545,12 @@ func (h *middlewareHarness) interceptUnary(methodURI string,
|
||||
res := respIntercept.GetResponse()
|
||||
require.NotNil(h.t, res)
|
||||
|
||||
// We expect the request ID to be the same for the request intercept
|
||||
// and the response intercept messages. But the message IDs must be
|
||||
// different/unique.
|
||||
require.Equal(h.t, reqIntercept.RequestId, respIntercept.RequestId)
|
||||
require.NotEqual(h.t, reqIntercept.MsgId, respIntercept.MsgId)
|
||||
|
||||
// We need to accept the response as well.
|
||||
h.sendAccept(respIntercept.MsgId, responseReplacement)
|
||||
|
||||
@ -593,6 +599,15 @@ func (h *middlewareHarness) interceptStream(methodURI string,
|
||||
res := respIntercept.GetResponse()
|
||||
require.NotNil(h.t, res)
|
||||
|
||||
// We expect the request ID to be the same for the auth intercept,
|
||||
// request intercept and the response intercept messages. But the
|
||||
// message IDs must be different/unique.
|
||||
require.Equal(h.t, authIntercept.RequestId, respIntercept.RequestId)
|
||||
require.Equal(h.t, reqIntercept.RequestId, respIntercept.RequestId)
|
||||
require.NotEqual(h.t, authIntercept.MsgId, reqIntercept.MsgId)
|
||||
require.NotEqual(h.t, authIntercept.MsgId, respIntercept.MsgId)
|
||||
require.NotEqual(h.t, reqIntercept.MsgId, respIntercept.MsgId)
|
||||
|
||||
// We need to accept the response as well.
|
||||
h.sendAccept(respIntercept.MsgId, responseReplacement)
|
||||
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/btcsuite/btclog"
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
@ -134,6 +135,12 @@ var (
|
||||
// | edited gRPC request to client
|
||||
// v
|
||||
type InterceptorChain struct {
|
||||
// lastRequestID is the ID of the last gRPC request or stream that was
|
||||
// intercepted by the middleware interceptor.
|
||||
//
|
||||
// NOTE: Must be used atomically!
|
||||
lastRequestID uint64
|
||||
|
||||
// Required by the grpc-gateway/v2 library for forward compatibility.
|
||||
lnrpc.UnimplementedStateServer
|
||||
|
||||
@ -790,7 +797,8 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = r.acceptRequest(msg)
|
||||
requestID := atomic.AddUint64(&r.lastRequestID, 1)
|
||||
err = r.acceptRequest(requestID, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -800,7 +808,9 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
|
||||
return resp, respErr
|
||||
}
|
||||
|
||||
return r.interceptResponse(ctx, false, info.FullMethod, resp)
|
||||
return r.interceptResponse(
|
||||
ctx, requestID, false, info.FullMethod, resp,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -845,13 +855,15 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.acceptRequest(msg)
|
||||
requestID := atomic.AddUint64(&r.lastRequestID, 1)
|
||||
err = r.acceptRequest(requestID, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wrappedSS := &serverStreamWrapper{
|
||||
ServerStream: ss,
|
||||
requestID: requestID,
|
||||
fullMethod: info.FullMethod,
|
||||
interceptor: r,
|
||||
}
|
||||
@ -900,7 +912,9 @@ func (r *InterceptorChain) middlewareRegistered() bool {
|
||||
// registered for it. This means either a middleware has requested read-only
|
||||
// access or the request actually has a macaroon which a caveat the middleware
|
||||
// registered for.
|
||||
func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error {
|
||||
func (r *InterceptorChain) acceptRequest(requestID uint64,
|
||||
msg *InterceptionRequest) error {
|
||||
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
@ -915,7 +929,7 @@ func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error {
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := middleware.intercept(msg)
|
||||
resp, err := middleware.intercept(requestID, msg)
|
||||
|
||||
// Error during interception itself.
|
||||
if err != nil {
|
||||
@ -936,7 +950,8 @@ func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error {
|
||||
// overwrite/replace the response, this needs to be handled differently than the
|
||||
// request/auth path above.
|
||||
func (r *InterceptorChain) interceptResponse(ctx context.Context,
|
||||
isStream bool, fullMethod string, m interface{}) (interface{}, error) {
|
||||
requestID uint64, isStream bool, fullMethod string,
|
||||
m interface{}) (interface{}, error) {
|
||||
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
@ -960,7 +975,7 @@ func (r *InterceptorChain) interceptResponse(ctx context.Context,
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := middleware.intercept(msg)
|
||||
resp, err := middleware.intercept(requestID, msg)
|
||||
|
||||
// Error during interception itself.
|
||||
if err != nil {
|
||||
@ -988,6 +1003,8 @@ type serverStreamWrapper struct {
|
||||
// ServerStream is the stream that's being wrapped.
|
||||
grpc.ServerStream
|
||||
|
||||
requestID uint64
|
||||
|
||||
fullMethod string
|
||||
|
||||
interceptor *InterceptorChain
|
||||
@ -997,7 +1014,7 @@ type serverStreamWrapper struct {
|
||||
// intercept streaming RPC responses.
|
||||
func (w *serverStreamWrapper) SendMsg(m interface{}) error {
|
||||
newMsg, err := w.interceptor.interceptResponse(
|
||||
w.ServerStream.Context(), true, w.fullMethod, m,
|
||||
w.ServerStream.Context(), w.requestID, true, w.fullMethod, m,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1022,5 +1039,5 @@ func (w *serverStreamWrapper) RecvMsg(m interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return w.interceptor.acceptRequest(msg)
|
||||
return w.interceptor.acceptRequest(w.requestID, msg)
|
||||
}
|
||||
|
@ -107,12 +107,13 @@ func NewMiddlewareHandler(name, customCaveatName string, readOnly bool,
|
||||
// feedback on it and sending the feedback to the appropriate channel. All steps
|
||||
// are guarded by the configured timeout to make sure a middleware cannot slow
|
||||
// down requests too much.
|
||||
func (h *MiddlewareHandler) intercept(
|
||||
func (h *MiddlewareHandler) intercept(requestID uint64,
|
||||
req *InterceptionRequest) (*interceptResponse, error) {
|
||||
|
||||
respChan := make(chan *interceptResponse, 1)
|
||||
|
||||
newRequest := &interceptRequest{
|
||||
requestID: requestID,
|
||||
request: req,
|
||||
response: respChan,
|
||||
}
|
||||
@ -233,7 +234,9 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
|
||||
req := newRequest.request
|
||||
interceptRequests[msgID] = newRequest
|
||||
|
||||
interceptReq, err := req.ToRPC(msgID)
|
||||
interceptReq, err := req.ToRPC(
|
||||
newRequest.requestID, msgID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -447,10 +450,11 @@ func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte,
|
||||
}
|
||||
|
||||
// ToRPC converts the interception request to its RPC counterpart.
|
||||
func (r *InterceptionRequest) ToRPC(msgID uint64) (*lnrpc.RPCMiddlewareRequest,
|
||||
error) {
|
||||
func (r *InterceptionRequest) ToRPC(requestID,
|
||||
msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) {
|
||||
|
||||
rpcRequest := &lnrpc.RPCMiddlewareRequest{
|
||||
RequestId: requestID,
|
||||
MsgId: msgID,
|
||||
RawMacaroon: r.RawMacaroon,
|
||||
CustomCaveatCondition: r.CustomCaveatCondition,
|
||||
@ -495,6 +499,7 @@ func (r *InterceptionRequest) ToRPC(msgID uint64) (*lnrpc.RPCMiddlewareRequest,
|
||||
// out to a middleware and the response that is eventually sent back by the
|
||||
// middleware.
|
||||
type interceptRequest struct {
|
||||
requestID uint64
|
||||
request *InterceptionRequest
|
||||
response chan *interceptResponse
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user