mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 09:53:54 +01:00
sweep: add handleInitialBroadcast
to handle initial broadcast
This commit adds a new method `handleInitialBroadcast` to handle the initial broadcast. Previously we'd broadcast immediately inside `Broadcast`, which soon will not work after the `blockbeat` is implemented as the action to publish is now always triggered by a new block. Meanwhile, we still keep the option to bypass the block trigger so users can broadcast immediately by setting `Immediate` to true.
This commit is contained in:
parent
85010c832d
commit
c37a3cd1d8
@ -376,40 +376,52 @@ func (t *TxPublisher) isNeutrinoBackend() bool {
|
||||
return t.cfg.Wallet.BackEnd() == "neutrino"
|
||||
}
|
||||
|
||||
// Broadcast is used to publish the tx created from the given inputs. It will,
|
||||
// 1. init a fee function based on the given strategy.
|
||||
// 2. create an RBF-compliant tx and monitor it for confirmation.
|
||||
// 3. notify the initial broadcast result back to the caller.
|
||||
// The initial broadcast is guaranteed to be RBF-compliant unless the budget
|
||||
// specified cannot cover the fee.
|
||||
// Broadcast is used to publish the tx created from the given inputs. It will
|
||||
// register the broadcast request and return a chan to the caller to subscribe
|
||||
// the broadcast result. The initial broadcast is guaranteed to be
|
||||
// RBF-compliant unless the budget specified cannot cover the fee.
|
||||
//
|
||||
// NOTE: part of the Bumper interface.
|
||||
func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
|
||||
log.Tracef("Received broadcast request: %s", lnutils.SpewLogClosure(
|
||||
req))
|
||||
|
||||
// Attempt an initial broadcast which is guaranteed to comply with the
|
||||
// RBF rules.
|
||||
result, err := t.initialBroadcast(req)
|
||||
if err != nil {
|
||||
log.Errorf("Initial broadcast failed: %v", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
// Store the request.
|
||||
requestID, record := t.storeInitialRecord(req)
|
||||
|
||||
// Create a chan to send the result to the caller.
|
||||
subscriber := make(chan *BumpResult, 1)
|
||||
t.subscriberChans.Store(result.requestID, subscriber)
|
||||
t.subscriberChans.Store(requestID, subscriber)
|
||||
|
||||
// Send the initial broadcast result to the caller.
|
||||
t.handleResult(result)
|
||||
// Publish the tx immediately if specified.
|
||||
if req.Immediate {
|
||||
t.handleInitialBroadcast(record, requestID)
|
||||
}
|
||||
|
||||
return subscriber, nil
|
||||
}
|
||||
|
||||
// storeInitialRecord initializes a monitor record and saves it in the map.
|
||||
func (t *TxPublisher) storeInitialRecord(req *BumpRequest) (
|
||||
uint64, *monitorRecord) {
|
||||
|
||||
// Increase the request counter.
|
||||
//
|
||||
// NOTE: this is the only place where we increase the counter.
|
||||
requestID := t.requestCounter.Add(1)
|
||||
|
||||
// Register the record.
|
||||
record := &monitorRecord{req: req}
|
||||
t.records.Store(requestID, record)
|
||||
|
||||
return requestID, record
|
||||
}
|
||||
|
||||
// initialBroadcast initializes a fee function, creates an RBF-compliant tx and
|
||||
// broadcasts it.
|
||||
func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) {
|
||||
func (t *TxPublisher) initialBroadcast(requestID uint64,
|
||||
req *BumpRequest) (*BumpResult, error) {
|
||||
|
||||
// Create a fee bumping algorithm to be used for future RBF.
|
||||
feeAlgo, err := t.initializeFeeFunction(req)
|
||||
if err != nil {
|
||||
@ -418,7 +430,7 @@ func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) {
|
||||
|
||||
// Create the initial tx to be broadcasted. This tx is guaranteed to
|
||||
// comply with the RBF restrictions.
|
||||
requestID, err := t.createRBFCompliantTx(req, feeAlgo)
|
||||
err = t.createRBFCompliantTx(requestID, req, feeAlgo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
|
||||
}
|
||||
@ -465,8 +477,8 @@ func (t *TxPublisher) initializeFeeFunction(
|
||||
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
|
||||
// and redo the process until the tx is valid, or return an error when non-RBF
|
||||
// related errors occur or the budget has been used up.
|
||||
func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
|
||||
f FeeFunction) (uint64, error) {
|
||||
func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest,
|
||||
f FeeFunction) error {
|
||||
|
||||
for {
|
||||
// Create a new tx with the given fee rate and check its
|
||||
@ -475,17 +487,18 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
// The tx is valid, return the request ID.
|
||||
requestID := t.storeRecord(
|
||||
sweepCtx.tx, req, f, sweepCtx.fee,
|
||||
// The tx is valid, store it.
|
||||
t.storeRecord(
|
||||
requestID, sweepCtx.tx, req, f, sweepCtx.fee,
|
||||
)
|
||||
|
||||
log.Infof("Created tx %v for %v inputs: feerate=%v, "+
|
||||
"fee=%v, inputs=%v", sweepCtx.tx.TxHash(),
|
||||
len(req.Inputs), f.FeeRate(), sweepCtx.fee,
|
||||
log.Infof("Created initial sweep tx=%v for %v inputs: "+
|
||||
"feerate=%v, fee=%v, inputs:\n%v",
|
||||
sweepCtx.tx.TxHash(), len(req.Inputs),
|
||||
f.FeeRate(), sweepCtx.fee,
|
||||
inputTypeSummary(req.Inputs))
|
||||
|
||||
return requestID, nil
|
||||
return nil
|
||||
|
||||
// If the error indicates the fees paid is not enough, we will
|
||||
// ask the fee function to increase the fee rate and retry.
|
||||
@ -516,7 +529,7 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
|
||||
// cluster these inputs differetly.
|
||||
increased, err = f.Increment()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -526,20 +539,14 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
|
||||
// mempool acceptance.
|
||||
default:
|
||||
log.Debugf("Failed to create RBF-compliant tx: %v", err)
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// storeRecord stores the given record in the records map.
|
||||
func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest,
|
||||
f FeeFunction, fee btcutil.Amount) uint64 {
|
||||
|
||||
// Increase the request counter.
|
||||
//
|
||||
// NOTE: this is the only place where we increase the
|
||||
// counter.
|
||||
requestID := t.requestCounter.Add(1)
|
||||
func (t *TxPublisher) storeRecord(requestID uint64, tx *wire.MsgTx,
|
||||
req *BumpRequest, f FeeFunction, fee btcutil.Amount) {
|
||||
|
||||
// Register the record.
|
||||
t.records.Store(requestID, &monitorRecord{
|
||||
@ -548,8 +555,6 @@ func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest,
|
||||
feeFunction: f,
|
||||
fee: fee,
|
||||
})
|
||||
|
||||
return requestID
|
||||
}
|
||||
|
||||
// createAndCheckTx creates a tx based on the given inputs, change output
|
||||
@ -849,18 +854,27 @@ func (t *TxPublisher) processRecords() {
|
||||
// confirmed.
|
||||
confirmedRecords := make(map[uint64]*monitorRecord)
|
||||
|
||||
// feeBumpRecords stores a map of the records which need to be bumped.
|
||||
// feeBumpRecords stores a map of records which need to be bumped.
|
||||
feeBumpRecords := make(map[uint64]*monitorRecord)
|
||||
|
||||
// failedRecords stores a map of the records which has inputs being
|
||||
// spent by a third party.
|
||||
// failedRecords stores a map of records which has inputs being spent
|
||||
// by a third party.
|
||||
//
|
||||
// NOTE: this is only used for neutrino backend.
|
||||
failedRecords := make(map[uint64]*monitorRecord)
|
||||
|
||||
// initialRecords stores a map of records which are being created and
|
||||
// published for the first time.
|
||||
initialRecords := make(map[uint64]*monitorRecord)
|
||||
|
||||
// visitor is a helper closure that visits each record and divides them
|
||||
// into two groups.
|
||||
visitor := func(requestID uint64, r *monitorRecord) error {
|
||||
if r.tx == nil {
|
||||
initialRecords[requestID] = r
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Tracef("Checking monitor recordID=%v for tx=%v", requestID,
|
||||
r.tx.TxHash())
|
||||
|
||||
@ -888,9 +902,14 @@ func (t *TxPublisher) processRecords() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Iterate through all the records and divide them into two groups.
|
||||
// Iterate through all the records and divide them into four groups.
|
||||
t.records.ForEach(visitor)
|
||||
|
||||
// Handle the initial broadcast.
|
||||
for requestID, r := range initialRecords {
|
||||
t.handleInitialBroadcast(r, requestID)
|
||||
}
|
||||
|
||||
// For records that are confirmed, we'll notify the caller about this
|
||||
// result.
|
||||
for requestID, r := range confirmedRecords {
|
||||
@ -946,6 +965,69 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) {
|
||||
t.handleResult(result)
|
||||
}
|
||||
|
||||
// handleInitialBroadcast is called when a new request is received. It will
|
||||
// handle the initial tx creation and broadcast. In details,
|
||||
// 1. init a fee function based on the given strategy.
|
||||
// 2. create an RBF-compliant tx and monitor it for confirmation.
|
||||
// 3. notify the initial broadcast result back to the caller.
|
||||
func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
|
||||
requestID uint64) {
|
||||
|
||||
log.Debugf("Initial broadcast for requestID=%v", requestID)
|
||||
|
||||
var (
|
||||
result *BumpResult
|
||||
err error
|
||||
)
|
||||
|
||||
// Attempt an initial broadcast which is guaranteed to comply with the
|
||||
// RBF rules.
|
||||
result, err = t.initialBroadcast(requestID, r.req)
|
||||
if err != nil {
|
||||
log.Errorf("Initial broadcast failed: %v", err)
|
||||
|
||||
// We now decide what type of event to send.
|
||||
var event BumpEvent
|
||||
|
||||
switch {
|
||||
// When the error is due to a dust output, we'll send a
|
||||
// TxFailed so these inputs can be retried with a different
|
||||
// group in the next block.
|
||||
case errors.Is(err, ErrTxNoOutput):
|
||||
event = TxFailed
|
||||
|
||||
// When the error is due to budget being used up, we'll send a
|
||||
// TxFailed so these inputs can be retried with a different
|
||||
// group in the next block.
|
||||
case errors.Is(err, ErrMaxPosition):
|
||||
event = TxFailed
|
||||
|
||||
// When the error is due to zero fee rate delta, we'll send a
|
||||
// TxFailed so these inputs can be retried in the next block.
|
||||
case errors.Is(err, ErrZeroFeeRateDelta):
|
||||
event = TxFailed
|
||||
|
||||
// Otherwise this is not a fee-related error and the tx cannot
|
||||
// be retried. In that case we will fail ALL the inputs in this
|
||||
// tx, which means they will be removed from the sweeper and
|
||||
// never be tried again.
|
||||
//
|
||||
// TODO(yy): Find out which input is causing the failure and
|
||||
// fail that one only.
|
||||
default:
|
||||
event = TxFatal
|
||||
}
|
||||
|
||||
result = &BumpResult{
|
||||
Event: event,
|
||||
Err: err,
|
||||
requestID: requestID,
|
||||
}
|
||||
}
|
||||
|
||||
t.handleResult(result)
|
||||
}
|
||||
|
||||
// handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will
|
||||
// attempt to bump the fee of the tx.
|
||||
//
|
||||
|
@ -344,13 +344,10 @@ func TestStoreRecord(t *testing.T) {
|
||||
initialCounter := tp.requestCounter.Load()
|
||||
|
||||
// Call the method under test.
|
||||
requestID := tp.storeRecord(tx, req, feeFunc, fee)
|
||||
|
||||
// Check the request ID is as expected.
|
||||
require.Equal(t, initialCounter+1, requestID)
|
||||
tp.storeRecord(initialCounter, tx, req, feeFunc, fee)
|
||||
|
||||
// Read the saved record and compare.
|
||||
record, ok := tp.records.Load(requestID)
|
||||
record, ok := tp.records.Load(initialCounter)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tx, record.tx)
|
||||
require.Equal(t, feeFunc, record.feeFunction)
|
||||
@ -646,23 +643,19 @@ func TestCreateRBFCompliantTx(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
var requestCounter atomic.Uint64
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
rid := requestCounter.Add(1)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.setupMock()
|
||||
|
||||
// Call the method under test.
|
||||
id, err := tp.createRBFCompliantTx(req, m.feeFunc)
|
||||
err := tp.createRBFCompliantTx(rid, req, m.feeFunc)
|
||||
|
||||
// Check the result is as expected.
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
|
||||
// If there's an error, expect the requestID to be
|
||||
// empty.
|
||||
if tc.expectedErr != nil {
|
||||
require.Zero(t, id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -687,7 +680,8 @@ func TestTxPublisherBroadcast(t *testing.T) {
|
||||
|
||||
// Create a testing record and put it in the map.
|
||||
fee := btcutil.Amount(1000)
|
||||
requestID := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
requestID := uint64(1)
|
||||
tp.storeRecord(requestID, tx, req, m.feeFunc, fee)
|
||||
|
||||
// Quickly check when the requestID cannot be found, an error is
|
||||
// returned.
|
||||
@ -774,6 +768,9 @@ func TestRemoveResult(t *testing.T) {
|
||||
// Create a testing record and put it in the map.
|
||||
fee := btcutil.Amount(1000)
|
||||
|
||||
// Create a test request ID counter.
|
||||
requestCounter := atomic.Uint64{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupRecord func() uint64
|
||||
@ -785,10 +782,11 @@ func TestRemoveResult(t *testing.T) {
|
||||
// removed.
|
||||
name: "remove on TxConfirmed",
|
||||
setupRecord: func() uint64 {
|
||||
id := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
tp.subscriberChans.Store(id, nil)
|
||||
rid := requestCounter.Add(1)
|
||||
tp.storeRecord(rid, tx, req, m.feeFunc, fee)
|
||||
tp.subscriberChans.Store(rid, nil)
|
||||
|
||||
return id
|
||||
return rid
|
||||
},
|
||||
result: &BumpResult{
|
||||
Event: TxConfirmed,
|
||||
@ -800,10 +798,11 @@ func TestRemoveResult(t *testing.T) {
|
||||
// When the tx is failed, the records will be removed.
|
||||
name: "remove on TxFailed",
|
||||
setupRecord: func() uint64 {
|
||||
id := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
tp.subscriberChans.Store(id, nil)
|
||||
rid := requestCounter.Add(1)
|
||||
tp.storeRecord(rid, tx, req, m.feeFunc, fee)
|
||||
tp.subscriberChans.Store(rid, nil)
|
||||
|
||||
return id
|
||||
return rid
|
||||
},
|
||||
result: &BumpResult{
|
||||
Event: TxFailed,
|
||||
@ -816,10 +815,11 @@ func TestRemoveResult(t *testing.T) {
|
||||
// Noop when the tx is neither confirmed or failed.
|
||||
name: "noop when tx is not confirmed or failed",
|
||||
setupRecord: func() uint64 {
|
||||
id := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
tp.subscriberChans.Store(id, nil)
|
||||
rid := requestCounter.Add(1)
|
||||
tp.storeRecord(rid, tx, req, m.feeFunc, fee)
|
||||
tp.subscriberChans.Store(rid, nil)
|
||||
|
||||
return id
|
||||
return rid
|
||||
},
|
||||
result: &BumpResult{
|
||||
Event: TxPublished,
|
||||
@ -866,7 +866,8 @@ func TestNotifyResult(t *testing.T) {
|
||||
|
||||
// Create a testing record and put it in the map.
|
||||
fee := btcutil.Amount(1000)
|
||||
requestID := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
requestID := uint64(1)
|
||||
tp.storeRecord(requestID, tx, req, m.feeFunc, fee)
|
||||
|
||||
// Create a subscription to the event.
|
||||
subscriber := make(chan *BumpResult, 1)
|
||||
@ -914,41 +915,17 @@ func TestNotifyResult(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcastSuccess checks the public `Broadcast` method can successfully
|
||||
// broadcast a tx based on the request.
|
||||
func TestBroadcastSuccess(t *testing.T) {
|
||||
// TestBroadcast checks the public `Broadcast` method can successfully register
|
||||
// a broadcast request.
|
||||
func TestBroadcast(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
tp, _ := createTestPublisher(t)
|
||||
|
||||
// Create a test feerate.
|
||||
feerate := chainfee.SatPerKWeight(1000)
|
||||
|
||||
// Mock the fee estimator to return the testing fee rate.
|
||||
//
|
||||
// We are not testing `NewLinearFeeFunction` here, so the actual params
|
||||
// used are irrelevant.
|
||||
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
|
||||
feerate, nil).Once()
|
||||
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
|
||||
|
||||
// Mock the signer to always return a valid script.
|
||||
//
|
||||
// NOTE: we are not testing the utility of creating valid txes here, so
|
||||
// this is fine to be mocked. This behaves essentially as skipping the
|
||||
// Signer check and alaways assume the tx has a valid sig.
|
||||
script := &input.Script{}
|
||||
m.signer.On("ComputeInputScript", mock.Anything,
|
||||
mock.Anything).Return(script, nil)
|
||||
|
||||
// Mock the testmempoolaccept to pass.
|
||||
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
|
||||
|
||||
// Mock the wallet to publish successfully.
|
||||
m.wallet.On("PublishTransaction",
|
||||
mock.Anything, mock.Anything).Return(nil).Once()
|
||||
|
||||
// Create a test request.
|
||||
inp := createTestInput(1000, input.WitnessKeyHash)
|
||||
|
||||
@ -964,25 +941,23 @@ func TestBroadcastSuccess(t *testing.T) {
|
||||
// Send the req and expect no error.
|
||||
resultChan, err := tp.Broadcast(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the result is sent back.
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for subscriber to receive result")
|
||||
|
||||
case result := <-resultChan:
|
||||
// We expect the first result to be TxPublished.
|
||||
require.Equal(t, TxPublished, result.Event)
|
||||
}
|
||||
require.NotNil(t, resultChan)
|
||||
|
||||
// Validate the record was stored.
|
||||
require.Equal(t, 1, tp.records.Len())
|
||||
require.Equal(t, 1, tp.subscriberChans.Len())
|
||||
|
||||
// Validate the record.
|
||||
rid := tp.requestCounter.Load()
|
||||
record, found := tp.records.Load(rid)
|
||||
require.True(t, found)
|
||||
require.Equal(t, req, record.req)
|
||||
}
|
||||
|
||||
// TestBroadcastFail checks the public `Broadcast` returns the error or a
|
||||
// failed result when the broadcast fails.
|
||||
func TestBroadcastFail(t *testing.T) {
|
||||
// TestBroadcastImmediate checks the public `Broadcast` method can successfully
|
||||
// register a broadcast request and publish the tx when `Immediate` flag is
|
||||
// set.
|
||||
func TestBroadcastImmediate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
@ -1001,64 +976,28 @@ func TestBroadcastFail(t *testing.T) {
|
||||
Budget: btcutil.Amount(1000),
|
||||
MaxFeeRate: feerate * 10,
|
||||
DeadlineHeight: 10,
|
||||
Immediate: true,
|
||||
}
|
||||
|
||||
// Mock the fee estimator to return the testing fee rate.
|
||||
// Mock the fee estimator to return an error.
|
||||
//
|
||||
// We are not testing `NewLinearFeeFunction` here, so the actual params
|
||||
// used are irrelevant.
|
||||
// NOTE: We are not testing `handleInitialBroadcast` here, but only
|
||||
// interested in checking that this method is indeed called when
|
||||
// `Immediate` is true. Thus we mock the method to return an error to
|
||||
// quickly abort. As long as this mocked method is called, we know the
|
||||
// `Immediate` flag works.
|
||||
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
|
||||
feerate, nil).Twice()
|
||||
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
|
||||
chainfee.SatPerKWeight(0), errDummy).Once()
|
||||
|
||||
// Mock the signer to always return a valid script.
|
||||
//
|
||||
// NOTE: we are not testing the utility of creating valid txes here, so
|
||||
// this is fine to be mocked. This behaves essentially as skipping the
|
||||
// Signer check and alaways assume the tx has a valid sig.
|
||||
script := &input.Script{}
|
||||
m.signer.On("ComputeInputScript", mock.Anything,
|
||||
mock.Anything).Return(script, nil)
|
||||
|
||||
// Mock the testmempoolaccept to return an error.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(errDummy).Once()
|
||||
|
||||
// Send the req and expect an error returned.
|
||||
// Send the req and expect no error.
|
||||
resultChan, err := tp.Broadcast(req)
|
||||
require.ErrorIs(t, err, errDummy)
|
||||
require.Nil(t, resultChan)
|
||||
|
||||
// Validate the record was NOT stored.
|
||||
require.Equal(t, 0, tp.records.Len())
|
||||
require.Equal(t, 0, tp.subscriberChans.Len())
|
||||
|
||||
// Mock the testmempoolaccept again, this time it passes.
|
||||
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
|
||||
|
||||
// Mock the wallet to fail on publish.
|
||||
m.wallet.On("PublishTransaction",
|
||||
mock.Anything, mock.Anything).Return(errDummy).Once()
|
||||
|
||||
// Send the req and expect no error returned.
|
||||
resultChan, err = tp.Broadcast(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultChan)
|
||||
|
||||
// Check the result is sent back.
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for subscriber to receive result")
|
||||
|
||||
case result := <-resultChan:
|
||||
// We expect the result to be TxFailed and the error is set in
|
||||
// the result.
|
||||
require.Equal(t, TxFailed, result.Event)
|
||||
require.ErrorIs(t, result.Err, errDummy)
|
||||
}
|
||||
|
||||
// Validate the record was removed.
|
||||
require.Equal(t, 0, tp.records.Len())
|
||||
require.Equal(t, 0, tp.subscriberChans.Len())
|
||||
// Validate the record was removed due to an error returned in initial
|
||||
// broadcast.
|
||||
require.Empty(t, tp.records.Len())
|
||||
require.Empty(t, tp.subscriberChans.Len())
|
||||
}
|
||||
|
||||
// TestCreateAnPublishFail checks all the error cases are handled properly in
|
||||
@ -1223,7 +1162,8 @@ func TestHandleTxConfirmed(t *testing.T) {
|
||||
|
||||
// Create a testing record and put it in the map.
|
||||
fee := btcutil.Amount(1000)
|
||||
requestID := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
requestID := uint64(1)
|
||||
tp.storeRecord(requestID, tx, req, m.feeFunc, fee)
|
||||
record, ok := tp.records.Load(requestID)
|
||||
require.True(t, ok)
|
||||
|
||||
@ -1295,7 +1235,8 @@ func TestHandleFeeBumpTx(t *testing.T) {
|
||||
|
||||
// Create a testing record and put it in the map.
|
||||
fee := btcutil.Amount(1000)
|
||||
requestID := tp.storeRecord(tx, req, m.feeFunc, fee)
|
||||
requestID := uint64(1)
|
||||
tp.storeRecord(requestID, tx, req, m.feeFunc, fee)
|
||||
|
||||
// Create a subscription to the event.
|
||||
subscriber := make(chan *BumpResult, 1)
|
||||
@ -1496,3 +1437,186 @@ func TestProcessRecords(t *testing.T) {
|
||||
require.Equal(t, requestID2, result.requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleInitialBroadcastSuccess checks `handleInitialBroadcast` method can
|
||||
// successfully broadcast a tx based on the request.
|
||||
func TestHandleInitialBroadcastSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
|
||||
// Create a test feerate.
|
||||
feerate := chainfee.SatPerKWeight(1000)
|
||||
|
||||
// Mock the fee estimator to return the testing fee rate.
|
||||
//
|
||||
// We are not testing `NewLinearFeeFunction` here, so the actual params
|
||||
// used are irrelevant.
|
||||
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
|
||||
feerate, nil).Once()
|
||||
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
|
||||
|
||||
// Mock the signer to always return a valid script.
|
||||
//
|
||||
// NOTE: we are not testing the utility of creating valid txes here, so
|
||||
// this is fine to be mocked. This behaves essentially as skipping the
|
||||
// Signer check and alaways assume the tx has a valid sig.
|
||||
script := &input.Script{}
|
||||
m.signer.On("ComputeInputScript", mock.Anything,
|
||||
mock.Anything).Return(script, nil)
|
||||
|
||||
// Mock the testmempoolaccept to pass.
|
||||
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
|
||||
|
||||
// Mock the wallet to publish successfully.
|
||||
m.wallet.On("PublishTransaction",
|
||||
mock.Anything, mock.Anything).Return(nil).Once()
|
||||
|
||||
// Create a test request.
|
||||
inp := createTestInput(1000, input.WitnessKeyHash)
|
||||
|
||||
// Create a testing bump request.
|
||||
req := &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
Budget: btcutil.Amount(1000),
|
||||
MaxFeeRate: feerate * 10,
|
||||
DeadlineHeight: 10,
|
||||
}
|
||||
|
||||
// Register the testing record use `Broadcast`.
|
||||
resultChan, err := tp.Broadcast(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Grab the monitor record from the map.
|
||||
rid := tp.requestCounter.Load()
|
||||
rec, ok := tp.records.Load(rid)
|
||||
require.True(t, ok)
|
||||
|
||||
// Call the method under test.
|
||||
tp.wg.Add(1)
|
||||
tp.handleInitialBroadcast(rec, rid)
|
||||
|
||||
// Check the result is sent back.
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for subscriber to receive result")
|
||||
|
||||
case result := <-resultChan:
|
||||
// We expect the first result to be TxPublished.
|
||||
require.Equal(t, TxPublished, result.Event)
|
||||
}
|
||||
|
||||
// Validate the record was stored.
|
||||
require.Equal(t, 1, tp.records.Len())
|
||||
require.Equal(t, 1, tp.subscriberChans.Len())
|
||||
}
|
||||
|
||||
// TestHandleInitialBroadcastFail checks `handleInitialBroadcast` returns the
|
||||
// error or a failed result when the broadcast fails.
|
||||
func TestHandleInitialBroadcastFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a publisher using the mocks.
|
||||
tp, m := createTestPublisher(t)
|
||||
|
||||
// Create a test feerate.
|
||||
feerate := chainfee.SatPerKWeight(1000)
|
||||
|
||||
// Create a test request.
|
||||
inp := createTestInput(1000, input.WitnessKeyHash)
|
||||
|
||||
// Create a testing bump request.
|
||||
req := &BumpRequest{
|
||||
DeliveryAddress: changePkScript,
|
||||
Inputs: []input.Input{&inp},
|
||||
Budget: btcutil.Amount(1000),
|
||||
MaxFeeRate: feerate * 10,
|
||||
DeadlineHeight: 10,
|
||||
}
|
||||
|
||||
// Mock the fee estimator to return the testing fee rate.
|
||||
//
|
||||
// We are not testing `NewLinearFeeFunction` here, so the actual params
|
||||
// used are irrelevant.
|
||||
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
|
||||
feerate, nil).Twice()
|
||||
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
|
||||
|
||||
// Mock the signer to always return a valid script.
|
||||
//
|
||||
// NOTE: we are not testing the utility of creating valid txes here, so
|
||||
// this is fine to be mocked. This behaves essentially as skipping the
|
||||
// Signer check and alaways assume the tx has a valid sig.
|
||||
script := &input.Script{}
|
||||
m.signer.On("ComputeInputScript", mock.Anything,
|
||||
mock.Anything).Return(script, nil)
|
||||
|
||||
// Mock the testmempoolaccept to return an error.
|
||||
m.wallet.On("CheckMempoolAcceptance",
|
||||
mock.Anything).Return(errDummy).Once()
|
||||
|
||||
// Register the testing record use `Broadcast`.
|
||||
resultChan, err := tp.Broadcast(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Grab the monitor record from the map.
|
||||
rid := tp.requestCounter.Load()
|
||||
rec, ok := tp.records.Load(rid)
|
||||
require.True(t, ok)
|
||||
|
||||
// Call the method under test and expect an error returned.
|
||||
tp.wg.Add(1)
|
||||
tp.handleInitialBroadcast(rec, rid)
|
||||
|
||||
// Check the result is sent back.
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for subscriber to receive result")
|
||||
|
||||
case result := <-resultChan:
|
||||
// We expect the first result to be TxFatal.
|
||||
require.Equal(t, TxFatal, result.Event)
|
||||
}
|
||||
|
||||
// Validate the record was NOT stored.
|
||||
require.Equal(t, 0, tp.records.Len())
|
||||
require.Equal(t, 0, tp.subscriberChans.Len())
|
||||
|
||||
// Mock the testmempoolaccept again, this time it passes.
|
||||
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
|
||||
|
||||
// Mock the wallet to fail on publish.
|
||||
m.wallet.On("PublishTransaction",
|
||||
mock.Anything, mock.Anything).Return(errDummy).Once()
|
||||
|
||||
// Register the testing record use `Broadcast`.
|
||||
resultChan, err = tp.Broadcast(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Grab the monitor record from the map.
|
||||
rid = tp.requestCounter.Load()
|
||||
rec, ok = tp.records.Load(rid)
|
||||
require.True(t, ok)
|
||||
|
||||
// Call the method under test.
|
||||
tp.wg.Add(1)
|
||||
tp.handleInitialBroadcast(rec, rid)
|
||||
|
||||
// Check the result is sent back.
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for subscriber to receive result")
|
||||
|
||||
case result := <-resultChan:
|
||||
// We expect the result to be TxFailed and the error is set in
|
||||
// the result.
|
||||
require.Equal(t, TxFailed, result.Event)
|
||||
require.ErrorIs(t, result.Err, errDummy)
|
||||
}
|
||||
|
||||
// Validate the record was removed.
|
||||
require.Equal(t, 0, tp.records.Len())
|
||||
require.Equal(t, 0, tp.subscriberChans.Len())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user