rpcperms: intercept errors too

This commit is contained in:
Oliver Gugger 2022-07-06 21:17:00 +02:00
parent 502542da60
commit b1d8767a0c
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
2 changed files with 110 additions and 37 deletions

View File

@ -817,14 +817,27 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
return nil, err
}
resp, respErr := handler(ctx, req)
if respErr != nil {
return resp, respErr
// Call the handler, which executes the request against lnd.
lndResp, lndErr := handler(ctx, req)
if lndErr != nil {
// The call to lnd ended in an error and not a normal
// proto message response. Send the error to the
// interceptor as well to inform about the abnormal
// termination of the stream and to give the option to
// replace the error message with a custom one.
replacedErr, err := r.interceptMessage(
ctx, TypeResponse, requestID, false,
info.FullMethod, lndErr,
)
if err != nil {
return nil, err
}
return lndResp, replacedErr.(error)
}
return r.interceptMessage(
ctx, TypeResponse, requestID, false, info.FullMethod,
resp,
lndResp,
)
}
}
@ -883,7 +896,27 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer
interceptor: r,
}
return handler(srv, wrappedSS)
// Call the stream handler, which will block as long as the
// stream is alive.
lndErr := handler(srv, wrappedSS)
if lndErr != nil {
// This is an error being returned from lnd. Send it to
// the interceptor as well to inform about the abnormal
// termination of the stream and to give the option to
// replace the error message with a custom one.
replacedErr, err := r.interceptMessage(
ss.Context(), TypeResponse, requestID,
true, info.FullMethod, lndErr,
)
if err != nil {
return err
}
return replacedErr.(error)
}
// Normal/successful termination of the stream.
return nil
}
}

View File

@ -263,24 +263,41 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
break
}
// For intercepted messages we also allow the
// content itself to be overwritten.
if t.ReplaceResponse {
response.replace = true
protoMsg, err := parseProto(
requestInfo.request.ProtoTypeName,
t.ReplacementSerialized,
// If there's nothing to replace, we're done,
// this request was just accepted.
if !t.ReplaceResponse {
break
}
// We are replacing the response, the question
// now just is: was it an error or a proper
// proto message?
response.replace = true
if requestInfo.request.IsError {
response.replacement = errors.New(
string(t.ReplacementSerialized),
)
if err != nil {
response.err = err
break
}
response.replacement = protoMsg
break
}
// Not an error but a proper proto message that
// needs to be replaced. For that we need to
// parse it from the raw bytes into the full RPC
// message.
protoMsg, err := parseProto(
requestInfo.request.ProtoTypeName,
t.ReplacementSerialized,
)
if err != nil {
response.err = err
break
}
response.replacement = protoMsg
default:
return fmt.Errorf("unknown middleware "+
"message: %v", msg)
@ -369,6 +386,10 @@ type InterceptionRequest struct {
// ProtoTypeName is the fully qualified name of the protobuf type of the
// request or response message that is serialized in the field above.
ProtoTypeName string
// IsError indicates that the message contained within this request is
// an error. Will only ever be true for response messages.
IsError bool
}
// NewMessageInterceptionRequest creates a new interception request for either
@ -382,24 +403,36 @@ func NewMessageInterceptionRequest(ctx context.Context,
return nil, err
}
rpcReq, ok := m.(proto.Message)
if !ok {
return nil, fmt.Errorf("msg is not proto message: %v", m)
}
rawRequest, err := proto.Marshal(rpcReq)
if err != nil {
return nil, fmt.Errorf("cannot marshal proto msg: %v", err)
req := &InterceptionRequest{
Type: authType,
StreamRPC: isStream,
Macaroon: mac,
RawMacaroon: rawMacaroon,
FullURI: fullMethod,
}
return &InterceptionRequest{
Type: authType,
StreamRPC: isStream,
Macaroon: mac,
RawMacaroon: rawMacaroon,
FullURI: fullMethod,
ProtoSerialized: rawRequest,
ProtoTypeName: string(proto.MessageName(rpcReq)),
}, nil
// The message is either a proto message or an error, we don't support
// any other types being intercepted.
switch t := m.(type) {
case proto.Message:
req.ProtoSerialized, err = proto.Marshal(t)
if err != nil {
return nil, fmt.Errorf("cannot marshal proto msg: %v",
err)
}
req.ProtoTypeName = string(proto.MessageName(t))
case error:
req.ProtoSerialized = []byte(t.Error())
req.ProtoTypeName = "error"
req.IsError = true
default:
return nil, fmt.Errorf("unsupported type for interception "+
"request: %v", m)
}
return req, nil
}
// NewStreamAuthInterceptionRequest creates a new interception request for a
@ -484,6 +517,7 @@ func (r *InterceptionRequest) ToRPC(requestID,
StreamRpc: r.StreamRPC,
TypeName: r.ProtoTypeName,
Serialized: r.ProtoSerialized,
IsError: r.IsError,
},
}
@ -549,8 +583,14 @@ func replaceProtoMsg(target interface{}, replacement interface{}) error {
return fmt.Errorf("replacement message is of wrong type")
}
proto.Reset(targetMsg)
proto.Merge(targetMsg, replacementMsg)
replacementBytes, err := proto.Marshal(replacementMsg)
if err != nil {
return fmt.Errorf("error marshaling replacement: %v", err)
}
err = proto.Unmarshal(replacementBytes, targetMsg)
if err != nil {
return fmt.Errorf("error unmarshaling replacement: %v", err)
}
return nil
}