Fix DefaultRouter type restrained to only MutexGuard

Type of DerefMut for DefaultRouter was specialized to only MutexGuard.
It should be generic around RefMut and MutexGuard. This commit fixes that
This commit is contained in:
henghonglee 2023-06-29 10:41:38 +08:00
parent 86fd9e7fbc
commit 54bcb6eb02
5 changed files with 91 additions and 55 deletions

View File

@ -885,7 +885,22 @@ mod tests {
fn disconnect_socket(&mut self) {}
}
type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter<Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>, (), TestScorer>>, Arc<test_utils::TestLogger>>;
type ChannelManager =
channelmanager::ChannelManager<
Arc<ChainMonitor>,
Arc<test_utils::TestBroadcaster>,
Arc<KeysManager>,
Arc<KeysManager>,
Arc<KeysManager>,
Arc<test_utils::TestFeeEstimator>,
Arc<DefaultRouter<
Arc<NetworkGraph<Arc<test_utils::TestLogger>>>,
Arc<test_utils::TestLogger>,
Arc<Mutex<TestScorer>>,
(),
TestScorer>
>,
Arc<test_utils::TestLogger>>;
type ChainMonitor = chainmonitor::ChainMonitor<InMemorySigner, Arc<test_utils::TestChainSource>, Arc<test_utils::TestBroadcaster>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>, Arc<FilesystemPersister>>;

View File

