From 729f02d2dd542c72989ac1665d669a3602bb81c8 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Fri, 23 Mar 2018 16:57:22 -0400 Subject: [PATCH] Hold channel_state lock into fail_htlc_backwards_internal --- src/ln/channelmanager.rs | 62 +++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/src/ln/channelmanager.rs b/src/ln/channelmanager.rs index 519a6c20d..4e243935f 100644 --- a/src/ln/channelmanager.rs +++ b/src/ln/channelmanager.rs @@ -26,11 +26,10 @@ use crypto::digest::Digest; use crypto::symmetriccipher::SynchronousStreamCipher; use crypto::chacha20::ChaCha20; -use std::sync::{Mutex,Arc}; +use std::sync::{Mutex,MutexGuard,Arc}; use std::collections::HashMap; use std::collections::hash_map; -use std::ptr; -use std::mem; +use std::{ptr, mem}; use std::time::{Instant,Duration}; /// Stores the info we will need to send when we want to forward an HTLC onwards @@ -651,11 +650,10 @@ impl ChannelManager { /// Indicates that the preimage for payment_hash is unknown after a PaymentReceived event. pub fn fail_htlc_backwards(&self, payment_hash: &[u8; 32]) -> bool { - self.fail_htlc_backwards_internal(payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 15 }) + self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 15 }) } - fn fail_htlc_backwards_internal(&self, payment_hash: &[u8; 32], onion_error: HTLCFailReason) -> bool { - let mut channel_state = self.channel_state.lock().unwrap(); + fn fail_htlc_backwards_internal(&self, mut channel_state: MutexGuard, payment_hash: &[u8; 32], onion_error: HTLCFailReason) -> bool { let mut pending_htlc = { match channel_state.claimable_htlcs.remove(payment_hash) { Some(pending_htlc) => pending_htlc, @@ -674,6 +672,7 @@ impl ChannelManager { PendingOutboundHTLC::CycledRoute { .. } => { panic!("WAT"); }, PendingOutboundHTLC::OutboundRoute { .. } => { //TODO: DECRYPT route from OutboundRoute + mem::drop(channel_state); let mut pending_events = self.pending_events.lock().unwrap(); pending_events.push(events::Event::PaymentFailed { payment_hash: payment_hash.clone() @@ -707,6 +706,7 @@ impl ChannelManager { } }; + mem::drop(channel_state); let mut pending_events = self.pending_events.lock().unwrap(); pending_events.push(events::Event::SendFailHTLC { node_id, @@ -1217,36 +1217,34 @@ impl ChannelMessageHandler for ChannelManager { } fn handle_update_fail_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) -> Result, msgs::CommitmentSigned)>, HandleError> { - let res = { - let mut channel_state = self.channel_state.lock().unwrap(); - match channel_state.by_id.get_mut(&msg.channel_id) { - Some(chan) => { - if chan.get_their_node_id() != *their_node_id { - return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None}) - } - chan.update_fail_htlc(&msg)? - }, - None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None}) - } - }; - self.fail_htlc_backwards_internal(&res.0, HTLCFailReason::ErrorPacket { err: &msg.reason }); + let mut channel_state = self.channel_state.lock().unwrap(); + let res; + match channel_state.by_id.get_mut(&msg.channel_id) { + Some(chan) => { + if chan.get_their_node_id() != *their_node_id { + return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None}) + } + res = chan.update_fail_htlc(&msg)?; + }, + None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None}) + } + self.fail_htlc_backwards_internal(channel_state, &res.0, HTLCFailReason::ErrorPacket { err: &msg.reason }); Ok(res.1) } fn handle_update_fail_malformed_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) -> Result, msgs::CommitmentSigned)>, HandleError> { - let res = { - let mut channel_state = self.channel_state.lock().unwrap(); - match channel_state.by_id.get_mut(&msg.channel_id) { - Some(chan) => { - if chan.get_their_node_id() != *their_node_id { - return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None}) - } - chan.update_fail_malformed_htlc(&msg)? - }, - None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None}) - } - }; - self.fail_htlc_backwards_internal(&res.0, HTLCFailReason::Reason { failure_code: msg.failure_code }); + let mut channel_state = self.channel_state.lock().unwrap(); + let res; + match channel_state.by_id.get_mut(&msg.channel_id) { + Some(chan) => { + if chan.get_their_node_id() != *their_node_id { + return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None}) + } + res = chan.update_fail_malformed_htlc(&msg)?; + }, + None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None}) + } + self.fail_htlc_backwards_internal(channel_state, &res.0, HTLCFailReason::Reason { failure_code: msg.failure_code }); Ok(res.1) }