package rpcperms import ( "context" "encoding/hex" "errors" "fmt" "sync" "sync/atomic" "time" "github.com/btcsuite/btcd/chaincfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/macaroons" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "gopkg.in/macaroon.v2" ) var ( // ErrShuttingDown is the error that's returned when the server is // shutting down and a request cannot be served anymore. ErrShuttingDown = errors.New("server shutting down") // ErrTimeoutReached is the error that's returned if any of the // middleware's tasks is not completed in the given time. ErrTimeoutReached = errors.New("intercept timeout reached") // errClientQuit is the error that's returned if the client closes the // middleware communication stream before a request was fully handled. errClientQuit = errors.New("interceptor RPC client quit") ) // MiddlewareHandler is a type that communicates with a middleware over the // established bi-directional RPC stream. It sends messages to the middleware // whenever the custom business logic implemented there should give feedback to // a request or response that's happening on the main gRPC server. type MiddlewareHandler struct { // lastMsgID is the ID of the last intercept message that was forwarded // to the middleware. // // NOTE: Must be used atomically! lastMsgID uint64 middlewareName string readOnly bool customCaveatName string receive func() (*lnrpc.RPCMiddlewareResponse, error) send func(request *lnrpc.RPCMiddlewareRequest) error interceptRequests chan *interceptRequest timeout time.Duration // params are our current chain params. params *chaincfg.Params // done is closed when the rpc client terminates. done chan struct{} // quit is closed when lnd is shutting down. quit chan struct{} wg sync.WaitGroup } // NewMiddlewareHandler creates a new handler for the middleware with the given // name and custom caveat name. func NewMiddlewareHandler(name, customCaveatName string, readOnly bool, receive func() (*lnrpc.RPCMiddlewareResponse, error), send func(request *lnrpc.RPCMiddlewareRequest) error, timeout time.Duration, params *chaincfg.Params, quit chan struct{}) *MiddlewareHandler { // We explicitly want to log this as a warning since intercepting any // gRPC messages can also be used for malicious purposes and the user // should be made aware of the risks. log.Warnf("A new gRPC middleware with the name '%s' was registered "+ " with custom_macaroon_caveat='%s', read_only=%v. Make sure "+ "you trust the middleware author since that code will be able "+ "to intercept and possibly modify and gRPC messages sent/"+ "received to/from a client that has a macaroon with that "+ "custom caveat.", name, customCaveatName, readOnly) return &MiddlewareHandler{ middlewareName: name, customCaveatName: customCaveatName, readOnly: readOnly, receive: receive, send: send, interceptRequests: make(chan *interceptRequest), timeout: timeout, params: params, done: make(chan struct{}), quit: quit, } } // intercept handles the full interception lifecycle of a single middleware // event (stream authentication, request interception or response interception). // The lifecycle consists of sending a message to the middleware, receiving a // feedback on it and sending the feedback to the appropriate channel. All steps // are guarded by the configured timeout to make sure a middleware cannot slow // down requests too much. func (h *MiddlewareHandler) intercept( req *InterceptionRequest) (*interceptResponse, error) { respChan := make(chan *interceptResponse, 1) newRequest := &interceptRequest{ request: req, response: respChan, } // timeout is the time after which intercept requests expire. timeout := time.After(h.timeout) // Send the request to the interceptRequests channel for the main // goroutine to be picked up. select { case h.interceptRequests <- newRequest: case <-timeout: log.Errorf("MiddlewareHandler returned error - reached "+ "timeout of %v for request interception", h.timeout) return nil, ErrTimeoutReached case <-h.done: return nil, errClientQuit case <-h.quit: return nil, ErrShuttingDown } // Receive the response and return it. If no response has been received // in AcceptorTimeout, then return false. select { case resp := <-respChan: return resp, nil case <-timeout: log.Errorf("MiddlewareHandler returned error - reached "+ "timeout of %v for response interception", h.timeout) return nil, ErrTimeoutReached case <-h.done: return nil, errClientQuit case <-h.quit: return nil, ErrShuttingDown } } // Run is the main loop for the middleware handler. This function will block // until it receives the signal that lnd is shutting down, or the rpc stream is // cancelled by the client. func (h *MiddlewareHandler) Run() error { // Wait for our goroutines to exit before we return. defer h.wg.Wait() defer log.Debugf("Exiting middleware run loop for %s", h.middlewareName) // Create a channel that responses from middlewares are sent into. responses := make(chan *lnrpc.RPCMiddlewareResponse) // errChan is used by the receive loop to signal any errors that occur // during reading from the stream. This is primarily used to shutdown // the send loop in the case of an RPC client disconnecting. errChan := make(chan error, 1) // Start a goroutine to receive responses from the interceptor. We // expect the receive function to block, so it must be run in a // goroutine (otherwise we could not send more than one intercept // request to the client). h.wg.Add(1) go func() { h.receiveResponses(errChan, responses) h.wg.Done() }() return h.sendInterceptRequests(errChan, responses) } // receiveResponses receives responses for our intercept requests and dispatches // them into the responses channel provided, sending any errors that occur into // the error channel provided. func (h *MiddlewareHandler) receiveResponses(errChan chan error, responses chan *lnrpc.RPCMiddlewareResponse) { for { resp, err := h.receive() if err != nil { errChan <- err return } select { case responses <- resp: case <-h.done: return case <-h.quit: return } } } // sendInterceptRequests handles intercept requests sent to us by our Accept() // function, dispatching them to our acceptor stream and coordinating return of // responses to their callers. func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error, responses chan *lnrpc.RPCMiddlewareResponse) error { // Close the done channel to indicate that the interceptor is no longer // listening and any in-progress requests should be terminated. defer close(h.done) interceptRequests := make(map[uint64]*interceptRequest) for { select { // Consume requests passed to us from our Accept() function and // send them into our stream. case newRequest := <-h.interceptRequests: msgID := atomic.AddUint64(&h.lastMsgID, 1) req := newRequest.request interceptRequests[msgID] = newRequest interceptReq, err := req.ToRPC(msgID) if err != nil { return err } if err := h.send(interceptReq); err != nil { return err } // Process newly received responses from our interceptor, // looking the original request up in our map of requests and // dispatching the response. case resp := <-responses: requestInfo, ok := interceptRequests[resp.RefMsgId] if !ok { continue } response := &interceptResponse{} switch msg := resp.GetMiddlewareMessage().(type) { case *lnrpc.RPCMiddlewareResponse_Feedback: t := msg.Feedback if t.Error != "" { response.err = fmt.Errorf("%s", t.Error) break } // For intercepted responses we also allow the // content itself to be overwritten. if requestInfo.request.Type == TypeResponse && t.ReplaceResponse { response.replace = true protoMsg, err := parseProto( requestInfo.request.ProtoTypeName, t.ReplacementSerialized, ) if err != nil { response.err = err break } response.replacement = protoMsg } default: return fmt.Errorf("unknown middleware "+ "message: %v", msg) } select { case requestInfo.response <- response: case <-h.quit: } delete(interceptRequests, resp.RefMsgId) // If we failed to receive from our middleware, we exit. case err := <-errChan: log.Errorf("Received an error: %v, shutting down", err) return err // Exit if we are shutting down. case <-h.quit: return ErrShuttingDown } } } // InterceptType defines the different types of intercept messages a middleware // can receive. type InterceptType uint8 const ( // TypeStreamAuth is the type of intercept message that is sent when a // client or streaming RPC is initialized. A message with this type will // be sent out during stream initialization so a middleware can // accept/deny the whole stream instead of only single messages on the // stream. TypeStreamAuth InterceptType = 1 // TypeRequest is the type of intercept message that is sent when an RPC // request message is sent to lnd. For client-streaming RPCs a new // message of this type is sent for each individual RPC request sent to // the stream. TypeRequest InterceptType = 2 // TypeResponse is the type of intercept message that is sent when an // RPC response message is sent from lnd to a client. For // server-streaming RPCs a new message of this type is sent for each // individual RPC response sent to the stream. Middleware has the option // to modify a response message before it is sent out to the client. TypeResponse InterceptType = 3 ) // InterceptionRequest is a struct holding all information that is sent to a // middleware whenever there is something to intercept (auth, request, // response). type InterceptionRequest struct { // Type is the type of the interception message. Type InterceptType // StreamRPC is set to true if the invoked RPC method is client or // server streaming. StreamRPC bool // Macaroon holds the macaroon that the client sent to lnd. Macaroon *macaroon.Macaroon // RawMacaroon holds the raw binary serialized macaroon that the client // sent to lnd. RawMacaroon []byte // CustomCaveatName is the name of the custom caveat that the middleware // was intercepting for. CustomCaveatName string // CustomCaveatCondition is the condition of the custom caveat that the // middleware was intercepting for. This can be empty for custom caveats // that only have a name (marker caveats). CustomCaveatCondition string // FullURI is the full RPC method URI that was invoked. FullURI string // ProtoSerialized is the full request or response object in the // protobuf binary serialization format. ProtoSerialized []byte // ProtoTypeName is the fully qualified name of the protobuf type of the // request or response message that is serialized in the field above. ProtoTypeName string } // NewMessageInterceptionRequest creates a new interception request for either // a request or response message. func NewMessageInterceptionRequest(ctx context.Context, authType InterceptType, isStream bool, fullMethod string, m interface{}) (*InterceptionRequest, error) { mac, rawMacaroon, err := macaroonFromContext(ctx) if err != nil { return nil, err } rpcReq, ok := m.(proto.Message) if !ok { return nil, fmt.Errorf("msg is not proto message: %v", m) } rawRequest, err := proto.Marshal(rpcReq) if err != nil { return nil, fmt.Errorf("cannot marshal proto msg: %v", err) } return &InterceptionRequest{ Type: authType, StreamRPC: isStream, Macaroon: mac, RawMacaroon: rawMacaroon, FullURI: fullMethod, ProtoSerialized: rawRequest, ProtoTypeName: string(proto.MessageName(rpcReq)), }, nil } // NewStreamAuthInterceptionRequest creates a new interception request for a // stream authentication message. func NewStreamAuthInterceptionRequest(ctx context.Context, fullMethod string) (*InterceptionRequest, error) { mac, rawMacaroon, err := macaroonFromContext(ctx) if err != nil { return nil, err } return &InterceptionRequest{ Type: TypeStreamAuth, StreamRPC: true, Macaroon: mac, RawMacaroon: rawMacaroon, FullURI: fullMethod, }, nil } // macaroonFromContext tries to extract the macaroon from the incoming context. // If there is no macaroon, a nil error is returned since some RPCs might not // require a macaroon. But in case there is something in the macaroon header // field that cannot be parsed, a non-nil error is returned. func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte, error) { macHex, err := macaroons.RawMacaroonFromContext(ctx) if err != nil { // If there is no macaroon, we continue anyway as it might be an // RPC that doesn't require a macaroon. return nil, nil, nil } macBytes, err := hex.DecodeString(macHex) if err != nil { return nil, nil, err } mac := &macaroon.Macaroon{} if err := mac.UnmarshalBinary(macBytes); err != nil { return nil, nil, err } return mac, macBytes, nil } // ToRPC converts the interception request to its RPC counterpart. func (r *InterceptionRequest) ToRPC(msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) { rpcRequest := &lnrpc.RPCMiddlewareRequest{ MsgId: msgID, RawMacaroon: r.RawMacaroon, CustomCaveatCondition: r.CustomCaveatCondition, } switch r.Type { case TypeStreamAuth: rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_StreamAuth{ StreamAuth: &lnrpc.StreamAuth{ MethodFullUri: r.FullURI, }, } case TypeRequest: rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Request{ Request: &lnrpc.RPCMessage{ MethodFullUri: r.FullURI, StreamRpc: r.StreamRPC, TypeName: r.ProtoTypeName, Serialized: r.ProtoSerialized, }, } case TypeResponse: rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Response{ Response: &lnrpc.RPCMessage{ MethodFullUri: r.FullURI, StreamRpc: r.StreamRPC, TypeName: r.ProtoTypeName, Serialized: r.ProtoSerialized, }, } default: return nil, fmt.Errorf("unknown intercept type %v", r.Type) } return rpcRequest, nil } // interceptRequest is a struct that keeps track of an interception request sent // out to a middleware and the response that is eventually sent back by the // middleware. type interceptRequest struct { request *InterceptionRequest response chan *interceptResponse } // interceptResponse is the response a middleware sends back for each // intercepted message. type interceptResponse struct { err error replace bool replacement interface{} } // parseProto parses a proto serialized message of the given type into its // native version. func parseProto(typeName string, serialized []byte) (proto.Message, error) { messageType, err := protoregistry.GlobalTypes.FindMessageByName( protoreflect.FullName(typeName), ) if err != nil { return nil, err } msg := messageType.New() err = proto.Unmarshal(serialized, msg.Interface()) if err != nil { return nil, err } return msg.Interface(), nil }