mirror of
https://github.com/lightningdevkit/rust-lightning.git
synced 2025-01-19 05:43:55 +01:00
Add intermediate ConstructedTransaction
This commit is contained in:
parent
a04dde7664
commit
59a8bd5d65
@ -15,7 +15,8 @@ use bitcoin::blockdata::constants::WITNESS_SCALE_FACTOR;
|
||||
use bitcoin::consensus::Encodable;
|
||||
use bitcoin::policy::MAX_STANDARD_TX_WEIGHT;
|
||||
use bitcoin::{
|
||||
absolute::LockTime as AbsoluteLockTime, OutPoint, ScriptBuf, Sequence, Transaction, TxIn, TxOut,
|
||||
absolute::LockTime as AbsoluteLockTime, OutPoint, ScriptBuf, Sequence, Transaction, TxIn,
|
||||
TxOut, Weight,
|
||||
};
|
||||
|
||||
use crate::chain::chaininterface::fee_for_weight;
|
||||
@ -77,7 +78,7 @@ impl SerialIdExt for SerialId {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AbortReason {
|
||||
pub(crate) enum AbortReason {
|
||||
InvalidStateTransition,
|
||||
UnexpectedCounterpartyMessage,
|
||||
ReceivedTooManyTxAddInputs,
|
||||
@ -97,53 +98,183 @@ pub enum AbortReason {
|
||||
InvalidTx,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TxInputWithPrevOutput {
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct InteractiveTxInput {
|
||||
serial_id: SerialId,
|
||||
input: TxIn,
|
||||
prev_output: TxOut,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct InteractiveTxOutput {
|
||||
serial_id: SerialId,
|
||||
tx_out: TxOut,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct ConstructedTransaction {
|
||||
holder_is_initiator: bool,
|
||||
|
||||
inputs: Vec<InteractiveTxInput>,
|
||||
outputs: Vec<InteractiveTxOutput>,
|
||||
|
||||
local_inputs_value_satoshis: u64,
|
||||
local_outputs_value_satoshis: u64,
|
||||
|
||||
remote_inputs_value_satoshis: u64,
|
||||
remote_outputs_value_satoshis: u64,
|
||||
|
||||
lock_time: AbsoluteLockTime,
|
||||
}
|
||||
|
||||
impl ConstructedTransaction {
|
||||
fn new(context: NegotiationContext) -> Self {
|
||||
let local_inputs_value_satoshis = context
|
||||
.inputs
|
||||
.iter()
|
||||
.filter(|(serial_id, _)| {
|
||||
!is_serial_id_valid_for_counterparty(context.holder_is_initiator, serial_id)
|
||||
})
|
||||
.fold(0u64, |value, (_, input)| value.saturating_add(input.prev_output.value));
|
||||
|
||||
let local_outputs_value_satoshis = context
|
||||
.outputs
|
||||
.iter()
|
||||
.filter(|(serial_id, _)| {
|
||||
!is_serial_id_valid_for_counterparty(context.holder_is_initiator, serial_id)
|
||||
})
|
||||
.fold(0u64, |value, (_, output)| value.saturating_add(output.tx_out.value));
|
||||
|
||||
Self {
|
||||
holder_is_initiator: context.holder_is_initiator,
|
||||
|
||||
local_inputs_value_satoshis,
|
||||
local_outputs_value_satoshis,
|
||||
|
||||
remote_inputs_value_satoshis: context.remote_inputs_value(),
|
||||
remote_outputs_value_satoshis: context.remote_outputs_value(),
|
||||
|
||||
inputs: context.inputs.into_values().collect(),
|
||||
outputs: context.outputs.into_values().collect(),
|
||||
|
||||
lock_time: context.tx_locktime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn weight(&self) -> Weight {
|
||||
let inputs_weight = self.inputs.iter().fold(
|
||||
Weight::from_wu(0),
|
||||
|weight, InteractiveTxInput { prev_output, .. }| {
|
||||
weight.checked_add(estimate_input_weight(prev_output)).unwrap_or(Weight::MAX)
|
||||
},
|
||||
);
|
||||
let outputs_weight = self.outputs.iter().fold(
|
||||
Weight::from_wu(0),
|
||||
|weight, InteractiveTxOutput { tx_out, .. }| {
|
||||
weight.checked_add(get_output_weight(&tx_out.script_pubkey)).unwrap_or(Weight::MAX)
|
||||
},
|
||||
);
|
||||
Weight::from_wu(TX_COMMON_FIELDS_WEIGHT)
|
||||
.checked_add(inputs_weight)
|
||||
.and_then(|weight| weight.checked_add(outputs_weight))
|
||||
.unwrap_or(Weight::MAX)
|
||||
}
|
||||
|
||||
pub fn into_unsigned_tx(self) -> Transaction {
|
||||
// Inputs and outputs must be sorted by serial_id
|
||||
let ConstructedTransaction { mut inputs, mut outputs, .. } = self;
|
||||
|
||||
inputs.sort_unstable_by_key(|InteractiveTxInput { serial_id, .. }| *serial_id);
|
||||
outputs.sort_unstable_by_key(|InteractiveTxOutput { serial_id, .. }| *serial_id);
|
||||
|
||||
let input: Vec<TxIn> =
|
||||
inputs.into_iter().map(|InteractiveTxInput { input, .. }| input).collect();
|
||||
let output: Vec<TxOut> =
|
||||
outputs.into_iter().map(|InteractiveTxOutput { tx_out, .. }| tx_out).collect();
|
||||
|
||||
Transaction { version: 2, lock_time: self.lock_time, input, output }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NegotiationContext {
|
||||
holder_is_initiator: bool,
|
||||
received_tx_add_input_count: u16,
|
||||
received_tx_add_output_count: u16,
|
||||
inputs: HashMap<SerialId, TxInputWithPrevOutput>,
|
||||
inputs: HashMap<SerialId, InteractiveTxInput>,
|
||||
prevtx_outpoints: HashSet<OutPoint>,
|
||||
outputs: HashMap<SerialId, TxOut>,
|
||||
outputs: HashMap<SerialId, InteractiveTxOutput>,
|
||||
tx_locktime: AbsoluteLockTime,
|
||||
feerate_sat_per_kw: u32,
|
||||
}
|
||||
|
||||
pub(crate) fn get_output_weight(script_pubkey: &ScriptBuf) -> u64 {
|
||||
(8 /* value */ + script_pubkey.consensus_encode(&mut sink()).unwrap() as u64)
|
||||
* WITNESS_SCALE_FACTOR as u64
|
||||
pub(crate) fn estimate_input_weight(prev_output: &TxOut) -> Weight {
|
||||
Weight::from_wu(if prev_output.script_pubkey.is_v0_p2wpkh() {
|
||||
P2WPKH_INPUT_WEIGHT_LOWER_BOUND
|
||||
} else if prev_output.script_pubkey.is_v0_p2wsh() {
|
||||
P2WSH_INPUT_WEIGHT_LOWER_BOUND
|
||||
} else if prev_output.script_pubkey.is_v1_p2tr() {
|
||||
P2TR_INPUT_WEIGHT_LOWER_BOUND
|
||||
} else {
|
||||
UNKNOWN_SEGWIT_VERSION_INPUT_WEIGHT_LOWER_BOUND
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_output_weight(script_pubkey: &ScriptBuf) -> Weight {
|
||||
Weight::from_wu(
|
||||
(8 /* value */ + script_pubkey.consensus_encode(&mut sink()).unwrap() as u64)
|
||||
* WITNESS_SCALE_FACTOR as u64,
|
||||
)
|
||||
}
|
||||
|
||||
fn is_serial_id_valid_for_counterparty(holder_is_initiator: bool, serial_id: &SerialId) -> bool {
|
||||
// A received `SerialId`'s parity must match the role of the counterparty.
|
||||
holder_is_initiator == serial_id.is_for_non_initiator()
|
||||
}
|
||||
|
||||
impl NegotiationContext {
|
||||
fn is_serial_id_valid_for_counterparty(&self, serial_id: &SerialId) -> bool {
|
||||
// A received `SerialId`'s parity must match the role of the counterparty.
|
||||
self.holder_is_initiator == serial_id.is_for_non_initiator()
|
||||
is_serial_id_valid_for_counterparty(self.holder_is_initiator, serial_id)
|
||||
}
|
||||
|
||||
fn total_input_and_output_count(&self) -> usize {
|
||||
self.inputs.len().saturating_add(self.outputs.len())
|
||||
}
|
||||
|
||||
fn counterparty_inputs_contributed(
|
||||
&self,
|
||||
) -> impl Iterator<Item = &TxInputWithPrevOutput> + Clone {
|
||||
fn remote_inputs_value(&self) -> u64 {
|
||||
self.inputs
|
||||
.iter()
|
||||
.filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
|
||||
.map(|(_, input_with_prevout)| input_with_prevout)
|
||||
.filter(|(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
|
||||
.fold(0u64, |acc, (_, InteractiveTxInput { prev_output, .. })| {
|
||||
acc.saturating_add(prev_output.value)
|
||||
})
|
||||
}
|
||||
|
||||
fn counterparty_outputs_contributed(&self) -> impl Iterator<Item = &TxOut> + Clone {
|
||||
fn remote_outputs_value(&self) -> u64 {
|
||||
self.outputs
|
||||
.iter()
|
||||
.filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
|
||||
.map(|(_, output)| output)
|
||||
.filter(|(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
|
||||
.fold(0u64, |acc, (_, InteractiveTxOutput { tx_out, .. })| {
|
||||
acc.saturating_add(tx_out.value)
|
||||
})
|
||||
}
|
||||
|
||||
fn remote_inputs_weight(&self) -> Weight {
|
||||
Weight::from_wu(
|
||||
self.inputs
|
||||
.iter()
|
||||
.filter(|(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
|
||||
.fold(0u64, |weight, (_, InteractiveTxInput { prev_output, .. })| {
|
||||
weight.saturating_add(estimate_input_weight(prev_output).to_wu())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn remote_outputs_weight(&self) -> Weight {
|
||||
Weight::from_wu(
|
||||
self.outputs
|
||||
.iter()
|
||||
.filter(|(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
|
||||
.fold(0u64, |weight, (_, InteractiveTxOutput { tx_out, .. })| {
|
||||
weight.saturating_add(get_output_weight(&tx_out.script_pubkey).to_wu())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn received_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> {
|
||||
@ -213,7 +344,8 @@ impl NegotiationContext {
|
||||
},
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out };
|
||||
entry.insert(TxInputWithPrevOutput {
|
||||
entry.insert(InteractiveTxInput {
|
||||
serial_id: msg.serial_id,
|
||||
input: TxIn {
|
||||
previous_output: prev_outpoint,
|
||||
sequence: Sequence(msg.sequence),
|
||||
@ -269,7 +401,7 @@ impl NegotiationContext {
|
||||
// bitcoin supply.
|
||||
let mut outputs_value: u64 = 0;
|
||||
for output in self.outputs.iter() {
|
||||
outputs_value = outputs_value.saturating_add(output.1.value);
|
||||
outputs_value = outputs_value.saturating_add(output.1.tx_out.value);
|
||||
}
|
||||
if outputs_value.saturating_add(msg.sats) > TOTAL_BITCOIN_SUPPLY_SATOSHIS {
|
||||
// The receiving node:
|
||||
@ -306,7 +438,10 @@ impl NegotiationContext {
|
||||
Err(AbortReason::DuplicateSerialId)
|
||||
},
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert(TxOut { value: msg.sats, script_pubkey: msg.script.clone() });
|
||||
entry.insert(InteractiveTxOutput {
|
||||
serial_id: msg.serial_id,
|
||||
tx_out: TxOut { value: msg.sats, script_pubkey: msg.script.clone() },
|
||||
});
|
||||
Ok(())
|
||||
},
|
||||
}
|
||||
@ -340,13 +475,21 @@ impl NegotiationContext {
|
||||
// We have added an input that already exists
|
||||
return Err(AbortReason::PrevTxOutInvalid);
|
||||
}
|
||||
self.inputs.insert(msg.serial_id, TxInputWithPrevOutput { input, prev_output });
|
||||
self.inputs.insert(
|
||||
msg.serial_id,
|
||||
InteractiveTxInput { serial_id: msg.serial_id, input, prev_output },
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sent_tx_add_output(&mut self, msg: &msgs::TxAddOutput) -> Result<(), AbortReason> {
|
||||
self.outputs
|
||||
.insert(msg.serial_id, TxOut { value: msg.sats, script_pubkey: msg.script.clone() });
|
||||
self.outputs.insert(
|
||||
msg.serial_id,
|
||||
InteractiveTxOutput {
|
||||
serial_id: msg.serial_id,
|
||||
tx_out: TxOut { value: msg.sats, script_pubkey: msg.script.clone() },
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -361,31 +504,12 @@ impl NegotiationContext {
|
||||
}
|
||||
|
||||
fn check_counterparty_fees(
|
||||
&self, counterparty_inputs_value: u64, counterparty_outputs_value: u64,
|
||||
&self, counterparty_fees_contributed: u64,
|
||||
) -> Result<(), AbortReason> {
|
||||
let mut counterparty_weight_contributed: u64 = self
|
||||
.counterparty_outputs_contributed()
|
||||
.map(|output| get_output_weight(&output.script_pubkey))
|
||||
.sum();
|
||||
// We don't know the counterparty's witnesses ahead of time obviously, so we use the lower bounds
|
||||
// specified in BOLT 3.
|
||||
let mut total_inputs_weight: u64 = 0;
|
||||
for TxInputWithPrevOutput { prev_output, .. } in self.counterparty_inputs_contributed() {
|
||||
total_inputs_weight =
|
||||
total_inputs_weight.saturating_add(if prev_output.script_pubkey.is_v0_p2wpkh() {
|
||||
P2WPKH_INPUT_WEIGHT_LOWER_BOUND
|
||||
} else if prev_output.script_pubkey.is_v0_p2wsh() {
|
||||
P2WSH_INPUT_WEIGHT_LOWER_BOUND
|
||||
} else if prev_output.script_pubkey.is_v1_p2tr() {
|
||||
P2TR_INPUT_WEIGHT_LOWER_BOUND
|
||||
} else {
|
||||
UNKNOWN_SEGWIT_VERSION_INPUT_WEIGHT_LOWER_BOUND
|
||||
});
|
||||
}
|
||||
counterparty_weight_contributed =
|
||||
counterparty_weight_contributed.saturating_add(total_inputs_weight);
|
||||
let counterparty_fees_contributed =
|
||||
counterparty_inputs_value.saturating_sub(counterparty_outputs_value);
|
||||
let counterparty_weight_contributed = self
|
||||
.remote_inputs_weight()
|
||||
.to_wu()
|
||||
.saturating_add(self.remote_outputs_weight().to_wu());
|
||||
let mut required_counterparty_contribution_fee =
|
||||
fee_for_weight(self.feerate_sat_per_kw, counterparty_weight_contributed);
|
||||
if !self.holder_is_initiator {
|
||||
@ -402,21 +526,14 @@ impl NegotiationContext {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_transaction(self) -> Result<Transaction, AbortReason> {
|
||||
fn validate_tx(self) -> Result<ConstructedTransaction, AbortReason> {
|
||||
// The receiving node:
|
||||
// MUST fail the negotiation if:
|
||||
|
||||
// - the peer's total input satoshis is less than their outputs
|
||||
let mut counterparty_inputs_value: u64 = 0;
|
||||
let mut counterparty_outputs_value: u64 = 0;
|
||||
for input in self.counterparty_inputs_contributed() {
|
||||
counterparty_inputs_value =
|
||||
counterparty_inputs_value.saturating_add(input.prev_output.value);
|
||||
}
|
||||
for output in self.counterparty_outputs_contributed() {
|
||||
counterparty_outputs_value = counterparty_outputs_value.saturating_add(output.value);
|
||||
}
|
||||
if counterparty_inputs_value < counterparty_outputs_value {
|
||||
let remote_inputs_value = self.remote_inputs_value();
|
||||
let remote_outputs_value = self.remote_outputs_value();
|
||||
if remote_inputs_value < remote_outputs_value {
|
||||
return Err(AbortReason::OutputsValueExceedsInputsValue);
|
||||
}
|
||||
|
||||
@ -429,25 +546,15 @@ impl NegotiationContext {
|
||||
}
|
||||
|
||||
// - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee).
|
||||
self.check_counterparty_fees(counterparty_inputs_value, counterparty_outputs_value)?;
|
||||
self.check_counterparty_fees(remote_inputs_value.saturating_sub(remote_outputs_value))?;
|
||||
|
||||
// Inputs and outputs must be sorted by serial_id
|
||||
let mut inputs = self.inputs.into_iter().collect::<Vec<_>>();
|
||||
let mut outputs = self.outputs.into_iter().collect::<Vec<_>>();
|
||||
inputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
|
||||
outputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
|
||||
let constructed_tx = ConstructedTransaction::new(self);
|
||||
|
||||
let tx_to_validate = Transaction {
|
||||
version: 2,
|
||||
lock_time: self.tx_locktime,
|
||||
input: inputs.into_iter().map(|(_, input)| input.input).collect(),
|
||||
output: outputs.into_iter().map(|(_, output)| output).collect(),
|
||||
};
|
||||
if tx_to_validate.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 {
|
||||
if constructed_tx.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 {
|
||||
return Err(AbortReason::TransactionTooLarge);
|
||||
}
|
||||
|
||||
Ok(tx_to_validate)
|
||||
Ok(constructed_tx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -535,7 +642,7 @@ define_state!(
|
||||
ReceivedTxComplete,
|
||||
"We have received a `tx_complete` message and the counterparty is awaiting ours."
|
||||
);
|
||||
define_state!(NegotiationComplete, Transaction, "We have exchanged consecutive `tx_complete` messages with the counterparty and the transaction negotiation is complete.");
|
||||
define_state!(NegotiationComplete, ConstructedTransaction, "We have exchanged consecutive `tx_complete` messages with the counterparty and the transaction negotiation is complete.");
|
||||
define_state!(
|
||||
NegotiationAborted,
|
||||
AbortReason,
|
||||
@ -577,7 +684,7 @@ macro_rules! define_state_transitions {
|
||||
impl StateTransition<NegotiationComplete, &msgs::TxComplete> for $tx_complete_state {
|
||||
fn transition(self, _data: &msgs::TxComplete) -> StateTransitionResult<NegotiationComplete> {
|
||||
let context = self.into_negotiation_context();
|
||||
let tx = context.build_transaction()?;
|
||||
let tx = context.validate_tx()?;
|
||||
Ok(NegotiationComplete(tx))
|
||||
}
|
||||
}
|
||||
@ -715,14 +822,14 @@ impl StateMachine {
|
||||
]);
|
||||
}
|
||||
|
||||
pub struct InteractiveTxConstructor {
|
||||
pub(crate) struct InteractiveTxConstructor {
|
||||
state_machine: StateMachine,
|
||||
channel_id: ChannelId,
|
||||
inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)>,
|
||||
outputs_to_contribute: Vec<(SerialId, TxOut)>,
|
||||
}
|
||||
|
||||
pub enum InteractiveTxMessageSend {
|
||||
pub(crate) enum InteractiveTxMessageSend {
|
||||
TxAddInput(msgs::TxAddInput),
|
||||
TxAddOutput(msgs::TxAddOutput),
|
||||
TxComplete(msgs::TxComplete),
|
||||
@ -754,10 +861,10 @@ where
|
||||
serial_id
|
||||
}
|
||||
|
||||
pub enum HandleTxCompleteValue {
|
||||
pub(crate) enum HandleTxCompleteValue {
|
||||
SendTxMessage(InteractiveTxMessageSend),
|
||||
SendTxComplete(InteractiveTxMessageSend, Transaction),
|
||||
NegotiationComplete(Transaction),
|
||||
SendTxComplete(InteractiveTxMessageSend, ConstructedTransaction),
|
||||
NegotiationComplete(ConstructedTransaction),
|
||||
}
|
||||
|
||||
impl InteractiveTxConstructor {
|
||||
@ -1107,7 +1214,7 @@ mod tests {
|
||||
}
|
||||
assert!(message_send_a.is_none());
|
||||
assert!(message_send_b.is_none());
|
||||
assert_eq!(final_tx_a, final_tx_b);
|
||||
assert_eq!(final_tx_a.unwrap().into_unsigned_tx(), final_tx_b.unwrap().into_unsigned_tx());
|
||||
assert!(session.expect_error.is_none(), "Test: {}", session.description);
|
||||
}
|
||||
|
||||
@ -1280,7 +1387,7 @@ mod tests {
|
||||
let p2wpkh_fee = fee_for_weight(TEST_FEERATE_SATS_PER_KW, P2WPKH_INPUT_WEIGHT_LOWER_BOUND);
|
||||
let outputs_fee = fee_for_weight(
|
||||
TEST_FEERATE_SATS_PER_KW,
|
||||
get_output_weight(&generate_p2wpkh_script_pubkey()),
|
||||
get_output_weight(&generate_p2wpkh_script_pubkey()).to_wu(),
|
||||
);
|
||||
let tx_common_fields_fee =
|
||||
fee_for_weight(TEST_FEERATE_SATS_PER_KW, TX_COMMON_FIELDS_WEIGHT);
|
||||
|
Loading…
Reference in New Issue
Block a user