Hold channel_state lock into fail_htlc_backwards_internal

This commit is contained in:
Matt Corallo 2018-03-23 16:57:22 -04:00
parent d8fc9ffde9
commit 729f02d2dd

View File

@ -26,11 +26,10 @@ use crypto::digest::Digest;
use crypto::symmetriccipher::SynchronousStreamCipher;
use crypto::chacha20::ChaCha20;
use std::sync::{Mutex,Arc};
use std::sync::{Mutex,MutexGuard,Arc};
use std::collections::HashMap;
use std::collections::hash_map;
use std::ptr;
use std::mem;
use std::{ptr, mem};
use std::time::{Instant,Duration};
/// Stores the info we will need to send when we want to forward an HTLC onwards
@ -651,11 +650,10 @@ impl ChannelManager {
/// Indicates that the preimage for payment_hash is unknown after a PaymentReceived event.
pub fn fail_htlc_backwards(&self, payment_hash: &[u8; 32]) -> bool {
self.fail_htlc_backwards_internal(payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 15 })
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 15 })
}
fn fail_htlc_backwards_internal(&self, payment_hash: &[u8; 32], onion_error: HTLCFailReason) -> bool {
let mut channel_state = self.channel_state.lock().unwrap();
fn fail_htlc_backwards_internal(&self, mut channel_state: MutexGuard<ChannelHolder>, payment_hash: &[u8; 32], onion_error: HTLCFailReason) -> bool {
let mut pending_htlc = {
match channel_state.claimable_htlcs.remove(payment_hash) {
Some(pending_htlc) => pending_htlc,
@ -674,6 +672,7 @@ impl ChannelManager {
PendingOutboundHTLC::CycledRoute { .. } => { panic!("WAT"); },
PendingOutboundHTLC::OutboundRoute { .. } => {
//TODO: DECRYPT route from OutboundRoute
mem::drop(channel_state);
let mut pending_events = self.pending_events.lock().unwrap();
pending_events.push(events::Event::PaymentFailed {
payment_hash: payment_hash.clone()
@ -707,6 +706,7 @@ impl ChannelManager {
}
};
mem::drop(channel_state);
let mut pending_events = self.pending_events.lock().unwrap();
pending_events.push(events::Event::SendFailHTLC {
node_id,
@ -1217,36 +1217,34 @@ impl ChannelMessageHandler for ChannelManager {
}
fn handle_update_fail_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) -> Result<Option<(Vec<msgs::UpdateAddHTLC>, msgs::CommitmentSigned)>, HandleError> {
let res = {
let mut channel_state = self.channel_state.lock().unwrap();
match channel_state.by_id.get_mut(&msg.channel_id) {
Some(chan) => {
if chan.get_their_node_id() != *their_node_id {
return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None})
}
chan.update_fail_htlc(&msg)?
},
None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None})
}
};
self.fail_htlc_backwards_internal(&res.0, HTLCFailReason::ErrorPacket { err: &msg.reason });
let mut channel_state = self.channel_state.lock().unwrap();
let res;
match channel_state.by_id.get_mut(&msg.channel_id) {
Some(chan) => {
if chan.get_their_node_id() != *their_node_id {
return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None})
}
res = chan.update_fail_htlc(&msg)?;
},
None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None})
}
self.fail_htlc_backwards_internal(channel_state, &res.0, HTLCFailReason::ErrorPacket { err: &msg.reason });
Ok(res.1)
}
fn handle_update_fail_malformed_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) -> Result<Option<(Vec<msgs::UpdateAddHTLC>, msgs::CommitmentSigned)>, HandleError> {
let res = {
let mut channel_state = self.channel_state.lock().unwrap();
match channel_state.by_id.get_mut(&msg.channel_id) {
Some(chan) => {
if chan.get_their_node_id() != *their_node_id {
return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None})
}
chan.update_fail_malformed_htlc(&msg)?
},
None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None})
}
};
self.fail_htlc_backwards_internal(&res.0, HTLCFailReason::Reason { failure_code: msg.failure_code });
let mut channel_state = self.channel_state.lock().unwrap();
let res;
match channel_state.by_id.get_mut(&msg.channel_id) {
Some(chan) => {
if chan.get_their_node_id() != *their_node_id {
return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None})
}
res = chan.update_fail_malformed_htlc(&msg)?;
},
None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None})
}
self.fail_htlc_backwards_internal(channel_state, &res.0, HTLCFailReason::Reason { failure_code: msg.failure_code });
Ok(res.1)
}