Implement a typed version of call to avoid useless matching

This commit is contained in:
elsirion 2022-05-13 19:43:36 +00:00 committed by Christian Decker
parent 7046252f96
commit 10917743fe
2 changed files with 74 additions and 1 deletions

View File

@ -24,6 +24,7 @@ pub use crate::{
notifications::Notification,
primitives::RpcError,
};
use crate::model::IntoRequest;
///
pub struct ClnRpc {
@ -105,6 +106,13 @@ impl ClnRpc {
})
}
}
pub async fn call_typed<R: IntoRequest>(&mut self, request: R) -> Result<R::Response, RpcError> {
Ok(self.call(request.into())
.await?
.try_into()
.expect("CLN will reply correctly"))
}
}
/// Used to skip optional arrays when serializing requests.
@ -142,4 +150,23 @@ mod test {
read_req
);
}
#[tokio::test]
async fn test_typed_call() {
let req = requests::GetinfoRequest {};
let (uds1, uds2) = UnixStream::pair().unwrap();
let mut cln = ClnRpc::from_stream(uds1).unwrap();
let mut read = FramedRead::new(uds2, JsonCodec::default());
tokio::task::spawn(async move {
let _: GetinfoResponse = cln.call_typed(req).await.unwrap();
});
let read_req = dbg!(read.next().await.unwrap().unwrap());
assert_eq!(
json!({"id": 1, "method": "getinfo", "params": {}, "jsonrpc": "2.0"}),
read_req
);
}
}

View File

@ -6,7 +6,7 @@ import sys
import re
from msggen.model import (ArrayField, CompositeField, EnumField,
PrimitiveField, Service)
PrimitiveField, Service, Method)
from msggen.gen.generator import IGenerator
logger = logging.getLogger(__name__)
@ -214,6 +214,7 @@ class RustGenerator(IGenerator):
use crate::primitives::*;
#[allow(unused_imports)]
use serde::{{Deserialize, Serialize}};
use super::{IntoRequest, Request};
""")
@ -221,9 +222,24 @@ class RustGenerator(IGenerator):
req = meth.request
_, decl = gen_composite(req)
self.write(decl, numindent=1)
self.generate_request_trait_impl(meth)
self.write("}\n\n")
def generate_request_trait_impl(self, method: Method):
self.write(dedent(f"""\
impl From<{method.request.typename}> for Request {{
fn from(r: {method.request.typename}) -> Self {{
Request::{method.name}(r)
}}
}}
impl IntoRequest for {method.request.typename} {{
type Response = super::responses::{method.response.typename};
}}
"""), numindent=1)
def generate_responses(self, service: Service):
self.write("""
pub mod responses {
@ -231,6 +247,7 @@ class RustGenerator(IGenerator):
use crate::primitives::*;
#[allow(unused_imports)]
use serde::{{Deserialize, Serialize}};
use super::{TryFromResponseError, Response};
""")
@ -238,9 +255,25 @@ class RustGenerator(IGenerator):
res = meth.response
_, decl = gen_composite(res)
self.write(decl, numindent=1)
self.generate_response_trait_impl(meth)
self.write("}\n\n")
def generate_response_trait_impl(self, method: Method):
self.write(dedent(f"""\
impl TryFrom<Response> for {method.response.typename} {{
type Error = super::TryFromResponseError;
fn try_from(response: Response) -> Result<Self, Self::Error> {{
match response {{
Response::{method.name}(response) => Ok(response),
_ => Err(TryFromResponseError)
}}
}}
}}
"""), numindent=1)
def generate_enums(self, service: Service):
"""The Request and Response enums serve as parsing primitives.
"""
@ -275,10 +308,23 @@ class RustGenerator(IGenerator):
""")
def generate_request_trait(self):
self.write("""
pub trait IntoRequest: Into<Request> {
type Response: TryFrom<Response, Error = TryFromResponseError>;
}
#[derive(Debug)]
pub struct TryFromResponseError;
""")
def generate(self, service: Service) -> None:
self.write(header)
self.generate_enums(service)
self.generate_request_trait()
self.generate_requests(service)
self.generate_responses(service)