sweep: add handleInitialBroadcast to handle initial broadcast

This commit adds a new method `handleInitialBroadcast` to handle the
initial broadcast. Previously we'd broadcast immediately inside
`Broadcast`, which soon will not work after the `blockbeat` is
implemented as the action to publish is now always triggered by a new
block. Meanwhile, we still keep the option to bypass the block trigger
so users can broadcast immediately by setting `Immediate` to true.
This commit is contained in:
yyforyongyu 2024-10-25 18:31:46 +08:00
parent 85010c832d
commit c37a3cd1d8
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
2 changed files with 366 additions and 160 deletions

View File

@ -376,40 +376,52 @@ func (t *TxPublisher) isNeutrinoBackend() bool {
return t.cfg.Wallet.BackEnd() == "neutrino" return t.cfg.Wallet.BackEnd() == "neutrino"
} }
// Broadcast is used to publish the tx created from the given inputs. It will, // Broadcast is used to publish the tx created from the given inputs. It will
// 1. init a fee function based on the given strategy. // register the broadcast request and return a chan to the caller to subscribe
// 2. create an RBF-compliant tx and monitor it for confirmation. // the broadcast result. The initial broadcast is guaranteed to be
// 3. notify the initial broadcast result back to the caller. // RBF-compliant unless the budget specified cannot cover the fee.
// The initial broadcast is guaranteed to be RBF-compliant unless the budget
// specified cannot cover the fee.
// //
// NOTE: part of the Bumper interface. // NOTE: part of the Bumper interface.
func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
log.Tracef("Received broadcast request: %s", lnutils.SpewLogClosure( log.Tracef("Received broadcast request: %s", lnutils.SpewLogClosure(
req)) req))
// Attempt an initial broadcast which is guaranteed to comply with the // Store the request.
// RBF rules. requestID, record := t.storeInitialRecord(req)
result, err := t.initialBroadcast(req)
if err != nil {
log.Errorf("Initial broadcast failed: %v", err)
return nil, err
}
// Create a chan to send the result to the caller. // Create a chan to send the result to the caller.
subscriber := make(chan *BumpResult, 1) subscriber := make(chan *BumpResult, 1)
t.subscriberChans.Store(result.requestID, subscriber) t.subscriberChans.Store(requestID, subscriber)
// Send the initial broadcast result to the caller. // Publish the tx immediately if specified.
t.handleResult(result) if req.Immediate {
t.handleInitialBroadcast(record, requestID)
}
return subscriber, nil return subscriber, nil
} }
// storeInitialRecord initializes a monitor record and saves it in the map.
func (t *TxPublisher) storeInitialRecord(req *BumpRequest) (
uint64, *monitorRecord) {
// Increase the request counter.
//
// NOTE: this is the only place where we increase the counter.
requestID := t.requestCounter.Add(1)
// Register the record.
record := &monitorRecord{req: req}
t.records.Store(requestID, record)
return requestID, record
}
// initialBroadcast initializes a fee function, creates an RBF-compliant tx and // initialBroadcast initializes a fee function, creates an RBF-compliant tx and
// broadcasts it. // broadcasts it.
func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { func (t *TxPublisher) initialBroadcast(requestID uint64,
req *BumpRequest) (*BumpResult, error) {
// Create a fee bumping algorithm to be used for future RBF. // Create a fee bumping algorithm to be used for future RBF.
feeAlgo, err := t.initializeFeeFunction(req) feeAlgo, err := t.initializeFeeFunction(req)
if err != nil { if err != nil {
@ -418,7 +430,7 @@ func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) {
// Create the initial tx to be broadcasted. This tx is guaranteed to // Create the initial tx to be broadcasted. This tx is guaranteed to
// comply with the RBF restrictions. // comply with the RBF restrictions.
requestID, err := t.createRBFCompliantTx(req, feeAlgo) err = t.createRBFCompliantTx(requestID, req, feeAlgo)
if err != nil { if err != nil {
return nil, fmt.Errorf("create RBF-compliant tx: %w", err) return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
} }
@ -465,8 +477,8 @@ func (t *TxPublisher) initializeFeeFunction(
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee // so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
// and redo the process until the tx is valid, or return an error when non-RBF // and redo the process until the tx is valid, or return an error when non-RBF
// related errors occur or the budget has been used up. // related errors occur or the budget has been used up.
func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest,
f FeeFunction) (uint64, error) { f FeeFunction) error {
for { for {
// Create a new tx with the given fee rate and check its // Create a new tx with the given fee rate and check its
@ -475,17 +487,18 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
switch { switch {
case err == nil: case err == nil:
// The tx is valid, return the request ID. // The tx is valid, store it.
requestID := t.storeRecord( t.storeRecord(
sweepCtx.tx, req, f, sweepCtx.fee, requestID, sweepCtx.tx, req, f, sweepCtx.fee,
) )
log.Infof("Created tx %v for %v inputs: feerate=%v, "+ log.Infof("Created initial sweep tx=%v for %v inputs: "+
"fee=%v, inputs=%v", sweepCtx.tx.TxHash(), "feerate=%v, fee=%v, inputs:\n%v",
len(req.Inputs), f.FeeRate(), sweepCtx.fee, sweepCtx.tx.TxHash(), len(req.Inputs),
f.FeeRate(), sweepCtx.fee,
inputTypeSummary(req.Inputs)) inputTypeSummary(req.Inputs))
return requestID, nil return nil
// If the error indicates the fees paid is not enough, we will // If the error indicates the fees paid is not enough, we will
// ask the fee function to increase the fee rate and retry. // ask the fee function to increase the fee rate and retry.
@ -516,7 +529,7 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
// cluster these inputs differetly. // cluster these inputs differetly.
increased, err = f.Increment() increased, err = f.Increment()
if err != nil { if err != nil {
return 0, err return err
} }
} }
@ -526,20 +539,14 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
// mempool acceptance. // mempool acceptance.
default: default:
log.Debugf("Failed to create RBF-compliant tx: %v", err) log.Debugf("Failed to create RBF-compliant tx: %v", err)
return 0, err return err
} }
} }
} }
// storeRecord stores the given record in the records map. // storeRecord stores the given record in the records map.
func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, func (t *TxPublisher) storeRecord(requestID uint64, tx *wire.MsgTx,
f FeeFunction, fee btcutil.Amount) uint64 { req *BumpRequest, f FeeFunction, fee btcutil.Amount) {
// Increase the request counter.
//
// NOTE: this is the only place where we increase the
// counter.
requestID := t.requestCounter.Add(1)
// Register the record. // Register the record.
t.records.Store(requestID, &monitorRecord{ t.records.Store(requestID, &monitorRecord{
@ -548,8 +555,6 @@ func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest,
feeFunction: f, feeFunction: f,
fee: fee, fee: fee,
}) })
return requestID
} }
// createAndCheckTx creates a tx based on the given inputs, change output // createAndCheckTx creates a tx based on the given inputs, change output
@ -849,18 +854,27 @@ func (t *TxPublisher) processRecords() {
// confirmed. // confirmed.
confirmedRecords := make(map[uint64]*monitorRecord) confirmedRecords := make(map[uint64]*monitorRecord)
// feeBumpRecords stores a map of the records which need to be bumped. // feeBumpRecords stores a map of records which need to be bumped.
feeBumpRecords := make(map[uint64]*monitorRecord) feeBumpRecords := make(map[uint64]*monitorRecord)
// failedRecords stores a map of the records which has inputs being // failedRecords stores a map of records which has inputs being spent
// spent by a third party. // by a third party.
// //
// NOTE: this is only used for neutrino backend. // NOTE: this is only used for neutrino backend.
failedRecords := make(map[uint64]*monitorRecord) failedRecords := make(map[uint64]*monitorRecord)
// initialRecords stores a map of records which are being created and
// published for the first time.
initialRecords := make(map[uint64]*monitorRecord)
// visitor is a helper closure that visits each record and divides them // visitor is a helper closure that visits each record and divides them
// into two groups. // into two groups.
visitor := func(requestID uint64, r *monitorRecord) error { visitor := func(requestID uint64, r *monitorRecord) error {
if r.tx == nil {
initialRecords[requestID] = r
return nil
}
log.Tracef("Checking monitor recordID=%v for tx=%v", requestID, log.Tracef("Checking monitor recordID=%v for tx=%v", requestID,
r.tx.TxHash()) r.tx.TxHash())
@ -888,9 +902,14 @@ func (t *TxPublisher) processRecords() {
return nil return nil
} }
// Iterate through all the records and divide them into two groups. // Iterate through all the records and divide them into four groups.
t.records.ForEach(visitor) t.records.ForEach(visitor)
// Handle the initial broadcast.
for requestID, r := range initialRecords {
t.handleInitialBroadcast(r, requestID)
}
// For records that are confirmed, we'll notify the caller about this // For records that are confirmed, we'll notify the caller about this
// result. // result.
for requestID, r := range confirmedRecords { for requestID, r := range confirmedRecords {
@ -946,6 +965,69 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) {
t.handleResult(result) t.handleResult(result)
} }
// handleInitialBroadcast is called when a new request is received. It will
// handle the initial tx creation and broadcast. In details,
// 1. init a fee function based on the given strategy.
// 2. create an RBF-compliant tx and monitor it for confirmation.
// 3. notify the initial broadcast result back to the caller.
func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
requestID uint64) {
log.Debugf("Initial broadcast for requestID=%v", requestID)
var (
result *BumpResult
err error
)
// Attempt an initial broadcast which is guaranteed to comply with the
// RBF rules.
result, err = t.initialBroadcast(requestID, r.req)
if err != nil {
log.Errorf("Initial broadcast failed: %v", err)
// We now decide what type of event to send.
var event BumpEvent
switch {
// When the error is due to a dust output, we'll send a
// TxFailed so these inputs can be retried with a different
// group in the next block.
case errors.Is(err, ErrTxNoOutput):
event = TxFailed
// When the error is due to budget being used up, we'll send a
// TxFailed so these inputs can be retried with a different
// group in the next block.
case errors.Is(err, ErrMaxPosition):
event = TxFailed
// When the error is due to zero fee rate delta, we'll send a
// TxFailed so these inputs can be retried in the next block.
case errors.Is(err, ErrZeroFeeRateDelta):
event = TxFailed
// Otherwise this is not a fee-related error and the tx cannot
// be retried. In that case we will fail ALL the inputs in this
// tx, which means they will be removed from the sweeper and
// never be tried again.
//
// TODO(yy): Find out which input is causing the failure and
// fail that one only.
default:
event = TxFatal
}
result = &BumpResult{
Event: event,
Err: err,
requestID: requestID,
}
}
t.handleResult(result)
}
// handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will // handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will
// attempt to bump the fee of the tx. // attempt to bump the fee of the tx.
// //

View File

@ -344,13 +344,10 @@ func TestStoreRecord(t *testing.T) {
initialCounter := tp.requestCounter.Load() initialCounter := tp.requestCounter.Load()
// Call the method under test. // Call the method under test.
requestID := tp.storeRecord(tx, req, feeFunc, fee) tp.storeRecord(initialCounter, tx, req, feeFunc, fee)
// Check the request ID is as expected.
require.Equal(t, initialCounter+1, requestID)
// Read the saved record and compare. // Read the saved record and compare.
record, ok := tp.records.Load(requestID) record, ok := tp.records.Load(initialCounter)
require.True(t, ok) require.True(t, ok)
require.Equal(t, tx, record.tx) require.Equal(t, tx, record.tx)
require.Equal(t, feeFunc, record.feeFunction) require.Equal(t, feeFunc, record.feeFunction)
@ -646,23 +643,19 @@ func TestCreateRBFCompliantTx(t *testing.T) {
}, },
} }
var requestCounter atomic.Uint64
for _, tc := range testCases { for _, tc := range testCases {
tc := tc tc := tc
rid := requestCounter.Add(1)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
tc.setupMock() tc.setupMock()
// Call the method under test. // Call the method under test.
id, err := tp.createRBFCompliantTx(req, m.feeFunc) err := tp.createRBFCompliantTx(rid, req, m.feeFunc)
// Check the result is as expected. // Check the result is as expected.
require.ErrorIs(t, err, tc.expectedErr) require.ErrorIs(t, err, tc.expectedErr)
// If there's an error, expect the requestID to be
// empty.
if tc.expectedErr != nil {
require.Zero(t, id)
}
}) })
} }
} }
@ -687,7 +680,8 @@ func TestTxPublisherBroadcast(t *testing.T) {
// Create a testing record and put it in the map. // Create a testing record and put it in the map.
fee := btcutil.Amount(1000) fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee) requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee)
// Quickly check when the requestID cannot be found, an error is // Quickly check when the requestID cannot be found, an error is
// returned. // returned.
@ -774,6 +768,9 @@ func TestRemoveResult(t *testing.T) {
// Create a testing record and put it in the map. // Create a testing record and put it in the map.
fee := btcutil.Amount(1000) fee := btcutil.Amount(1000)
// Create a test request ID counter.
requestCounter := atomic.Uint64{}
testCases := []struct { testCases := []struct {
name string name string
setupRecord func() uint64 setupRecord func() uint64
@ -785,10 +782,11 @@ func TestRemoveResult(t *testing.T) {
// removed. // removed.
name: "remove on TxConfirmed", name: "remove on TxConfirmed",
setupRecord: func() uint64 { setupRecord: func() uint64 {
id := tp.storeRecord(tx, req, m.feeFunc, fee) rid := requestCounter.Add(1)
tp.subscriberChans.Store(id, nil) tp.storeRecord(rid, tx, req, m.feeFunc, fee)
tp.subscriberChans.Store(rid, nil)
return id return rid
}, },
result: &BumpResult{ result: &BumpResult{
Event: TxConfirmed, Event: TxConfirmed,
@ -800,10 +798,11 @@ func TestRemoveResult(t *testing.T) {
// When the tx is failed, the records will be removed. // When the tx is failed, the records will be removed.
name: "remove on TxFailed", name: "remove on TxFailed",
setupRecord: func() uint64 { setupRecord: func() uint64 {
id := tp.storeRecord(tx, req, m.feeFunc, fee) rid := requestCounter.Add(1)
tp.subscriberChans.Store(id, nil) tp.storeRecord(rid, tx, req, m.feeFunc, fee)
tp.subscriberChans.Store(rid, nil)
return id return rid
}, },
result: &BumpResult{ result: &BumpResult{
Event: TxFailed, Event: TxFailed,
@ -816,10 +815,11 @@ func TestRemoveResult(t *testing.T) {
// Noop when the tx is neither confirmed or failed. // Noop when the tx is neither confirmed or failed.
name: "noop when tx is not confirmed or failed", name: "noop when tx is not confirmed or failed",
setupRecord: func() uint64 { setupRecord: func() uint64 {
id := tp.storeRecord(tx, req, m.feeFunc, fee) rid := requestCounter.Add(1)
tp.subscriberChans.Store(id, nil) tp.storeRecord(rid, tx, req, m.feeFunc, fee)
tp.subscriberChans.Store(rid, nil)
return id return rid
}, },
result: &BumpResult{ result: &BumpResult{
Event: TxPublished, Event: TxPublished,
@ -866,7 +866,8 @@ func TestNotifyResult(t *testing.T) {
// Create a testing record and put it in the map. // Create a testing record and put it in the map.
fee := btcutil.Amount(1000) fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee) requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee)
// Create a subscription to the event. // Create a subscription to the event.
subscriber := make(chan *BumpResult, 1) subscriber := make(chan *BumpResult, 1)
@ -914,41 +915,17 @@ func TestNotifyResult(t *testing.T) {
} }
} }
// TestBroadcastSuccess checks the public `Broadcast` method can successfully // TestBroadcast checks the public `Broadcast` method can successfully register
// broadcast a tx based on the request. // a broadcast request.
func TestBroadcastSuccess(t *testing.T) { func TestBroadcast(t *testing.T) {
t.Parallel() t.Parallel()
// Create a publisher using the mocks. // Create a publisher using the mocks.
tp, m := createTestPublisher(t) tp, _ := createTestPublisher(t)
// Create a test feerate. // Create a test feerate.
feerate := chainfee.SatPerKWeight(1000) feerate := chainfee.SatPerKWeight(1000)
// Mock the fee estimator to return the testing fee rate.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Once()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to pass.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Mock the wallet to publish successfully.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(nil).Once()
// Create a test request. // Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash) inp := createTestInput(1000, input.WitnessKeyHash)
@ -964,25 +941,23 @@ func TestBroadcastSuccess(t *testing.T) {
// Send the req and expect no error. // Send the req and expect no error.
resultChan, err := tp.Broadcast(req) resultChan, err := tp.Broadcast(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resultChan)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the first result to be TxPublished.
require.Equal(t, TxPublished, result.Event)
}
// Validate the record was stored. // Validate the record was stored.
require.Equal(t, 1, tp.records.Len()) require.Equal(t, 1, tp.records.Len())
require.Equal(t, 1, tp.subscriberChans.Len()) require.Equal(t, 1, tp.subscriberChans.Len())
// Validate the record.
rid := tp.requestCounter.Load()
record, found := tp.records.Load(rid)
require.True(t, found)
require.Equal(t, req, record.req)
} }
// TestBroadcastFail checks the public `Broadcast` returns the error or a // TestBroadcastImmediate checks the public `Broadcast` method can successfully
// failed result when the broadcast fails. // register a broadcast request and publish the tx when `Immediate` flag is
func TestBroadcastFail(t *testing.T) { // set.
func TestBroadcastImmediate(t *testing.T) {
t.Parallel() t.Parallel()
// Create a publisher using the mocks. // Create a publisher using the mocks.
@ -1001,64 +976,28 @@ func TestBroadcastFail(t *testing.T) {
Budget: btcutil.Amount(1000), Budget: btcutil.Amount(1000),
MaxFeeRate: feerate * 10, MaxFeeRate: feerate * 10,
DeadlineHeight: 10, DeadlineHeight: 10,
Immediate: true,
} }
// Mock the fee estimator to return the testing fee rate. // Mock the fee estimator to return an error.
// //
// We are not testing `NewLinearFeeFunction` here, so the actual params // NOTE: We are not testing `handleInitialBroadcast` here, but only
// used are irrelevant. // interested in checking that this method is indeed called when
// `Immediate` is true. Thus we mock the method to return an error to
// quickly abort. As long as this mocked method is called, we know the
// `Immediate` flag works.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return( m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Twice() chainfee.SatPerKWeight(0), errDummy).Once()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
// Mock the signer to always return a valid script. // Send the req and expect no error.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to return an error.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(errDummy).Once()
// Send the req and expect an error returned.
resultChan, err := tp.Broadcast(req) resultChan, err := tp.Broadcast(req)
require.ErrorIs(t, err, errDummy)
require.Nil(t, resultChan)
// Validate the record was NOT stored.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
// Mock the testmempoolaccept again, this time it passes.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Mock the wallet to fail on publish.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(errDummy).Once()
// Send the req and expect no error returned.
resultChan, err = tp.Broadcast(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resultChan)
// Check the result is sent back. // Validate the record was removed due to an error returned in initial
select { // broadcast.
case <-time.After(time.Second): require.Empty(t, tp.records.Len())
t.Fatal("timeout waiting for subscriber to receive result") require.Empty(t, tp.subscriberChans.Len())
case result := <-resultChan:
// We expect the result to be TxFailed and the error is set in
// the result.
require.Equal(t, TxFailed, result.Event)
require.ErrorIs(t, result.Err, errDummy)
}
// Validate the record was removed.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
} }
// TestCreateAnPublishFail checks all the error cases are handled properly in // TestCreateAnPublishFail checks all the error cases are handled properly in
@ -1223,7 +1162,8 @@ func TestHandleTxConfirmed(t *testing.T) {
// Create a testing record and put it in the map. // Create a testing record and put it in the map.
fee := btcutil.Amount(1000) fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee) requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee)
record, ok := tp.records.Load(requestID) record, ok := tp.records.Load(requestID)
require.True(t, ok) require.True(t, ok)
@ -1295,7 +1235,8 @@ func TestHandleFeeBumpTx(t *testing.T) {
// Create a testing record and put it in the map. // Create a testing record and put it in the map.
fee := btcutil.Amount(1000) fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee) requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee)
// Create a subscription to the event. // Create a subscription to the event.
subscriber := make(chan *BumpResult, 1) subscriber := make(chan *BumpResult, 1)
@ -1496,3 +1437,186 @@ func TestProcessRecords(t *testing.T) {
require.Equal(t, requestID2, result.requestID) require.Equal(t, requestID2, result.requestID)
} }
} }
// TestHandleInitialBroadcastSuccess checks `handleInitialBroadcast` method can
// successfully broadcast a tx based on the request.
func TestHandleInitialBroadcastSuccess(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test feerate.
feerate := chainfee.SatPerKWeight(1000)
// Mock the fee estimator to return the testing fee rate.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Once()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to pass.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Mock the wallet to publish successfully.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(nil).Once()
// Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash)
// Create a testing bump request.
req := &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
MaxFeeRate: feerate * 10,
DeadlineHeight: 10,
}
// Register the testing record use `Broadcast`.
resultChan, err := tp.Broadcast(req)
require.NoError(t, err)
// Grab the monitor record from the map.
rid := tp.requestCounter.Load()
rec, ok := tp.records.Load(rid)
require.True(t, ok)
// Call the method under test.
tp.wg.Add(1)
tp.handleInitialBroadcast(rec, rid)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the first result to be TxPublished.
require.Equal(t, TxPublished, result.Event)
}
// Validate the record was stored.
require.Equal(t, 1, tp.records.Len())
require.Equal(t, 1, tp.subscriberChans.Len())
}
// TestHandleInitialBroadcastFail checks `handleInitialBroadcast` returns the
// error or a failed result when the broadcast fails.
func TestHandleInitialBroadcastFail(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test feerate.
feerate := chainfee.SatPerKWeight(1000)
// Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash)
// Create a testing bump request.
req := &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
MaxFeeRate: feerate * 10,
DeadlineHeight: 10,
}
// Mock the fee estimator to return the testing fee rate.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Twice()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to return an error.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(errDummy).Once()
// Register the testing record use `Broadcast`.
resultChan, err := tp.Broadcast(req)
require.NoError(t, err)
// Grab the monitor record from the map.
rid := tp.requestCounter.Load()
rec, ok := tp.records.Load(rid)
require.True(t, ok)
// Call the method under test and expect an error returned.
tp.wg.Add(1)
tp.handleInitialBroadcast(rec, rid)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the first result to be TxFatal.
require.Equal(t, TxFatal, result.Event)
}
// Validate the record was NOT stored.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
// Mock the testmempoolaccept again, this time it passes.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Mock the wallet to fail on publish.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(errDummy).Once()
// Register the testing record use `Broadcast`.
resultChan, err = tp.Broadcast(req)
require.NoError(t, err)
// Grab the monitor record from the map.
rid = tp.requestCounter.Load()
rec, ok = tp.records.Load(rid)
require.True(t, ok)
// Call the method under test.
tp.wg.Add(1)
tp.handleInitialBroadcast(rec, rid)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the result to be TxFailed and the error is set in
// the result.
require.Equal(t, TxFailed, result.Event)
require.ErrorIs(t, result.Err, errDummy)
}
// Validate the record was removed.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
}