cln-rpc: Make Pubkey and ShortChannelId proper types

This commit is contained in:
Christian Decker 2022-04-01 14:43:34 +10:30 committed by Rusty Russell
parent ef145c7900
commit 1613c44b0a
7 changed files with 159 additions and 25 deletions

View File

@ -1,3 +1,6 @@
// Huge json!() macros require lots of recursion
#![recursion_limit = "1024"]
mod convert;
pub mod pb;
mod server;

View File

@ -1,4 +1,5 @@
tonic::include_proto!("cln");
use std::str::FromStr;
use cln_rpc::primitives::{
Amount as JAmount, AmountOrAll as JAmountOrAll, AmountOrAny as JAmountOrAny,
@ -104,8 +105,8 @@ impl From<&AmountOrAny> for JAmountOrAny {
impl From<RouteHop> for cln_rpc::primitives::Routehop {
fn from(c: RouteHop) -> Self {
Self {
id: hex::encode(c.id),
scid: c.short_channel_id,
id: cln_rpc::primitives::Pubkey::from_slice(&c.id).unwrap(),
scid: cln_rpc::primitives::ShortChannelId::from_str(&c.short_channel_id).unwrap(),
feebase: c.feebase.as_ref().unwrap().into(),
feeprop: c.feeprop,
expirydelta: c.expirydelta as u16,

View File

@ -1,6 +1,7 @@
use crate::codec::JsonCodec;
use crate::codec::JsonRpc;
use anyhow::{Context, Error, Result};
use anyhow::{Context, Result};
pub use anyhow::Error;
use futures_util::sink::SinkExt;
use futures_util::StreamExt;
use log::{debug, trace};

View File

@ -1,6 +1,10 @@
use anyhow::Context;
use anyhow::{anyhow, Error, Result};
use serde::{Deserialize, Serialize};
use serde::{Deserializer, Serializer};
use std::str::FromStr;
use std::string::ToString;
#[derive(Copy, Clone, Serialize, Deserialize, Debug)]
#[allow(non_camel_case_types)]
pub enum ChannelState {
@ -68,6 +72,118 @@ impl Amount {
}
}
#[derive(Clone, Debug)]
pub struct Pubkey([u8; 33]);
impl Serialize for Pubkey {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&hex::encode(&self.0))
}
}
impl<'de> Deserialize<'de> for Pubkey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let s: String = Deserialize::deserialize(deserializer)?;
Ok(Self::from_str(&s).map_err(|e| Error::custom(e.to_string()))?)
}
}
impl FromStr for Pubkey {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let raw =
hex::decode(&s).with_context(|| format!("{} is not a valid hex-encoded pubkey", s))?;
Ok(Pubkey(raw.try_into().map_err(|_| {
anyhow!("could not convert {} into pubkey", s)
})?))
}
}
impl ToString for Pubkey {
fn to_string(&self) -> String {
hex::encode(self.0)
}
}
impl Pubkey {
pub fn from_slice(data: &[u8]) -> Result<Pubkey, crate::Error> {
Ok(Pubkey(
data.try_into().with_context(|| "Not a valid pubkey")?,
))
}
pub fn to_vec(&self) -> Vec<u8> {
self.0.to_vec()
}
}
#[derive(Clone, Debug)]
pub struct ShortChannelId(u64);
impl Serialize for ShortChannelId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for ShortChannelId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let s: String = Deserialize::deserialize(deserializer)?;
Ok(Self::from_str(&s).map_err(|e| Error::custom(e.to_string()))?)
}
}
impl FromStr for ShortChannelId {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Result<Vec<u64>, _> = s.split('x').map(|p| p.parse()).collect();
let parts = parts.with_context(|| format!("Malformed short_channel_id: {}", s))?;
if parts.len() != 3 {
return Err(anyhow!(
"Malformed short_channel_id: element count mismatch"
));
}
Ok(ShortChannelId(
(parts[0] << 40) | (parts[1] << 16) | (parts[2] << 0),
))
}
}
impl ToString for ShortChannelId {
fn to_string(&self) -> String {
format!("{}x{}x{}", self.block(), self.txindex(), self.outnum())
}
}
impl ShortChannelId {
pub fn block(&self) -> u32 {
(self.0 >> 40) as u32 & 0xFFFFFF
}
pub fn txindex(&self) -> u32 {
(self.0 >> 16) as u32 & 0xFFFFFF
}
pub fn outnum(&self) -> u16 {
self.0 as u16 & 0xFFFF
}
}
pub type Secret = [u8; 32];
pub type Txid = [u8; 32];
pub type Hash = [u8; 32];
pub type NodeId = Pubkey;
#[derive(Clone, Debug, PartialEq)]
pub struct Outpoint {
pub txid: Vec<u8>,
@ -428,8 +544,8 @@ impl Serialize for OutputDesc {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Routehop {
pub id: String,
pub scid: String,
pub id: Pubkey,
pub scid: ShortChannelId,
pub feebase: Amount,
pub feeprop: u32,
pub expirydelta: u16,

View File

@ -10,8 +10,8 @@ typemap = {
'boolean': 'bool',
'hex': 'bytes',
'msat': 'Amount',
'msat|all': 'AmountOrAll',
'msat|any': 'AmountOrAny',
'msat_or_all': 'AmountOrAll',
'msat_or_any': 'AmountOrAny',
'number': 'sint64',
'pubkey': 'bytes',
'short_channel_id': 'string',
@ -275,8 +275,10 @@ class GrpcConverterGenerator:
'hex': f'hex::decode(i).unwrap()',
}.get(typ, f'i.into()')
self.write(f"{name}: c.{name}.iter().map(|i| {mapping}).collect(),\n", numindent=3)
if f.required:
self.write(f"{name}: c.{name}.iter().map(|i| {mapping}).collect(), // Rule #3 \n", numindent=3)
else:
self.write(f"{name}: c.{name}.as_ref().map(|arr| arr.iter().map(|i| {mapping}).collect()).unwrap_or(vec![]), // Rule #3 \n", numindent=3)
elif isinstance(f, EnumField):
if f.required:
self.write(f"{name}: c.{name} as i32,\n", numindent=3)
@ -295,12 +297,14 @@ class GrpcConverterGenerator:
'u16?': f'c.{name}.map(|v| v.into())',
'msat': f'Some(c.{name}.into())',
'msat?': f'c.{name}.map(|f| f.into())',
'pubkey': f'hex::decode(&c.{name}).unwrap()',
'pubkey?': f'c.{name}.as_ref().map(|v| hex::decode(&v).unwrap())',
'pubkey': f'c.{name}.to_vec()',
'pubkey?': f'c.{name}.as_ref().map(|v| v.to_vec())',
'hex': f'hex::decode(&c.{name}).unwrap()',
'hex?': f'c.{name}.as_ref().map(|v| hex::decode(&v).unwrap())',
'txid': f'hex::decode(&c.{name}).unwrap()',
'txid?': f'c.{name}.as_ref().map(|v| hex::decode(&v).unwrap())',
'short_channel_id': f'c.{name}.to_string()',
'short_channel_id?': f'c.{name}.as_ref().map(|v| v.to_string())',
}.get(
typ,
f'c.{name}.clone()' # default to just assignment
@ -335,6 +339,7 @@ class GrpcConverterGenerator:
#[allow(unused_imports)]
use cln_rpc::model::{responses,requests};
use crate::pb;
use std::str::FromStr;
""")
@ -380,7 +385,10 @@ class GrpcUnconverterGenerator(GrpcConverterGenerator):
'hex': f'hex::encode(s)',
'u32': f's.clone()',
}.get(typ, f's.into()')
self.write(f"{name}: c.{name}.iter().map(|s| {mapping}).collect(),\n", numindent=3)
if f.required:
self.write(f"{name}: c.{name}.iter().map(|s| {mapping}).collect(), // Rule #4\n", numindent=3)
else:
self.write(f"{name}: Some(c.{name}.iter().map(|s| {mapping}).collect()), // Rule #4\n", numindent=3)
elif isinstance(f, EnumField):
if f.required:
@ -400,17 +408,19 @@ class GrpcUnconverterGenerator(GrpcConverterGenerator):
'hex': f'hex::encode(&c.{name})',
'hex?': f'c.{name}.clone().map(|v| hex::encode(v))',
'txid?': f'c.{name}.clone().map(|v| hex::encode(v))',
'pubkey': f'hex::encode(&c.{name})',
'pubkey?': f'c.{name}.clone().map(|v| hex::encode(v))',
'pubkey': f'cln_rpc::primitives::Pubkey::from_slice(&c.{name}).unwrap()',
'pubkey?': f'c.{name}.as_ref().map(|v| cln_rpc::primitives::Pubkey::from_slice(v).unwrap())',
'msat': f'c.{name}.as_ref().unwrap().into()',
'msat?': f'c.{name}.as_ref().map(|a| a.into())',
'msat|all': f'c.{name}.as_ref().unwrap().into()',
'msat|all?': f'c.{name}.as_ref().map(|a| a.into())',
'msat|any': f'c.{name}.as_ref().unwrap().into()',
'msat|any?': f'c.{name}.as_ref().map(|a| a.into())',
'msat_or_all': f'c.{name}.as_ref().unwrap().into()',
'msat_or_all?': f'c.{name}.as_ref().map(|a| a.into())',
'msat_or_any': f'c.{name}.as_ref().unwrap().into()',
'msat_or_any?': f'c.{name}.as_ref().map(|a| a.into())',
'feerate': f'c.{name}.as_ref().unwrap().into()',
'feerate?': f'c.{name}.as_ref().map(|a| a.into())',
'RoutehintList?': f'c.{name}.clone().map(|rl| rl.into())',
'short_channel_id': f'cln_rpc::primitives::ShortChannelId::from_str(&c.{name}).unwrap()',
'short_channel_id?': f'c.{name}.as_ref().map(|v| cln_rpc::primitives::ShortChannelId::from_str(&v).unwrap())',
}.get(
typ,
f'c.{name}.clone()' # default to just assignment

View File

@ -265,8 +265,8 @@ class PrimitiveField(Field):
"pubkey",
"signature",
"msat",
"msat|any",
"msat|all",
"msat_or_any",
"msat_or_all",
"hex",
"short_channel_id",
"short_channel_id_dir",

View File

@ -33,11 +33,11 @@ typemap = {
'boolean': 'bool',
'hex': 'String',
'msat': 'Amount',
'msat|all': 'AmountOrAll',
'msat|any': 'AmountOrAny',
'msat_or_all': 'AmountOrAll',
'msat_or_any': 'AmountOrAny',
'number': 'i64',
'pubkey': 'String',
'short_channel_id': 'String',
'pubkey': 'Pubkey',
'short_channel_id': 'ShortChannelId',
'signature': 'String',
'string': 'String',
'txid': 'String',
@ -166,7 +166,10 @@ def gen_array(a):
itemtype = typemap.get(itemtype, itemtype)
alias = a.name.normalized()
defi = f" #[serde(alias = \"{alias}\")]\n pub {name}: {'Vec<'*a.dims}{itemtype}{'>'*a.dims},\n"
if a.required:
defi = f" #[serde(alias = \"{alias}\")]\n pub {name}: {'Vec<'*a.dims}{itemtype}{'>'*a.dims},\n"
else:
defi = f" #[serde(alias = \"{alias}\", skip_serializing_if = \"Option::is_none\")]\n pub {name}: Option<{'Vec<'*a.dims}{itemtype}{'>'*a.dims}>,\n"
return (defi, decl)