From ded2352cf8b734fbbdb1c97e17431f926b06fe9e Mon Sep 17 00:00:00 2001 From: junderw Date: Sat, 24 Jun 2023 23:05:43 -0700 Subject: [PATCH] Use a class to hold state for Rust GbtGenerator --- backend/rust-gbt/index.d.ts | 7 ++- backend/rust-gbt/index.js | 5 +- backend/rust-gbt/src/gbt.rs | 6 +-- backend/rust-gbt/src/lib.rs | 89 +++++++++++++++++++------------ backend/rust-gbt/src/utils.rs | 58 +++++++++++++++++++- backend/src/api/mempool-blocks.ts | 7 +-- 6 files changed, 126 insertions(+), 46 deletions(-) diff --git a/backend/rust-gbt/index.d.ts b/backend/rust-gbt/index.d.ts index 793d78c4e..b02a27c45 100644 --- a/backend/rust-gbt/index.d.ts +++ b/backend/rust-gbt/index.d.ts @@ -3,8 +3,11 @@ /* auto-generated by NAPI-RS */ -export function make(mempoolBuffer: Uint8Array): Promise -export function update(newTxs: Uint8Array, removeTxs: Uint8Array): Promise +export class GbtGenerator { + constructor() + make(mempoolBuffer: Uint8Array): Promise + update(newTxs: Uint8Array, removeTxs: Uint8Array): Promise +} /** * The result from calling the gbt function. * diff --git a/backend/rust-gbt/index.js b/backend/rust-gbt/index.js index 5caf75b42..8680501d1 100644 --- a/backend/rust-gbt/index.js +++ b/backend/rust-gbt/index.js @@ -252,8 +252,7 @@ if (!nativeBinding) { throw new Error(`Failed to load native binding`) } -const { make, update, GbtResult } = nativeBinding +const { GbtGenerator, GbtResult } = nativeBinding -module.exports.make = make -module.exports.update = update +module.exports.GbtGenerator = GbtGenerator module.exports.GbtResult = GbtResult diff --git a/backend/rust-gbt/src/gbt.rs b/backend/rust-gbt/src/gbt.rs index e78f81604..f657c013a 100644 --- a/backend/rust-gbt/src/gbt.rs +++ b/backend/rust-gbt/src/gbt.rs @@ -4,9 +4,7 @@ use std::{ collections::{HashMap, HashSet, VecDeque}, }; -use crate::{ - audit_transaction::AuditTransaction, thread_transaction::ThreadTransaction, GbtResult, -}; +use crate::{audit_transaction::AuditTransaction, GbtResult, ThreadTransactionsMap}; const BLOCK_WEIGHT_UNITS: u32 = 4_000_000; const BLOCK_SIGOPS: u32 = 80_000; @@ -43,7 +41,7 @@ impl Ord for TxPriority { * (see BlockAssembler in https://github.com/bitcoin/bitcoin/blob/master/src/node/miner.cpp) * Ported from https://github.com/mempool/mempool/blob/master/backend/src/api/tx-selection-worker.ts */ -pub fn gbt(mempool: &mut HashMap) -> Option { +pub fn gbt(mempool: &mut ThreadTransactionsMap) -> Option { let mut audit_pool: HashMap = HashMap::new(); let mut mempool_array: VecDeque = VecDeque::new(); let mut clusters: Vec> = Vec::new(); diff --git a/backend/rust-gbt/src/lib.rs b/backend/rust-gbt/src/lib.rs index 0cdeb74e3..1044d5bd5 100644 --- a/backend/rust-gbt/src/lib.rs +++ b/backend/rust-gbt/src/lib.rs @@ -1,9 +1,9 @@ use napi::bindgen_prelude::*; use napi_derive::napi; -use once_cell::sync::Lazy; use std::collections::HashMap; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; +use utils::U32HasherState; mod audit_transaction; mod gbt; @@ -11,41 +11,48 @@ mod thread_transaction; mod utils; use thread_transaction::ThreadTransaction; -static THREAD_TRANSACTIONS: Lazy>> = - Lazy::new(|| Mutex::new(HashMap::new())); +type ThreadTransactionsMap = HashMap; -#[napi(ts_args_type = "mempoolBuffer: Uint8Array")] -pub async fn make(mempool_buffer: Uint8Array) -> Result { - let mut map = HashMap::new(); - for tx in ThreadTransaction::batch_from_buffer(&mempool_buffer) { - map.insert(tx.uid, tx); - } - - { - let mut global_map = THREAD_TRANSACTIONS - .lock() - .map_err(|_| napi::Error::from_reason("THREAD_TRANSACTIONS Mutex poisoned"))?; - *global_map = map; - } - - run_in_thread().await +#[napi] +pub struct GbtGenerator { + thread_transactions: Arc>, } -#[napi(ts_args_type = "newTxs: Uint8Array, removeTxs: Uint8Array")] -pub async fn update(new_txs: Uint8Array, remove_txs: Uint8Array) -> Result { - { - let mut map = THREAD_TRANSACTIONS - .lock() - .map_err(|_| napi::Error::from_reason("THREAD_TRANSACTIONS Mutex poisoned"))?; - for tx in ThreadTransaction::batch_from_buffer(&new_txs) { - map.insert(tx.uid, tx); - } - for txid in &utils::txids_from_buffer(&remove_txs) { - map.remove(txid); +#[napi] +impl GbtGenerator { + #[napi(constructor)] + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { + thread_transactions: Arc::new(Mutex::new(HashMap::with_capacity_and_hasher( + 2048, + U32HasherState, + ))), } } - run_in_thread().await + #[napi] + pub async fn make(&self, mempool_buffer: Uint8Array) -> Result { + run_task(Arc::clone(&self.thread_transactions), move |map| { + for tx in ThreadTransaction::batch_from_buffer(&mempool_buffer) { + map.insert(tx.uid, tx); + } + }) + .await + } + + #[napi] + pub async fn update(&self, new_txs: Uint8Array, remove_txs: Uint8Array) -> Result { + run_task(Arc::clone(&self.thread_transactions), move |map| { + for tx in ThreadTransaction::batch_from_buffer(&new_txs) { + map.insert(tx.uid, tx); + } + for txid in &utils::txids_from_buffer(&remove_txs) { + map.remove(txid); + } + }) + .await + } } /// The result from calling the gbt function. @@ -61,11 +68,27 @@ pub struct GbtResult { pub rates: Vec>, // Tuples not supported. u32 fits inside f64 } -async fn run_in_thread() -> Result { +/// All on another thread, this runs an arbitrary task in between +/// taking the lock and running gbt. +/// +/// Rather than filling / updating the HashMap on the main thread, +/// this allows for HashMap modifying tasks to be run before running and returning gbt results. +/// +/// `thread_transactions` is a cloned Arc of the Mutex for the HashMap state. +/// `callback` is a `'static + Send` `FnOnce` closure/function that takes a mutable reference +/// to the HashMap as the only argument. (A move closure is recommended to meet the bounds) +async fn run_task( + thread_transactions: Arc>, + callback: F, +) -> Result +where + F: FnOnce(&mut ThreadTransactionsMap) + Send + 'static, +{ let handle = napi::tokio::task::spawn_blocking(move || { - let mut map = THREAD_TRANSACTIONS + let mut map = thread_transactions .lock() .map_err(|_| napi::Error::from_reason("THREAD_TRANSACTIONS Mutex poisoned"))?; + callback(&mut map); gbt::gbt(&mut map).ok_or_else(|| napi::Error::from_reason("gbt failed")) }); diff --git a/backend/rust-gbt/src/utils.rs b/backend/rust-gbt/src/utils.rs index c1b6063a1..b969c8361 100644 --- a/backend/rust-gbt/src/utils.rs +++ b/backend/rust-gbt/src/utils.rs @@ -1,5 +1,8 @@ use bytes::buf::Buf; -use std::io::Cursor; +use std::{ + hash::{BuildHasher, Hasher}, + io::Cursor, +}; pub fn txids_from_buffer(buffer: &[u8]) -> Vec { let mut txids: Vec = Vec::new(); @@ -11,3 +14,56 @@ pub fn txids_from_buffer(buffer: &[u8]) -> Vec { txids } + +pub struct U32HasherState; + +impl BuildHasher for U32HasherState { + type Hasher = U32Hasher; + + fn build_hasher(&self) -> Self::Hasher { + U32Hasher(0) + } +} + +pub struct U32Hasher(u32); + +impl Hasher for U32Hasher { + fn finish(&self) -> u64 { + // Safety: Two u32s next to each other will make a u64 + unsafe { core::mem::transmute::<(u32, u32), u64>((self.0, 0_u32)) } + } + + fn write(&mut self, bytes: &[u8]) { + // Assert in debug builds (testing too) that only 4 byte keys (u32, i32, f32, etc.) run + debug_assert!(bytes.len() == 4); + // Safety: We know that the size of the key is at least 4 bytes + self.0 = unsafe { *bytes.as_ptr().cast::() }; + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::U32HasherState; + + #[test] + fn test_me() { + let mut hm: HashMap = HashMap::with_hasher(U32HasherState); + + hm.insert(0, String::from("0")); + hm.insert(42, String::from("42")); + hm.insert(256, String::from("256")); + hm.insert(u32::MAX, String::from("MAX")); + hm.insert(u32::MAX >> 2, String::from("MAX >> 2")); + + assert_eq!(hm.get(&0), Some(&String::from("0"))); + assert_eq!(hm.get(&42), Some(&String::from("42"))); + assert_eq!(hm.get(&256), Some(&String::from("256"))); + assert_eq!(hm.get(&u32::MAX), Some(&String::from("MAX"))); + assert_eq!(hm.get(&(u32::MAX >> 2)), Some(&String::from("MAX >> 2"))); + assert_eq!(hm.get(&(u32::MAX >> 4)), None); + assert_eq!(hm.get(&3), None); + assert_eq!(hm.get(&43), None); + } +} diff --git a/backend/src/api/mempool-blocks.ts b/backend/src/api/mempool-blocks.ts index a4786af7c..e0362b6ce 100644 --- a/backend/src/api/mempool-blocks.ts +++ b/backend/src/api/mempool-blocks.ts @@ -1,4 +1,4 @@ -import * as napiAddon from '../../rust-gbt'; +import { GbtGenerator } from '../../rust-gbt'; import logger from '../logger'; import { MempoolBlock, MempoolTransactionExtended, TransactionStripped, MempoolBlockWithTransactions, MempoolBlockDelta, Ancestor, CompactThreadTransaction, EffectiveFeeStats, AuditTransaction } from '../mempool.interfaces'; import { Common, OnlineFeeStatsCalculator } from './common'; @@ -11,6 +11,7 @@ class MempoolBlocks { private mempoolBlockDeltas: MempoolBlockDelta[] = []; private txSelectionWorker: Worker | null = null; private rustInitialized: boolean = false; + private rustGbtGenerator: GbtGenerator = new GbtGenerator(); private nextUid: number = 1; private uidMap: Map = new Map(); // map short numerical uids to full txids @@ -342,7 +343,7 @@ class MempoolBlocks { // run the block construction algorithm in a separate thread, and wait for a result try { const { blocks, rates, clusters } = this.convertNapiResultTxids( - await napiAddon.make(new Uint8Array(mempoolBuffer)), + await this.rustGbtGenerator.make(new Uint8Array(mempoolBuffer)), ); this.rustInitialized = true; const processed = this.processBlockTemplates(newMempool, blocks, rates, clusters, saveResults); @@ -376,7 +377,7 @@ class MempoolBlocks { // run the block construction algorithm in a separate thread, and wait for a result try { const { blocks, rates, clusters } = this.convertNapiResultTxids( - await napiAddon.update( + await this.rustGbtGenerator.update( new Uint8Array(addedBuffer), new Uint8Array(removedBuffer), ),