mirror of
https://github.com/lightningdevkit/rust-lightning.git
synced 2025-01-19 05:43:55 +01:00
Fix DefaultRouter type restrained to only MutexGuard
Type of DerefMut for DefaultRouter was specialized to only MutexGuard. It should be generic around RefMut and MutexGuard. This commit fixes that
This commit is contained in:
parent
86fd9e7fbc
commit
54bcb6eb02
@ -885,7 +885,22 @@ mod tests {
|
||||
fn disconnect_socket(&mut self) {}
|
||||
}
|
||||
|
||||
type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter<Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>, (), TestScorer>>, Arc<test_utils::TestLogger>>;
|
||||
type ChannelManager =
|
||||
channelmanager::ChannelManager<
|
||||
Arc<ChainMonitor>,
|
||||
Arc<test_utils::TestBroadcaster>,
|
||||
Arc<KeysManager>,
|
||||
Arc<KeysManager>,
|
||||
Arc<KeysManager>,
|
||||
Arc<test_utils::TestFeeEstimator>,
|
||||
Arc<DefaultRouter<
|
||||
Arc<NetworkGraph<Arc<test_utils::TestLogger>>>,
|
||||
Arc<test_utils::TestLogger>,
|
||||
Arc<Mutex<TestScorer>>,
|
||||
(),
|
||||
TestScorer>
|
||||
>,
|
||||
Arc<test_utils::TestLogger>>;
|
||||
|
||||
type ChainMonitor = chainmonitor::ChainMonitor<InMemorySigner, Arc<test_utils::TestChainSource>, Arc<test_utils::TestBroadcaster>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>, Arc<FilesystemPersister>>;
|
||||
|
||||
|
@ -752,7 +752,23 @@ pub type SimpleArcChannelManager<M, T, F, L> = ChannelManager<
|
||||
/// of [`KeysManager`] and [`DefaultRouter`].
|
||||
///
|
||||
/// This is not exported to bindings users as Arcs don't make sense in bindings
|
||||
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = ChannelManager<&'a M, &'b T, &'c KeysManager, &'c KeysManager, &'c KeysManager, &'d F, &'e DefaultRouter<&'f NetworkGraph<&'g L>, &'g L, &'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, ProbabilisticScoringFeeParameters, ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, &'g L>;
|
||||
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> =
|
||||
ChannelManager<
|
||||
&'a M,
|
||||
&'b T,
|
||||
&'c KeysManager,
|
||||
&'c KeysManager,
|
||||
&'c KeysManager,
|
||||
&'d F,
|
||||
&'e DefaultRouter<
|
||||
&'f NetworkGraph<&'g L>,
|
||||
&'g L,
|
||||
&'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>,
|
||||
ProbabilisticScoringFeeParameters,
|
||||
ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>
|
||||
>,
|
||||
&'g L
|
||||
>;
|
||||
|
||||
macro_rules! define_test_pub_trait { ($vis: vis) => {
|
||||
/// A trivial trait which describes any [`ChannelManager`] used in testing.
|
||||
|
@ -27,15 +27,15 @@ use crate::util::chacha20::ChaCha20;
|
||||
|
||||
use crate::io;
|
||||
use crate::prelude::*;
|
||||
use crate::sync::{Mutex, MutexGuard};
|
||||
use crate::sync::{Mutex};
|
||||
use alloc::collections::BinaryHeap;
|
||||
use core::{cmp, fmt};
|
||||
use core::ops::Deref;
|
||||
use core::ops::{Deref, DerefMut};
|
||||
|
||||
/// A [`Router`] implemented using [`find_route`].
|
||||
pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> where
|
||||
L::Target: Logger,
|
||||
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
|
||||
S::Target: for <'a> LockableScore<'a, Score = Sc>,
|
||||
{
|
||||
network_graph: G,
|
||||
logger: L,
|
||||
@ -46,7 +46,7 @@ pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref,
|
||||
|
||||
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> DefaultRouter<G, L, S, SP, Sc> where
|
||||
L::Target: Logger,
|
||||
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
|
||||
S::Target: for <'a> LockableScore<'a, Score = Sc>,
|
||||
{
|
||||
/// Creates a new router.
|
||||
pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S, score_params: SP) -> Self {
|
||||
@ -55,9 +55,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Scor
|
||||
}
|
||||
}
|
||||
|
||||
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
|
||||
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
|
||||
L::Target: Logger,
|
||||
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
|
||||
S::Target: for <'a> LockableScore<'a, Score = Sc>,
|
||||
{
|
||||
fn find_route(
|
||||
&self,
|
||||
@ -73,7 +73,7 @@ impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Sc
|
||||
};
|
||||
find_route(
|
||||
payer, params, &self.network_graph, first_hops, &*self.logger,
|
||||
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock(), inflight_htlcs),
|
||||
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), inflight_htlcs),
|
||||
&self.score_params,
|
||||
&random_seed_bytes
|
||||
)
|
||||
@ -104,15 +104,15 @@ pub trait Router {
|
||||
/// [`find_route`].
|
||||
///
|
||||
/// [`Score`]: crate::routing::scoring::Score
|
||||
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score> {
|
||||
scorer: S,
|
||||
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score<ScoreParams = SP>, SP: Sized> {
|
||||
scorer: &'a mut S,
|
||||
// Maps a channel's short channel id and its direction to the liquidity used up.
|
||||
inflight_htlcs: &'a InFlightHtlcs,
|
||||
}
|
||||
|
||||
impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
|
||||
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> {
|
||||
/// Initialize a new `ScorerAccountingForInFlightHtlcs`.
|
||||
pub fn new(scorer: S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
|
||||
pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
|
||||
ScorerAccountingForInFlightHtlcs {
|
||||
scorer,
|
||||
inflight_htlcs
|
||||
@ -121,11 +121,11 @@ impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
|
||||
}
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
impl<'a, S: Score> Writeable for ScorerAccountingForInFlightHtlcs<'a, S> {
|
||||
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Writeable for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
|
||||
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> { self.scorer.write(writer) }
|
||||
}
|
||||
|
||||
impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> {
|
||||
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Score for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
|
||||
type ScoreParams = S::ScoreParams;
|
||||
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
|
||||
if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(
|
||||
|
@ -157,8 +157,11 @@ define_score!();
|
||||
///
|
||||
/// [`find_route`]: crate::routing::router::find_route
|
||||
pub trait LockableScore<'a> {
|
||||
/// The [`Score`] type.
|
||||
type Score: 'a + Score;
|
||||
|
||||
/// The locked [`Score`] type.
|
||||
type Locked: 'a + Score;
|
||||
type Locked: DerefMut<Target = Self::Score> + Sized;
|
||||
|
||||
/// Returns the locked scorer.
|
||||
fn lock(&'a self) -> Self::Locked;
|
||||
@ -174,60 +177,35 @@ pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {}
|
||||
impl<'a, T> WriteableScore<'a> for T where T: LockableScore<'a> + Writeable {}
|
||||
/// This is not exported to bindings users
|
||||
impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
|
||||
type Score = T;
|
||||
type Locked = MutexGuard<'a, T>;
|
||||
|
||||
fn lock(&'a self) -> MutexGuard<'a, T> {
|
||||
fn lock(&'a self) -> Self::Locked {
|
||||
Mutex::lock(self).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
|
||||
type Score = T;
|
||||
type Locked = RefMut<'a, T>;
|
||||
|
||||
fn lock(&'a self) -> RefMut<'a, T> {
|
||||
fn lock(&'a self) -> Self::Locked {
|
||||
self.borrow_mut()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
/// A concrete implementation of [`LockableScore`] which supports multi-threading.
|
||||
pub struct MultiThreadedLockableScore<S: Score> {
|
||||
score: Mutex<S>,
|
||||
}
|
||||
#[cfg(c_bindings)]
|
||||
/// A locked `MultiThreadedLockableScore`.
|
||||
pub struct MultiThreadedScoreLock<'a, S: Score>(MutexGuard<'a, S>);
|
||||
#[cfg(c_bindings)]
|
||||
impl<'a, T: Score + 'a> Score for MultiThreadedScoreLock<'a, T> {
|
||||
type ScoreParams = <T as Score>::ScoreParams;
|
||||
fn channel_penalty_msat(&self, scid: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
|
||||
self.0.channel_penalty_msat(scid, source, target, usage, score_params)
|
||||
}
|
||||
fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) {
|
||||
self.0.payment_path_failed(path, short_channel_id)
|
||||
}
|
||||
fn payment_path_successful(&mut self, path: &Path) {
|
||||
self.0.payment_path_successful(path)
|
||||
}
|
||||
fn probe_failed(&mut self, path: &Path, short_channel_id: u64) {
|
||||
self.0.probe_failed(path, short_channel_id)
|
||||
}
|
||||
fn probe_successful(&mut self, path: &Path) {
|
||||
self.0.probe_successful(path)
|
||||
}
|
||||
}
|
||||
#[cfg(c_bindings)]
|
||||
impl<'a, T: Score + 'a> Writeable for MultiThreadedScoreLock<'a, T> {
|
||||
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
|
||||
self.0.write(writer)
|
||||
}
|
||||
pub struct MultiThreadedLockableScore<T: Score> {
|
||||
score: Mutex<T>,
|
||||
}
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
|
||||
impl<'a, T: 'a + Score> LockableScore<'a> for MultiThreadedLockableScore<T> {
|
||||
type Score = T;
|
||||
type Locked = MultiThreadedScoreLock<'a, T>;
|
||||
|
||||
fn lock(&'a self) -> MultiThreadedScoreLock<'a, T> {
|
||||
fn lock(&'a self) -> Self::Locked {
|
||||
MultiThreadedScoreLock(Mutex::lock(&self.score).unwrap())
|
||||
}
|
||||
}
|
||||
@ -240,7 +218,7 @@ impl<T: Score> Writeable for MultiThreadedLockableScore<T> {
|
||||
}
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
impl<'a, T: Score + 'a> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
|
||||
impl<'a, T: 'a + Score> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
impl<T: Score> MultiThreadedLockableScore<T> {
|
||||
@ -250,6 +228,33 @@ impl<T: Score> MultiThreadedLockableScore<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
/// A locked `MultiThreadedLockableScore`.
|
||||
pub struct MultiThreadedScoreLock<'a, T: Score>(MutexGuard<'a, T>);
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
impl<'a, T: 'a + Score> Writeable for MultiThreadedScoreLock<'a, T> {
|
||||
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
|
||||
self.0.write(writer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
impl<'a, T: 'a + Score> DerefMut for MultiThreadedScoreLock<'a, T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.0.deref_mut()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
impl<'a, T: 'a + Score> Deref for MultiThreadedScoreLock<'a, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.deref()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(c_bindings)]
|
||||
/// This is not exported to bindings users
|
||||
impl<'a, T: Writeable> Writeable for RefMut<'a, T> {
|
||||
|
@ -51,6 +51,7 @@ use regex;
|
||||
use crate::io;
|
||||
use crate::prelude::*;
|
||||
use core::cell::RefCell;
|
||||
use core::ops::DerefMut;
|
||||
use core::time::Duration;
|
||||
use crate::sync::{Mutex, Arc};
|
||||
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
@ -113,8 +114,8 @@ impl<'a> Router for TestRouter<'a> {
|
||||
if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
|
||||
assert_eq!(find_route_query, *params);
|
||||
if let Ok(ref route) = find_route_res {
|
||||
let locked_scorer = self.scorer.lock().unwrap();
|
||||
let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
|
||||
let mut binding = self.scorer.lock().unwrap();
|
||||
let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs);
|
||||
for path in &route.paths {
|
||||
let mut aggregate_msat = 0u64;
|
||||
for (idx, hop) in path.hops.iter().rev().enumerate() {
|
||||
@ -139,10 +140,9 @@ impl<'a> Router for TestRouter<'a> {
|
||||
return find_route_res;
|
||||
}
|
||||
let logger = TestLogger::new();
|
||||
let scorer = self.scorer.lock().unwrap();
|
||||
find_route(
|
||||
payer, params, &self.network_graph, first_hops, &logger,
|
||||
&ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), &(),
|
||||
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().unwrap().deref_mut(), &inflight_htlcs), &(),
|
||||
&[42; 32]
|
||||
)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user