rpcperms: add replaceProtoMsg

Because of the way the gRPC Receive() method is designed, we need a way
to replace a proto message with the content of another one without
replacing the original instance itself (e.g. overwrite all values in the
existing struct instance).
This commit is contained in:
Oliver Gugger 2022-07-06 21:16:56 +02:00
parent 66258ee7b5
commit dc32ca61f8
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
2 changed files with 116 additions and 0 deletions

View File

@ -529,3 +529,29 @@ func parseProto(typeName string, serialized []byte) (proto.Message, error) {
return msg.Interface(), nil return msg.Interface(), nil
} }
// replaceProtoMsg replaces the given target message with the content of the
// replacement message.
func replaceProtoMsg(target interface{}, replacement interface{}) error {
targetMsg, ok := target.(proto.Message)
if !ok {
return fmt.Errorf("target is not a proto message: %v", target)
}
replacementMsg, ok := replacement.(proto.Message)
if !ok {
return fmt.Errorf("replacement is not a proto message: %v",
replacement)
}
if targetMsg.ProtoReflect().Type() !=
replacementMsg.ProtoReflect().Type() {
return fmt.Errorf("replacement message is of wrong type")
}
proto.Reset(targetMsg)
proto.Merge(targetMsg, replacementMsg)
return nil
}

View File

@ -0,0 +1,90 @@
package rpcperms
import (
"encoding/json"
"testing"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/stretchr/testify/require"
)
// TestReplaceProtoMsg makes sure the proto message replacement works as
// expected.
func TestReplaceProtoMsg(t *testing.T) {
testCases := []struct {
name string
original interface{}
replacement interface{}
expectedErr string
}{{
name: "simple content replacement",
original: &lnrpc.Invoice{
Memo: "This is a memo string",
Value: 123456,
},
replacement: &lnrpc.Invoice{
Memo: "This is the replaced string",
Value: 654321,
},
}, {
name: "replace with empty message",
original: &lnrpc.Invoice{
Memo: "This is a memo string",
Value: 123456,
},
replacement: &lnrpc.Invoice{},
}, {
name: "replace with fewer fields",
original: &lnrpc.Invoice{
Memo: "This is a memo string",
Value: 123456,
},
replacement: &lnrpc.Invoice{
Value: 654321,
},
}, {
name: "wrong replacement type",
original: &lnrpc.Invoice{
Memo: "This is a memo string",
Value: 123456,
},
replacement: &lnrpc.AddInvoiceResponse{},
expectedErr: "replacement message is of wrong type",
}, {
name: "wrong original type",
original: &interceptRequest{},
replacement: &lnrpc.Invoice{
Memo: "This is the replaced string",
Value: 654321,
},
expectedErr: "target is not a proto message",
}}
for _, tc := range testCases {
t.Run(tc.name, func(tt *testing.T) {
err := replaceProtoMsg(tc.original, tc.replacement)
if tc.expectedErr != "" {
require.Error(tt, err)
require.Contains(
tt, err.Error(), tc.expectedErr,
)
return
}
require.NoError(tt, err)
jsonEqual(tt, tc.replacement, tc.original)
})
}
}
func jsonEqual(t *testing.T, expected, actual interface{}) {
expectedJSON, err := json.Marshal(expected)
require.NoError(t, err)
actualJSON, err := json.Marshal(actual)
require.NoError(t, err)
require.JSONEq(t, string(expectedJSON), string(actualJSON))
}