@ -752,7 +752,23 @@ pub type SimpleArcChannelManager<M, T, F, L> = ChannelManager<
/// of [`KeysManager`] and [`DefaultRouter`].
///
/// This is not exported to bindings users as Arcs don't make sense in bindings
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = ChannelManager<&'a M, &'b T, &'c KeysManager, &'c KeysManager, &'c KeysManager, &'d F, &'e DefaultRouter<&'f NetworkGraph<&'g L>, &'g L, &'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, ProbabilisticScoringFeeParameters, ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, &'g L>;
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> =
ChannelManager<
&'a M,
&'b T,
&'c KeysManager,
&'c KeysManager,
&'c KeysManager,
&'d F,
&'e DefaultRouter<
&'f NetworkGraph<&'g L>,
&'g L,
&'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>,
ProbabilisticScoringFeeParameters,
ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>
>,
&'g L
>;
macro_rules! define_test_pub_trait { ($vis: vis) => {
/// A trivial trait which describes any [`ChannelManager`] used in testing.

View File

@ -27,15 +27,15 @@ use crate::util::chacha20::ChaCha20;
use crate::io;
use crate::prelude::*;
use crate::sync::{Mutex, MutexGuard};
use crate::sync::{Mutex};
use alloc::collections::BinaryHeap;
use core::{cmp, fmt};
use core::ops::Deref;
use core::ops::{Deref, DerefMut};
/// A [`Router`] implemented using [`find_route`].
pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> where
L::Target: Logger,
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
S::Target: for <'a> LockableScore<'a, Score = Sc>,
{
network_graph: G,
logger: L,
@ -46,7 +46,7 @@ pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref,
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> DefaultRouter<G, L, S, SP, Sc> where
L::Target: Logger,
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
S::Target: for <'a> LockableScore<'a, Score = Sc>,
{
/// Creates a new router.
pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S, score_params: SP) -> Self {
@ -55,9 +55,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Scor
}
}
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
L::Target: Logger,
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
S::Target: for <'a> LockableScore<'a, Score = Sc>,
{
fn find_route(
&self,
@ -73,7 +73,7 @@ impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Sc
};
find_route(
payer, params, &self.network_graph, first_hops, &*self.logger,
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock(), inflight_htlcs),
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), inflight_htlcs),
&self.score_params,
&random_seed_bytes
)
@ -104,15 +104,15 @@ pub trait Router {
/// [`find_route`].
///
/// [`Score`]: crate::routing::scoring::Score
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score> {
scorer: S,
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score<ScoreParams = SP>, SP: Sized> {
scorer: &'a mut S,
// Maps a channel's short channel id and its direction to the liquidity used up.
inflight_htlcs: &'a InFlightHtlcs,
}
impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> {
/// Initialize a new `ScorerAccountingForInFlightHtlcs`.
pub fn new(scorer: S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
ScorerAccountingForInFlightHtlcs {
scorer,
inflight_htlcs
@ -121,11 +121,11 @@ impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
}
#[cfg(c_bindings)]
impl<'a, S: Score> Writeable for ScorerAccountingForInFlightHtlcs<'a, S> {
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Writeable for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> { self.scorer.write(writer) }
}
impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> {
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Score for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
type ScoreParams = S::ScoreParams;
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(

View File

@ -157,8 +157,11 @@ define_score!();
///
/// [`find_route`]: crate::routing::router::find_route
pub trait LockableScore<'a> {
/// The [`Score`] type.
type Score: 'a + Score;
/// The locked [`Score`] type.
type Locked: 'a + Score;
type Locked: DerefMut<Target = Self::Score> + Sized;
/// Returns the locked scorer.
fn lock(&'a self) -> Self::Locked;
@ -174,60 +177,35 @@ pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {}
impl<'a, T> WriteableScore<'a> for T where T: LockableScore<'a> + Writeable {}
/// This is not exported to bindings users
impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
type Score = T;
type Locked = MutexGuard<'a, T>;
fn lock(&'a self) -> MutexGuard<'a, T> {
fn lock(&'a self) -> Self::Locked {
Mutex::lock(self).unwrap()
}
}
impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
type Score = T;
type Locked = RefMut<'a, T>;
fn lock(&'a self) -> RefMut<'a, T> {
fn lock(&'a self) -> Self::Locked {
self.borrow_mut()
}
}
#[cfg(c_bindings)]
/// A concrete implementation of [`LockableScore`] which supports multi-threading.
pub struct MultiThreadedLockableScore<S: Score> {
score: Mutex<S>,
}
#[cfg(c_bindings)]
/// A locked `MultiThreadedLockableScore`.
pub struct MultiThreadedScoreLock<'a, S: Score>(MutexGuard<'a, S>);
#[cfg(c_bindings)]
impl<'a, T: Score + 'a> Score for MultiThreadedScoreLock<'a, T> {
type ScoreParams = <T as Score>::ScoreParams;
fn channel_penalty_msat(&self, scid: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
self.0.channel_penalty_msat(scid, source, target, usage, score_params)
}
fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) {
self.0.payment_path_failed(path, short_channel_id)
}
fn payment_path_successful(&mut self, path: &Path) {
self.0.payment_path_successful(path)
}
fn probe_failed(&mut self, path: &Path, short_channel_id: u64) {
self.0.probe_failed(path, short_channel_id)
}
fn probe_successful(&mut self, path: &Path) {
self.0.probe_successful(path)
}
}
#[cfg(c_bindings)]
impl<'a, T: Score + 'a> Writeable for MultiThreadedScoreLock<'a, T> {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
self.0.write(writer)
}
pub struct MultiThreadedLockableScore<T: Score> {
score: Mutex<T>,
}
#[cfg(c_bindings)]
impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
impl<'a, T: 'a + Score> LockableScore<'a> for MultiThreadedLockableScore<T> {
type Score = T;
type Locked = MultiThreadedScoreLock<'a, T>;
fn lock(&'a self) -> MultiThreadedScoreLock<'a, T> {
fn lock(&'a self) -> Self::Locked {
MultiThreadedScoreLock(Mutex::lock(&self.score).unwrap())
}
}
@ -240,7 +218,7 @@ impl<T: Score> Writeable for MultiThreadedLockableScore<T> {
}
#[cfg(c_bindings)]
impl<'a, T: Score + 'a> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
impl<'a, T: 'a + Score> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
#[cfg(c_bindings)]
impl<T: Score> MultiThreadedLockableScore<T> {
@ -250,6 +228,33 @@ impl<T: Score> MultiThreadedLockableScore<T> {
}
}
#[cfg(c_bindings)]
/// A locked `MultiThreadedLockableScore`.
pub struct MultiThreadedScoreLock<'a, T: Score>(MutexGuard<'a, T>);
#[cfg(c_bindings)]
impl<'a, T: 'a + Score> Writeable for MultiThreadedScoreLock<'a, T> {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
self.0.write(writer)
}
}
#[cfg(c_bindings)]
impl<'a, T: 'a + Score> DerefMut for MultiThreadedScoreLock<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut()
}
}
#[cfg(c_bindings)]
impl<'a, T: 'a + Score> Deref for MultiThreadedScoreLock<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
#[cfg(c_bindings)]
/// This is not exported to bindings users
impl<'a, T: Writeable> Writeable for RefMut<'a, T> {

View File

@ -51,6 +51,7 @@ use regex;
use crate::io;
use crate::prelude::*;
use core::cell::RefCell;
use core::ops::DerefMut;
use core::time::Duration;
use crate::sync::{Mutex, Arc};
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
@ -113,8 +114,8 @@ impl<'a> Router for TestRouter<'a> {
if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
assert_eq!(find_route_query, *params);
if let Ok(ref route) = find_route_res {
let locked_scorer = self.scorer.lock().unwrap();
let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
let mut binding = self.scorer.lock().unwrap();
let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs);
for path in &route.paths {
let mut aggregate_msat = 0u64;
for (idx, hop) in path.hops.iter().rev().enumerate() {
@ -139,10 +140,9 @@ impl<'a> Router for TestRouter<'a> {
return find_route_res;
}
let logger = TestLogger::new();
let scorer = self.scorer.lock().unwrap();
find_route(
payer, params, &self.network_graph, first_hops, &logger,
&ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), &(),
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().unwrap().deref_mut(), &inflight_htlcs), &(),
&[42; 32]
)
}