mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 01:43:16 +01:00
rpcperms: intercept errors too
This commit is contained in:
parent
502542da60
commit
b1d8767a0c
@ -817,14 +817,27 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, respErr := handler(ctx, req)
|
||||
if respErr != nil {
|
||||
return resp, respErr
|
||||
// Call the handler, which executes the request against lnd.
|
||||
lndResp, lndErr := handler(ctx, req)
|
||||
if lndErr != nil {
|
||||
// The call to lnd ended in an error and not a normal
|
||||
// proto message response. Send the error to the
|
||||
// interceptor as well to inform about the abnormal
|
||||
// termination of the stream and to give the option to
|
||||
// replace the error message with a custom one.
|
||||
replacedErr, err := r.interceptMessage(
|
||||
ctx, TypeResponse, requestID, false,
|
||||
info.FullMethod, lndErr,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return lndResp, replacedErr.(error)
|
||||
}
|
||||
|
||||
return r.interceptMessage(
|
||||
ctx, TypeResponse, requestID, false, info.FullMethod,
|
||||
resp,
|
||||
lndResp,
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -883,7 +896,27 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer
|
||||
interceptor: r,
|
||||
}
|
||||
|
||||
return handler(srv, wrappedSS)
|
||||
// Call the stream handler, which will block as long as the
|
||||
// stream is alive.
|
||||
lndErr := handler(srv, wrappedSS)
|
||||
if lndErr != nil {
|
||||
// This is an error being returned from lnd. Send it to
|
||||
// the interceptor as well to inform about the abnormal
|
||||
// termination of the stream and to give the option to
|
||||
// replace the error message with a custom one.
|
||||
replacedErr, err := r.interceptMessage(
|
||||
ss.Context(), TypeResponse, requestID,
|
||||
true, info.FullMethod, lndErr,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return replacedErr.(error)
|
||||
}
|
||||
|
||||
// Normal/successful termination of the stream.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -263,24 +263,41 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
|
||||
break
|
||||
}
|
||||
|
||||
// For intercepted messages we also allow the
|
||||
// content itself to be overwritten.
|
||||
if t.ReplaceResponse {
|
||||
response.replace = true
|
||||
protoMsg, err := parseProto(
|
||||
requestInfo.request.ProtoTypeName,
|
||||
t.ReplacementSerialized,
|
||||
// If there's nothing to replace, we're done,
|
||||
// this request was just accepted.
|
||||
if !t.ReplaceResponse {
|
||||
break
|
||||
}
|
||||
|
||||
// We are replacing the response, the question
|
||||
// now just is: was it an error or a proper
|
||||
// proto message?
|
||||
response.replace = true
|
||||
if requestInfo.request.IsError {
|
||||
response.replacement = errors.New(
|
||||
string(t.ReplacementSerialized),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
response.err = err
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
response.replacement = protoMsg
|
||||
break
|
||||
}
|
||||
|
||||
// Not an error but a proper proto message that
|
||||
// needs to be replaced. For that we need to
|
||||
// parse it from the raw bytes into the full RPC
|
||||
// message.
|
||||
protoMsg, err := parseProto(
|
||||
requestInfo.request.ProtoTypeName,
|
||||
t.ReplacementSerialized,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
response.err = err
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
response.replacement = protoMsg
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown middleware "+
|
||||
"message: %v", msg)
|
||||
@ -369,6 +386,10 @@ type InterceptionRequest struct {
|
||||
// ProtoTypeName is the fully qualified name of the protobuf type of the
|
||||
// request or response message that is serialized in the field above.
|
||||
ProtoTypeName string
|
||||
|
||||
// IsError indicates that the message contained within this request is
|
||||
// an error. Will only ever be true for response messages.
|
||||
IsError bool
|
||||
}
|
||||
|
||||
// NewMessageInterceptionRequest creates a new interception request for either
|
||||
@ -382,24 +403,36 @@ func NewMessageInterceptionRequest(ctx context.Context,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rpcReq, ok := m.(proto.Message)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("msg is not proto message: %v", m)
|
||||
}
|
||||
rawRequest, err := proto.Marshal(rpcReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot marshal proto msg: %v", err)
|
||||
req := &InterceptionRequest{
|
||||
Type: authType,
|
||||
StreamRPC: isStream,
|
||||
Macaroon: mac,
|
||||
RawMacaroon: rawMacaroon,
|
||||
FullURI: fullMethod,
|
||||
}
|
||||
|
||||
return &InterceptionRequest{
|
||||
Type: authType,
|
||||
StreamRPC: isStream,
|
||||
Macaroon: mac,
|
||||
RawMacaroon: rawMacaroon,
|
||||
FullURI: fullMethod,
|
||||
ProtoSerialized: rawRequest,
|
||||
ProtoTypeName: string(proto.MessageName(rpcReq)),
|
||||
}, nil
|
||||
// The message is either a proto message or an error, we don't support
|
||||
// any other types being intercepted.
|
||||
switch t := m.(type) {
|
||||
case proto.Message:
|
||||
req.ProtoSerialized, err = proto.Marshal(t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot marshal proto msg: %v",
|
||||
err)
|
||||
}
|
||||
req.ProtoTypeName = string(proto.MessageName(t))
|
||||
|
||||
case error:
|
||||
req.ProtoSerialized = []byte(t.Error())
|
||||
req.ProtoTypeName = "error"
|
||||
req.IsError = true
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type for interception "+
|
||||
"request: %v", m)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// NewStreamAuthInterceptionRequest creates a new interception request for a
|
||||
@ -484,6 +517,7 @@ func (r *InterceptionRequest) ToRPC(requestID,
|
||||
StreamRpc: r.StreamRPC,
|
||||
TypeName: r.ProtoTypeName,
|
||||
Serialized: r.ProtoSerialized,
|
||||
IsError: r.IsError,
|
||||
},
|
||||
}
|
||||
|
||||
@ -549,8 +583,14 @@ func replaceProtoMsg(target interface{}, replacement interface{}) error {
|
||||
return fmt.Errorf("replacement message is of wrong type")
|
||||
}
|
||||
|
||||
proto.Reset(targetMsg)
|
||||
proto.Merge(targetMsg, replacementMsg)
|
||||
replacementBytes, err := proto.Marshal(replacementMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling replacement: %v", err)
|
||||
}
|
||||
err = proto.Unmarshal(replacementBytes, targetMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling replacement: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user