diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 146105405..22f4bdcc1 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -42,6 +42,10 @@ path = "fuzz_targets/channel_target.rs" name = "full_stack_target" path = "fuzz_targets/full_stack_target.rs" +[[bin]] +name = "router_target" +path = "fuzz_targets/router_target.rs" + [[bin]] name = "chanmon_deser_target" path = "fuzz_targets/chanmon_deser_target.rs" diff --git a/fuzz/fuzz_targets/channel_target.rs b/fuzz/fuzz_targets/channel_target.rs index 27891e11a..1bcc3d708 100644 --- a/fuzz/fuzz_targets/channel_target.rs +++ b/fuzz/fuzz_targets/channel_target.rs @@ -10,7 +10,7 @@ use bitcoin::network::serialize::{serialize, BitcoinHash}; use lightning::ln::channel::{Channel, ChannelKeys}; use lightning::ln::channelmanager::{HTLCFailReason, PendingForwardHTLCInfo}; use lightning::ln::msgs; -use lightning::ln::msgs::MsgDecodable; +use lightning::ln::msgs::{MsgDecodable, ErrorAction}; use lightning::chain::chaininterface::{FeeEstimator, ConfirmationTarget}; use lightning::chain::transaction::OutPoint; use lightning::util::reset_rng_state; @@ -120,7 +120,8 @@ pub fn do_test(data: &[u8]) { msgs::DecodeError::BadSignature => return, msgs::DecodeError::BadText => return, msgs::DecodeError::ExtraAddressesPerType => return, - msgs::DecodeError::WrongLength => panic!("We picked the length..."), + msgs::DecodeError::BadLengthDescriptor => return, + msgs::DecodeError::ShortRead => panic!("We picked the length..."), } } } @@ -141,7 +142,8 @@ pub fn do_test(data: &[u8]) { msgs::DecodeError::BadSignature => return, msgs::DecodeError::BadText => return, msgs::DecodeError::ExtraAddressesPerType => return, - msgs::DecodeError::WrongLength => panic!("We picked the length..."), + msgs::DecodeError::BadLengthDescriptor => return, + msgs::DecodeError::ShortRead => panic!("We picked the length..."), } } } @@ -237,10 +239,25 @@ pub fn do_test(data: &[u8]) { let funding_locked = decode_msg!(msgs::FundingLocked, 32+33); return_err!(channel.funding_locked(&funding_locked)); + macro_rules! test_err { + ($expr: expr) => { + match $expr { + Ok(r) => Some(r), + Err(e) => match e.action { + None => return, + Some(ErrorAction::UpdateFailHTLC {..}) => None, + Some(ErrorAction::DisconnectPeer {..}) => return, + Some(ErrorAction::IgnoreError) => None, + Some(ErrorAction::SendErrorMessage {..}) => None, + }, + } + } + } + loop { match get_slice!(1)[0] { 0 => { - return_err!(channel.send_htlc(slice_to_be64(get_slice!(8)), [42; 32], slice_to_be32(get_slice!(4)), msgs::OnionPacket { + test_err!(channel.send_htlc(slice_to_be64(get_slice!(8)), [42; 32], slice_to_be32(get_slice!(4)), msgs::OnionPacket { version: get_slice!(1)[0], public_key: get_pubkey!(), hop_data: [0; 20*65], @@ -248,44 +265,45 @@ pub fn do_test(data: &[u8]) { })); }, 1 => { - return_err!(channel.send_commitment()); + test_err!(channel.send_commitment()); }, 2 => { let update_add_htlc = decode_msg!(msgs::UpdateAddHTLC, 32+8+8+32+4+4+33+20*65+32); - return_err!(channel.update_add_htlc(&update_add_htlc, PendingForwardHTLCInfo::dummy())); + test_err!(channel.update_add_htlc(&update_add_htlc, PendingForwardHTLCInfo::dummy())); }, 3 => { let update_fulfill_htlc = decode_msg!(msgs::UpdateFulfillHTLC, 32 + 8 + 32); - return_err!(channel.update_fulfill_htlc(&update_fulfill_htlc)); + test_err!(channel.update_fulfill_htlc(&update_fulfill_htlc)); }, 4 => { let update_fail_htlc = decode_msg_with_len16!(msgs::UpdateFailHTLC, 32 + 8, 1); - return_err!(channel.update_fail_htlc(&update_fail_htlc, HTLCFailReason::dummy())); + test_err!(channel.update_fail_htlc(&update_fail_htlc, HTLCFailReason::dummy())); }, 5 => { let update_fail_malformed_htlc = decode_msg!(msgs::UpdateFailMalformedHTLC, 32+8+32+2); - return_err!(channel.update_fail_malformed_htlc(&update_fail_malformed_htlc, HTLCFailReason::dummy())); + test_err!(channel.update_fail_malformed_htlc(&update_fail_malformed_htlc, HTLCFailReason::dummy())); }, 6 => { let commitment_signed = decode_msg_with_len16!(msgs::CommitmentSigned, 32+64, 64); - return_err!(channel.commitment_signed(&commitment_signed)); + test_err!(channel.commitment_signed(&commitment_signed)); }, 7 => { let revoke_and_ack = decode_msg!(msgs::RevokeAndACK, 32+32+33); - return_err!(channel.revoke_and_ack(&revoke_and_ack)); + test_err!(channel.revoke_and_ack(&revoke_and_ack)); }, 8 => { let update_fee = decode_msg!(msgs::UpdateFee, 32+4); - return_err!(channel.update_fee(&fee_est, &update_fee)); + test_err!(channel.update_fee(&fee_est, &update_fee)); }, 9 => { let shutdown = decode_msg_with_len16!(msgs::Shutdown, 32, 1); - return_err!(channel.shutdown(&fee_est, &shutdown)); + test_err!(channel.shutdown(&fee_est, &shutdown)); if channel.is_shutdown() { return; } }, 10 => { let closing_signed = decode_msg!(msgs::ClosingSigned, 32+8+64); - if return_err!(channel.closing_signed(&fee_est, &closing_signed)).1.is_some() { + let sign_res = test_err!(channel.closing_signed(&fee_est, &closing_signed)); + if sign_res.is_some() && sign_res.unwrap().1.is_some() { assert!(channel.is_shutdown()); return; } diff --git a/fuzz/fuzz_targets/router_target.rs b/fuzz/fuzz_targets/router_target.rs new file mode 100644 index 000000000..13733adb6 --- /dev/null +++ b/fuzz/fuzz_targets/router_target.rs @@ -0,0 +1,219 @@ +extern crate bitcoin; +extern crate lightning; +extern crate secp256k1; + +use lightning::ln::channelmanager::ChannelDetails; +use lightning::ln::msgs; +use lightning::ln::msgs::{MsgDecodable, RoutingMessageHandler}; +use lightning::ln::router::{Router, RouteHint}; +use lightning::util::reset_rng_state; + +use secp256k1::key::PublicKey; +use secp256k1::Secp256k1; + +#[inline] +pub fn slice_to_be16(v: &[u8]) -> u16 { + ((v[0] as u16) << 8*1) | + ((v[1] as u16) << 8*0) +} + +#[inline] +pub fn slice_to_be32(v: &[u8]) -> u32 { + ((v[0] as u32) << 8*3) | + ((v[1] as u32) << 8*2) | + ((v[2] as u32) << 8*1) | + ((v[3] as u32) << 8*0) +} + +#[inline] +pub fn slice_to_be64(v: &[u8]) -> u64 { + ((v[0] as u64) << 8*7) | + ((v[1] as u64) << 8*6) | + ((v[2] as u64) << 8*5) | + ((v[3] as u64) << 8*4) | + ((v[4] as u64) << 8*3) | + ((v[5] as u64) << 8*2) | + ((v[6] as u64) << 8*1) | + ((v[7] as u64) << 8*0) +} + +#[inline] +pub fn do_test(data: &[u8]) { + reset_rng_state(); + + let mut read_pos = 0; + macro_rules! get_slice_nonadvancing { + ($len: expr) => { + { + if data.len() < read_pos + $len as usize { + return; + } + &data[read_pos..read_pos + $len as usize] + } + } + } + macro_rules! get_slice { + ($len: expr) => { + { + let res = get_slice_nonadvancing!($len); + read_pos += $len; + res + } + } + } + + macro_rules! decode_msg { + ($MsgType: path, $len: expr) => { + match <($MsgType)>::decode(get_slice!($len)) { + Ok(msg) => msg, + Err(e) => match e { + msgs::DecodeError::UnknownRealmByte => return, + msgs::DecodeError::BadPublicKey => return, + msgs::DecodeError::BadSignature => return, + msgs::DecodeError::BadText => return, + msgs::DecodeError::ExtraAddressesPerType => return, + msgs::DecodeError::BadLengthDescriptor => return, + msgs::DecodeError::ShortRead => panic!("We picked the length..."), + } + } + } + } + + macro_rules! decode_msg_with_len16 { + ($MsgType: path, $begin_len: expr, $excess: expr) => { + { + let extra_len = slice_to_be16(&get_slice_nonadvancing!($begin_len as usize + 2)[$begin_len..$begin_len + 2]); + decode_msg!($MsgType, $begin_len as usize + 2 + (extra_len as usize) + $excess) + } + } + } + + let secp_ctx = Secp256k1::new(); + macro_rules! get_pubkey { + () => { + match PublicKey::from_slice(&secp_ctx, get_slice!(33)) { + Ok(key) => key, + Err(_) => return, + } + } + } + + let our_pubkey = get_pubkey!(); + let router = Router::new(our_pubkey.clone()); + + loop { + match get_slice!(1)[0] { + 0 => { + let start_len = slice_to_be16(&get_slice_nonadvancing!(64 + 2)[64..64 + 2]) as usize; + let addr_len = slice_to_be16(&get_slice_nonadvancing!(64+start_len+2 + 74)[64+start_len+2 + 72..64+start_len+2 + 74]); + if addr_len > (37+1)*4 { + return; + } + let _ = router.handle_node_announcement(&decode_msg_with_len16!(msgs::NodeAnnouncement, 64, 288)); + }, + 1 => { + let _ = router.handle_channel_announcement(&decode_msg_with_len16!(msgs::ChannelAnnouncement, 64*4, 32+8+33*4)); + }, + 2 => { + let _ = router.handle_channel_update(&decode_msg!(msgs::ChannelUpdate, 128)); + }, + 3 => { + match get_slice!(1)[0] { + 0 => { + router.handle_htlc_fail_channel_update(&msgs::HTLCFailChannelUpdate::ChannelUpdateMessage {msg: decode_msg!(msgs::ChannelUpdate, 128)}); + }, + 1 => { + let short_channel_id = slice_to_be64(get_slice!(8)); + router.handle_htlc_fail_channel_update(&msgs::HTLCFailChannelUpdate::ChannelClosed {short_channel_id}); + }, + _ => return, + } + }, + 4 => { + let target = get_pubkey!(); + let mut first_hops_vec = Vec::new(); + let first_hops = match get_slice!(1)[0] { + 0 => None, + 1 => { + let count = slice_to_be16(get_slice!(2)); + for _ in 0..count { + first_hops_vec.push(ChannelDetails { + channel_id: [0; 32], + short_channel_id: Some(slice_to_be64(get_slice!(8))), + remote_network_id: get_pubkey!(), + channel_value_satoshis: slice_to_be64(get_slice!(8)), + user_id: 0, + }); + } + Some(&first_hops_vec[..]) + }, + _ => return, + }; + let mut last_hops_vec = Vec::new(); + let last_hops = { + let count = slice_to_be16(get_slice!(2)); + for _ in 0..count { + last_hops_vec.push(RouteHint { + src_node_id: get_pubkey!(), + short_channel_id: slice_to_be64(get_slice!(8)), + fee_base_msat: slice_to_be64(get_slice!(8)), + fee_proportional_millionths: slice_to_be32(get_slice!(4)), + cltv_expiry_delta: slice_to_be16(get_slice!(2)), + htlc_minimum_msat: slice_to_be64(get_slice!(8)), + }); + } + &last_hops_vec[..] + }; + let _ = router.get_route(&target, first_hops, last_hops, slice_to_be64(get_slice!(8)), slice_to_be32(get_slice!(4))); + }, + _ => return, + } + } +} + +#[cfg(feature = "afl")] +extern crate afl; +#[cfg(feature = "afl")] +fn main() { + afl::read_stdio_bytes(|data| { + do_test(&data); + }); +} + +#[cfg(feature = "honggfuzz")] +#[macro_use] extern crate honggfuzz; +#[cfg(feature = "honggfuzz")] +fn main() { + loop { + fuzz!(|data| { + do_test(data); + }); + } +} + +#[cfg(test)] +mod tests { + fn extend_vec_from_hex(hex: &str, out: &mut Vec) { + let mut b = 0; + for (idx, c) in hex.as_bytes().iter().enumerate() { + b <<= 4; + match *c { + b'A'...b'F' => b |= c - b'A' + 10, + b'a'...b'f' => b |= c - b'a' + 10, + b'0'...b'9' => b |= c - b'0', + _ => panic!("Bad hex"), + } + if (idx & 1) == 1 { + out.push(b); + b = 0; + } + } + } + + #[test] + fn duplicate_crash() { + let mut a = Vec::new(); + extend_vec_from_hex("00", &mut a); + super::do_test(&a); + } +} diff --git a/src/ln/msgs.rs b/src/ln/msgs.rs index 53aa23972..7f502530a 100644 --- a/src/ln/msgs.rs +++ b/src/ln/msgs.rs @@ -33,10 +33,13 @@ pub enum DecodeError { BadSignature, /// Value expected to be text wasn't decodable as text BadText, - /// Buffer not of right length (either too short or too long) - WrongLength, + /// Buffer too short + ShortRead, /// node_announcement included more than one address of a given type! ExtraAddressesPerType, + /// A length descriptor in the packet didn't describe the later data correctly + /// (currently only generated in node_announcement) + BadLengthDescriptor, } pub trait MsgDecodable: Sized { fn decode(v: &[u8]) -> Result; @@ -500,8 +503,9 @@ impl Error for DecodeError { DecodeError::BadPublicKey => "Invalid public key in packet", DecodeError::BadSignature => "Invalid signature in packet", DecodeError::BadText => "Invalid text in packet", - DecodeError::WrongLength => "Data was wrong length for packet", + DecodeError::ShortRead => "Packet extended beyond the provided bytes", DecodeError::ExtraAddressesPerType => "More than one address of a single type", + DecodeError::BadLengthDescriptor => "A length descriptor in the packet didn't describe the later data correctly", } } } @@ -537,9 +541,9 @@ macro_rules! secp_signature { impl MsgDecodable for LocalFeatures { fn decode(v: &[u8]) -> Result { - if v.len() < 2 { return Err(DecodeError::WrongLength); } + if v.len() < 2 { return Err(DecodeError::ShortRead); } let len = byte_utils::slice_to_be16(&v[0..2]) as usize; - if v.len() < len + 2 { return Err(DecodeError::WrongLength); } + if v.len() < len + 2 { return Err(DecodeError::ShortRead); } let mut flags = Vec::with_capacity(len); flags.extend_from_slice(&v[2..2 + len]); Ok(Self { @@ -559,9 +563,9 @@ impl MsgEncodable for LocalFeatures { impl MsgDecodable for GlobalFeatures { fn decode(v: &[u8]) -> Result { - if v.len() < 2 { return Err(DecodeError::WrongLength); } + if v.len() < 2 { return Err(DecodeError::ShortRead); } let len = byte_utils::slice_to_be16(&v[0..2]) as usize; - if v.len() < len + 2 { return Err(DecodeError::WrongLength); } + if v.len() < len + 2 { return Err(DecodeError::ShortRead); } let mut flags = Vec::with_capacity(len); flags.extend_from_slice(&v[2..2 + len]); Ok(Self { @@ -583,7 +587,7 @@ impl MsgDecodable for Init { fn decode(v: &[u8]) -> Result { let global_features = GlobalFeatures::decode(v)?; if v.len() < global_features.flags.len() + 4 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let local_features = LocalFeatures::decode(&v[global_features.flags.len() + 2..])?; Ok(Self { @@ -604,12 +608,12 @@ impl MsgEncodable for Init { impl MsgDecodable for Ping { fn decode(v: &[u8]) -> Result { if v.len() < 4 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let ponglen = byte_utils::slice_to_be16(&v[0..2]); let byteslen = byte_utils::slice_to_be16(&v[2..4]); if v.len() < 4 + byteslen as usize { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } Ok(Self { ponglen, @@ -629,11 +633,11 @@ impl MsgEncodable for Ping { impl MsgDecodable for Pong { fn decode(v: &[u8]) -> Result { if v.len() < 2 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let byteslen = byte_utils::slice_to_be16(&v[0..2]); if v.len() < 2 + byteslen as usize { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } Ok(Self { byteslen @@ -652,7 +656,7 @@ impl MsgEncodable for Pong { impl MsgDecodable for OpenChannel { fn decode(v: &[u8]) -> Result { if v.len() < 2*32+6*8+4+2*2+6*33+1 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let ctx = Secp256k1::without_caps(); @@ -660,11 +664,9 @@ impl MsgDecodable for OpenChannel { if v.len() >= 321 { let len = byte_utils::slice_to_be16(&v[319..321]) as usize; if v.len() < 321+len { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } shutdown_scriptpubkey = Some(Script::from(v[321..321+len].to_vec())); - } else if v.len() != 2*32+6*8+4+2*2+6*33+1 { // Message cant have 1 extra byte - return Err(DecodeError::WrongLength); } Ok(OpenChannel { @@ -725,7 +727,7 @@ impl MsgEncodable for OpenChannel { impl MsgDecodable for AcceptChannel { fn decode(v: &[u8]) -> Result { if v.len() < 32+4*8+4+2*2+6*33 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let ctx = Secp256k1::without_caps(); @@ -733,11 +735,9 @@ impl MsgDecodable for AcceptChannel { if v.len() >= 272 { let len = byte_utils::slice_to_be16(&v[270..272]) as usize; if v.len() < 272+len { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } shutdown_scriptpubkey = Some(Script::from(v[272..272+len].to_vec())); - } else if v.len() != 32+4*8+4+2*2+6*33 { // Message cant have 1 extra byte - return Err(DecodeError::WrongLength); } let mut temporary_channel_id = [0; 32]; @@ -792,7 +792,7 @@ impl MsgEncodable for AcceptChannel { impl MsgDecodable for FundingCreated { fn decode(v: &[u8]) -> Result { if v.len() < 32+32+2+64 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let ctx = Secp256k1::without_caps(); let mut temporary_channel_id = [0; 32]; @@ -820,7 +820,7 @@ impl MsgEncodable for FundingCreated { impl MsgDecodable for FundingSigned { fn decode(v: &[u8]) -> Result { if v.len() < 32+64 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let ctx = Secp256k1::without_caps(); let mut channel_id = [0; 32]; @@ -843,7 +843,7 @@ impl MsgEncodable for FundingSigned { impl MsgDecodable for FundingLocked { fn decode(v: &[u8]) -> Result { if v.len() < 32+33 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let ctx = Secp256k1::without_caps(); let mut channel_id = [0; 32]; @@ -866,11 +866,11 @@ impl MsgEncodable for FundingLocked { impl MsgDecodable for Shutdown { fn decode(v: &[u8]) -> Result { if v.len() < 32 + 2 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let scriptlen = byte_utils::slice_to_be16(&v[32..34]) as usize; if v.len() < 32 + 2 + scriptlen { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut channel_id = [0; 32]; channel_id[..].copy_from_slice(&v[0..32]); @@ -893,7 +893,7 @@ impl MsgEncodable for Shutdown { impl MsgDecodable for ClosingSigned { fn decode(v: &[u8]) -> Result { if v.len() < 32 + 8 + 64 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let secp_ctx = Secp256k1::without_caps(); let mut channel_id = [0; 32]; @@ -919,7 +919,7 @@ impl MsgEncodable for ClosingSigned { impl MsgDecodable for UpdateAddHTLC { fn decode(v: &[u8]) -> Result { if v.len() < 32+8+8+32+4+1+33+20*65+32 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut channel_id = [0; 32]; channel_id[..].copy_from_slice(&v[0..32]); @@ -951,7 +951,7 @@ impl MsgEncodable for UpdateAddHTLC { impl MsgDecodable for UpdateFulfillHTLC { fn decode(v: &[u8]) -> Result { if v.len() < 32+8+32 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut channel_id = [0; 32]; channel_id[..].copy_from_slice(&v[0..32]); @@ -977,7 +977,7 @@ impl MsgEncodable for UpdateFulfillHTLC { impl MsgDecodable for UpdateFailHTLC { fn decode(v: &[u8]) -> Result { if v.len() < 32+8 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut channel_id = [0; 32]; channel_id[..].copy_from_slice(&v[0..32]); @@ -1002,7 +1002,7 @@ impl MsgEncodable for UpdateFailHTLC { impl MsgDecodable for UpdateFailMalformedHTLC { fn decode(v: &[u8]) -> Result { if v.len() < 32+8+32+2 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut channel_id = [0; 32]; channel_id[..].copy_from_slice(&v[0..32]); @@ -1030,14 +1030,14 @@ impl MsgEncodable for UpdateFailMalformedHTLC { impl MsgDecodable for CommitmentSigned { fn decode(v: &[u8]) -> Result { if v.len() < 32+64+2 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut channel_id = [0; 32]; channel_id[..].copy_from_slice(&v[0..32]); let htlcs = byte_utils::slice_to_be16(&v[96..98]) as usize; if v.len() < 32+64+2+htlcs*64 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut htlc_signatures = Vec::with_capacity(htlcs); let secp_ctx = Secp256k1::without_caps(); @@ -1068,7 +1068,7 @@ impl MsgEncodable for CommitmentSigned { impl MsgDecodable for RevokeAndACK { fn decode(v: &[u8]) -> Result { if v.len() < 32+32+33 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut channel_id = [0; 32]; channel_id[..].copy_from_slice(&v[0..32]); @@ -1095,7 +1095,7 @@ impl MsgEncodable for RevokeAndACK { impl MsgDecodable for UpdateFee { fn decode(v: &[u8]) -> Result { if v.len() < 32+4 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut channel_id = [0; 32]; channel_id[..].copy_from_slice(&v[0..32]); @@ -1117,12 +1117,12 @@ impl MsgEncodable for UpdateFee { impl MsgDecodable for ChannelReestablish { fn decode(v: &[u8]) -> Result { if v.len() < 32+2*8+33 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let your_last_per_commitment_secret = if v.len() > 32+2*8+33 { if v.len() < 32+2*8+33 + 32 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut inner_array = [0; 32]; inner_array.copy_from_slice(&v[48..48+32]); @@ -1165,7 +1165,7 @@ impl MsgEncodable for ChannelReestablish { impl MsgDecodable for AnnouncementSignatures { fn decode(v: &[u8]) -> Result { if v.len() < 32+8+64*2 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let secp_ctx = Secp256k1::without_caps(); let mut channel_id = [0; 32]; @@ -1194,7 +1194,7 @@ impl MsgDecodable for UnsignedNodeAnnouncement { fn decode(v: &[u8]) -> Result { let features = GlobalFeatures::decode(&v[..])?; if v.len() < features.encoded_len() + 4 + 33 + 3 + 32 + 2 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let start = features.encoded_len(); @@ -1206,22 +1206,23 @@ impl MsgDecodable for UnsignedNodeAnnouncement { let addrlen = byte_utils::slice_to_be16(&v[start + 72..start + 74]) as usize; if v.len() < start + 74 + addrlen { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } + let addr_read_limit = start + 74 + addrlen; let mut addresses = Vec::with_capacity(4); let mut read_pos = start + 74; loop { - if v.len() <= read_pos { break; } + if addr_read_limit <= read_pos { break; } match v[read_pos] { 0 => { read_pos += 1; }, 1 => { - if v.len() < read_pos + 1 + 6 { - return Err(DecodeError::WrongLength); - } if addresses.len() > 0 { return Err(DecodeError::ExtraAddressesPerType); } + if addr_read_limit < read_pos + 1 + 6 { + return Err(DecodeError::BadLengthDescriptor); + } let mut addr = [0; 4]; addr.copy_from_slice(&v[read_pos + 1..read_pos + 5]); addresses.push(NetAddress::IPv4 { @@ -1231,12 +1232,12 @@ impl MsgDecodable for UnsignedNodeAnnouncement { read_pos += 1 + 6; }, 2 => { - if v.len() < read_pos + 1 + 18 { - return Err(DecodeError::WrongLength); - } if addresses.len() > 1 || (addresses.len() == 1 && addresses[0].get_id() != 1) { return Err(DecodeError::ExtraAddressesPerType); } + if addr_read_limit < read_pos + 1 + 18 { + return Err(DecodeError::BadLengthDescriptor); + } let mut addr = [0; 16]; addr.copy_from_slice(&v[read_pos + 1..read_pos + 17]); addresses.push(NetAddress::IPv6 { @@ -1246,12 +1247,12 @@ impl MsgDecodable for UnsignedNodeAnnouncement { read_pos += 1 + 18; }, 3 => { - if v.len() < read_pos + 1 + 12 { - return Err(DecodeError::WrongLength); - } if addresses.len() > 2 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 2) { return Err(DecodeError::ExtraAddressesPerType); } + if addr_read_limit < read_pos + 1 + 12 { + return Err(DecodeError::BadLengthDescriptor); + } let mut addr = [0; 10]; addr.copy_from_slice(&v[read_pos + 1..read_pos + 11]); addresses.push(NetAddress::OnionV2 { @@ -1261,12 +1262,12 @@ impl MsgDecodable for UnsignedNodeAnnouncement { read_pos += 1 + 12; }, 4 => { - if v.len() < read_pos + 1 + 37 { - return Err(DecodeError::WrongLength); - } if addresses.len() > 3 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 3) { return Err(DecodeError::ExtraAddressesPerType); } + if addr_read_limit < read_pos + 1 + 37 { + return Err(DecodeError::BadLengthDescriptor); + } let mut ed25519_pubkey = [0; 32]; ed25519_pubkey.copy_from_slice(&v[read_pos + 1..read_pos + 33]); addresses.push(NetAddress::OnionV3 { @@ -1340,7 +1341,7 @@ impl MsgEncodable for UnsignedNodeAnnouncement { impl MsgDecodable for NodeAnnouncement { fn decode(v: &[u8]) -> Result { if v.len() < 64 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let secp_ctx = Secp256k1::without_caps(); Ok(Self { @@ -1364,7 +1365,7 @@ impl MsgDecodable for UnsignedChannelAnnouncement { fn decode(v: &[u8]) -> Result { let features = GlobalFeatures::decode(&v[..])?; if v.len() < features.encoded_len() + 32 + 8 + 33*4 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let start = features.encoded_len(); let secp_ctx = Secp256k1::without_caps(); @@ -1397,7 +1398,7 @@ impl MsgEncodable for UnsignedChannelAnnouncement { impl MsgDecodable for ChannelAnnouncement { fn decode(v: &[u8]) -> Result { if v.len() < 64*4 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let secp_ctx = Secp256k1::without_caps(); Ok(Self { @@ -1426,7 +1427,7 @@ impl MsgEncodable for ChannelAnnouncement { impl MsgDecodable for UnsignedChannelUpdate { fn decode(v: &[u8]) -> Result { if v.len() < 32+8+4+2+2+8+4+4 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } Ok(Self { chain_hash: deserialize(&v[0..32]).unwrap(), @@ -1458,7 +1459,7 @@ impl MsgEncodable for UnsignedChannelUpdate { impl MsgDecodable for ChannelUpdate { fn decode(v: &[u8]) -> Result { if v.len() < 128 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let secp_ctx = Secp256k1::without_caps(); Ok(Self { @@ -1479,7 +1480,7 @@ impl MsgEncodable for ChannelUpdate { impl MsgDecodable for OnionRealm0HopData { fn decode(v: &[u8]) -> Result { if v.len() < 32 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } Ok(OnionRealm0HopData { short_channel_id: byte_utils::slice_to_be64(&v[0..8]), @@ -1502,7 +1503,7 @@ impl MsgEncodable for OnionRealm0HopData { impl MsgDecodable for OnionHopData { fn decode(v: &[u8]) -> Result { if v.len() < 65 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let realm = v[0]; if realm != 0 { @@ -1530,7 +1531,7 @@ impl MsgEncodable for OnionHopData { impl MsgDecodable for OnionPacket { fn decode(v: &[u8]) -> Result { if v.len() < 1+33+20*65+32 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut hop_data = [0; 20*65]; hop_data.copy_from_slice(&v[34..1334]); @@ -1559,15 +1560,15 @@ impl MsgEncodable for OnionPacket { impl MsgDecodable for DecodedOnionErrorPacket { fn decode(v: &[u8]) -> Result { if v.len() < 32 + 4 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let failuremsg_len = byte_utils::slice_to_be16(&v[32..34]) as usize; if v.len() < 32 + 4 + failuremsg_len { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let padding_len = byte_utils::slice_to_be16(&v[34 + failuremsg_len..]) as usize; if v.len() < 32 + 4 + failuremsg_len + padding_len { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let mut hmac = [0; 32]; @@ -1594,11 +1595,11 @@ impl MsgEncodable for DecodedOnionErrorPacket { impl MsgDecodable for OnionErrorPacket { fn decode(v: &[u8]) -> Result { if v.len() < 2 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let len = byte_utils::slice_to_be16(&v[0..2]) as usize; if v.len() < 2 + len { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } Ok(Self { data: v[2..len+2].to_vec(), @@ -1626,11 +1627,11 @@ impl MsgEncodable for ErrorMessage { impl MsgDecodable for ErrorMessage { fn decode(v: &[u8]) -> Result { if v.len() < 34 { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let len = byte_utils::slice_to_be16(&v[32..34]); if v.len() < 34 + len as usize { - return Err(DecodeError::WrongLength); + return Err(DecodeError::ShortRead); } let data = match String::from_utf8(v[34..34 + len as usize].to_vec()) { Ok(s) => s, diff --git a/src/ln/router.rs b/src/ln/router.rs index ad7ae6751..f30eb7912 100644 --- a/src/ln/router.rs +++ b/src/ln/router.rs @@ -428,39 +428,42 @@ impl Router { ( $chan_id: expr, $dest_node_id: expr, $directional_info: expr, $starting_fee_msat: expr ) => { //TODO: Explore simply adding fee to hit htlc_minimum_msat if $starting_fee_msat as u64 + final_value_msat > $directional_info.htlc_minimum_msat { - let new_fee = $directional_info.fee_base_msat as u64 + ($starting_fee_msat + final_value_msat) * ($directional_info.fee_proportional_millionths as u64) / 1000000; - let mut total_fee = $starting_fee_msat as u64; - let mut hm_entry = dist.entry(&$directional_info.src_node_id); - let old_entry = hm_entry.or_insert_with(|| { - let node = network.nodes.get(&$directional_info.src_node_id).unwrap(); - (u64::max_value(), - node.lowest_inbound_channel_fee_base_msat as u64, - node.lowest_inbound_channel_fee_proportional_millionths as u64, - RouteHop { - pubkey: PublicKey::new(), - short_channel_id: 0, - fee_msat: 0, - cltv_expiry_delta: 0, - }) - }); - if $directional_info.src_node_id != network.our_node_id { - // Ignore new_fee for channel-from-us as we assume all channels-from-us - // will have the same effective-fee - total_fee += new_fee; - total_fee += old_entry.2 * (final_value_msat + total_fee) / 1000000 + old_entry.1; - } - let new_graph_node = RouteGraphNode { - pubkey: $directional_info.src_node_id, - lowest_fee_to_peer_through_node: total_fee, - }; - if old_entry.0 > total_fee { - targets.push(new_graph_node); - old_entry.0 = total_fee; - old_entry.3 = RouteHop { - pubkey: $dest_node_id.clone(), - short_channel_id: $chan_id.clone(), - fee_msat: new_fee, // This field is ignored on the last-hop anyway - cltv_expiry_delta: $directional_info.cltv_expiry_delta as u32, + let proportional_fee_millions = ($starting_fee_msat + final_value_msat).checked_mul($directional_info.fee_proportional_millionths as u64); + if let Some(proportional_fee) = proportional_fee_millions { + let new_fee = $directional_info.fee_base_msat as u64 + proportional_fee / 1000000; + let mut total_fee = $starting_fee_msat as u64; + let mut hm_entry = dist.entry(&$directional_info.src_node_id); + let old_entry = hm_entry.or_insert_with(|| { + let node = network.nodes.get(&$directional_info.src_node_id).unwrap(); + (u64::max_value(), + node.lowest_inbound_channel_fee_base_msat as u64, + node.lowest_inbound_channel_fee_proportional_millionths as u64, + RouteHop { + pubkey: PublicKey::new(), + short_channel_id: 0, + fee_msat: 0, + cltv_expiry_delta: 0, + }) + }); + if $directional_info.src_node_id != network.our_node_id { + // Ignore new_fee for channel-from-us as we assume all channels-from-us + // will have the same effective-fee + total_fee += new_fee; + total_fee += old_entry.2 * (final_value_msat + total_fee) / 1000000 + old_entry.1; + } + let new_graph_node = RouteGraphNode { + pubkey: $directional_info.src_node_id, + lowest_fee_to_peer_through_node: total_fee, + }; + if old_entry.0 > total_fee { + targets.push(new_graph_node); + old_entry.0 = total_fee; + old_entry.3 = RouteHop { + pubkey: $dest_node_id.clone(), + short_channel_id: $chan_id.clone(), + fee_msat: new_fee, // This field is ignored on the last-hop anyway + cltv_expiry_delta: $directional_info.cltv_expiry_delta as u32, + } } } }