Add WithContext and Tests

This commit is contained in:
henghonglee 2023-09-05 02:37:39 +08:00 committed by Jeffrey Czyz
parent a42aeb5667
commit 973636bd2a
No known key found for this signature in database
GPG key ID: 3A4E08275D5E96D2
3 changed files with 90 additions and 4 deletions

View file

@ -42,7 +42,7 @@ use crate::sign::{EntropySource, ChannelSigner, SignerProvider, NodeSigner, Reci
use crate::events::ClosureReason; use crate::events::ClosureReason;
use crate::routing::gossip::NodeId; use crate::routing::gossip::NodeId;
use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer}; use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer};
use crate::util::logger::Logger; use crate::util::logger::{Logger, WithContext};
use crate::util::errors::APIError; use crate::util::errors::APIError;
use crate::util::config::{UserConfig, ChannelConfig, LegacyChannelConfig, ChannelHandshakeConfig, ChannelHandshakeLimits, MaxDustHTLCExposure}; use crate::util::config::{UserConfig, ChannelConfig, LegacyChannelConfig, ChannelHandshakeConfig, ChannelHandshakeLimits, MaxDustHTLCExposure};
use crate::util::scid_utils::scid_from_parts; use crate::util::scid_utils::scid_from_parts;
@ -6463,6 +6463,7 @@ impl<SP: Deref> InboundV1Channel<SP> where SP::Target: SignerProvider {
F::Target: FeeEstimator, F::Target: FeeEstimator,
L::Target: Logger, L::Target: Logger,
{ {
let logger = WithContext::from(logger, Some(counterparty_node_id), Some(msg.temporary_channel_id));
let announced_channel = if (msg.channel_flags & 1) == 1 { true } else { false }; let announced_channel = if (msg.channel_flags & 1) == 1 { true } else { false };
// First check the channel type is known, failing before we do anything else if we don't // First check the channel type is known, failing before we do anything else if we don't
@ -6529,7 +6530,7 @@ impl<SP: Deref> InboundV1Channel<SP> where SP::Target: SignerProvider {
if msg.htlc_minimum_msat >= full_channel_value_msat { if msg.htlc_minimum_msat >= full_channel_value_msat {
return Err(ChannelError::Close(format!("Minimum htlc value ({}) was larger than full channel value ({})", msg.htlc_minimum_msat, full_channel_value_msat))); return Err(ChannelError::Close(format!("Minimum htlc value ({}) was larger than full channel value ({})", msg.htlc_minimum_msat, full_channel_value_msat)));
} }
Channel::<SP>::check_remote_fee(&channel_type, fee_estimator, msg.feerate_per_kw, None, logger)?; Channel::<SP>::check_remote_fee(&channel_type, fee_estimator, msg.feerate_per_kw, None, &&logger)?;
let max_counterparty_selected_contest_delay = u16::min(config.channel_handshake_limits.their_to_self_delay, MAX_LOCAL_BREAKDOWN_TIMEOUT); let max_counterparty_selected_contest_delay = u16::min(config.channel_handshake_limits.their_to_self_delay, MAX_LOCAL_BREAKDOWN_TIMEOUT);
if msg.to_self_delay > max_counterparty_selected_contest_delay { if msg.to_self_delay > max_counterparty_selected_contest_delay {

View file

@ -18,6 +18,7 @@ use bitcoin::secp256k1::PublicKey;
use core::cmp; use core::cmp;
use core::fmt; use core::fmt;
use core::ops::Deref;
use crate::ln::ChannelId; use crate::ln::ChannelId;
#[cfg(c_bindings)] #[cfg(c_bindings)]
@ -152,6 +153,39 @@ pub trait Logger {
fn log(&self, record: Record); fn log(&self, record: Record);
} }
/// Adds relevant context to a [`Record`] before passing it to the wrapped [`Logger`].
pub struct WithContext<'a, L: Deref> where L::Target: Logger {
/// The logger to delegate to after adding context to the record.
logger: &'a L,
/// The node id of the peer pertaining to the logged record.
peer_id: Option<PublicKey>,
/// The channel id of the channel pertaining to the logged record.
channel_id: Option<ChannelId>,
}
impl<'a, L: Deref> Logger for WithContext<'a, L> where L::Target: Logger {
fn log(&self, mut record: Record) {
if self.peer_id.is_some() {
record.peer_id = self.peer_id
};
if self.channel_id.is_some() {
record.channel_id = self.channel_id;
}
self.logger.log(record)
}
}
impl<'a, L: Deref> WithContext<'a, L> where L::Target: Logger {
/// Wraps the given logger, providing additional context to any logged records.
pub fn from(logger: &'a L, peer_id: Option<PublicKey>, channel_id: Option<ChannelId>) -> Self {
WithContext {
logger,
peer_id,
channel_id,
}
}
}
/// Wrapper for logging a [`PublicKey`] in hex format. /// Wrapper for logging a [`PublicKey`] in hex format.
/// ///
/// This is not exported to bindings users as fmt can't be used in C /// This is not exported to bindings users as fmt can't be used in C
@ -202,7 +236,9 @@ impl<T: fmt::Display, I: core::iter::Iterator<Item = T> + Clone> fmt::Display fo
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::util::logger::{Logger, Level}; use bitcoin::secp256k1::{PublicKey, SecretKey, Secp256k1};
use crate::ln::ChannelId;
use crate::util::logger::{Logger, Level, WithContext};
use crate::util::test_utils::TestLogger; use crate::util::test_utils::TestLogger;
use crate::sync::Arc; use crate::sync::Arc;
@ -243,6 +279,41 @@ mod tests {
wrapper.call_macros(); wrapper.call_macros();
} }
#[test]
fn test_logging_with_context() {
let logger = &TestLogger::new();
let secp_ctx = Secp256k1::new();
let pk = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap());
let context_logger = WithContext::from(&logger, Some(pk), Some(ChannelId([0; 32])));
log_error!(context_logger, "This is an error");
log_warn!(context_logger, "This is an error");
log_debug!(context_logger, "This is an error");
log_trace!(context_logger, "This is an error");
log_gossip!(context_logger, "This is an error");
log_info!(context_logger, "This is an error");
logger.assert_log_context_contains(
"lightning::util::logger::tests", Some(pk), Some(ChannelId([0;32])), 6
);
}
#[test]
fn test_logging_with_multiple_wrapped_context() {
let logger = &TestLogger::new();
let secp_ctx = Secp256k1::new();
let pk = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap());
let context_logger = &WithContext::from(&logger, None, Some(ChannelId([0; 32])));
let full_context_logger = WithContext::from(&context_logger, Some(pk), None);
log_error!(full_context_logger, "This is an error");
log_warn!(full_context_logger, "This is an error");
log_debug!(full_context_logger, "This is an error");
log_trace!(full_context_logger, "This is an error");
log_gossip!(full_context_logger, "This is an error");
log_info!(full_context_logger, "This is an error");
logger.assert_log_context_contains(
"lightning::util::logger::tests", Some(pk), Some(ChannelId([0;32])), 6
);
}
#[test] #[test]
fn test_log_ordering() { fn test_log_ordering() {
assert!(Level::Error > Level::Warn); assert!(Level::Error > Level::Warn);

View file

@ -931,6 +931,7 @@ pub struct TestLogger {
level: Level, level: Level,
pub(crate) id: String, pub(crate) id: String,
pub lines: Mutex<HashMap<(String, String), usize>>, pub lines: Mutex<HashMap<(String, String), usize>>,
pub context: Mutex<HashMap<(String, Option<PublicKey>, Option<ChannelId>), usize>>,
} }
impl TestLogger { impl TestLogger {
@ -941,7 +942,8 @@ impl TestLogger {
TestLogger { TestLogger {
level: Level::Trace, level: Level::Trace,
id, id,
lines: Mutex::new(HashMap::new()) lines: Mutex::new(HashMap::new()),
context: Mutex::new(HashMap::new()),
} }
} }
pub fn enable(&mut self, level: Level) { pub fn enable(&mut self, level: Level) {
@ -976,11 +978,23 @@ impl TestLogger {
}).map(|(_, c) | { c }).sum(); }).map(|(_, c) | { c }).sum();
assert_eq!(l, count) assert_eq!(l, count)
} }
pub fn assert_log_context_contains(
&self, module: &str, peer_id: Option<PublicKey>, channel_id: Option<ChannelId>, count: usize
) {
let context_entries = self.context.lock().unwrap();
let l: usize = context_entries.iter()
.filter(|&(&(ref m, ref p, ref c), _)| m == module && *p == peer_id && *c == channel_id)
.map(|(_, c) | c)
.sum();
assert_eq!(l, count)
}
} }
impl Logger for TestLogger { impl Logger for TestLogger {
fn log(&self, record: Record) { fn log(&self, record: Record) {
*self.lines.lock().unwrap().entry((record.module_path.to_string(), format!("{}", record.args))).or_insert(0) += 1; *self.lines.lock().unwrap().entry((record.module_path.to_string(), format!("{}", record.args))).or_insert(0) += 1;
*self.context.lock().unwrap().entry((record.module_path.to_string(), record.peer_id, record.channel_id)).or_insert(0) += 1;
if record.level >= self.level { if record.level >= self.level {
#[cfg(all(not(ldk_bench), feature = "std"))] { #[cfg(all(not(ldk_bench), feature = "std"))] {
let pfx = format!("{} {} [{}:{}]", self.id, record.level.to_string(), record.module_path, record.line); let pfx = format!("{} {} [{}:{}]", self.id, record.level.to_string(), record.module_path, record.line);