routing: update mockers in unit test

This commit adds more mockers to be used in coming unit tests and
simplified the mockers to be more straightforward.
This commit is contained in:
yyforyongyu 2023-02-13 20:57:18 +08:00
parent 01e3bd87ab
commit ddad6ad4c4
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
2 changed files with 98 additions and 43 deletions

View File

@ -12,7 +12,9 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards"
"github.com/stretchr/testify/mock"
)
@ -572,8 +574,6 @@ func (m *mockControlTowerOld) SubscribeAllPayments() (
type mockPaymentAttemptDispatcher struct {
mock.Mock
resultChan chan *htlcswitch.PaymentResult
}
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
@ -589,11 +589,14 @@ func (m *mockPaymentAttemptDispatcher) GetAttemptResult(attemptID uint64,
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error) {
m.Called(attemptID, paymentHash, deobfuscator)
args := m.Called(attemptID, paymentHash, deobfuscator)
// Instead of returning the mocked returned values, we need to return
// the chan resultChan so it can be converted into a read-only chan.
return m.resultChan, nil
resultChan := args.Get(0)
if resultChan == nil {
return nil, args.Error(1)
}
return args.Get(0).(chan *htlcswitch.PaymentResult), args.Error(1)
}
func (m *mockPaymentAttemptDispatcher) CleanStore(
@ -698,7 +701,6 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
type mockControlTower struct {
mock.Mock
sync.Mutex
}
var _ ControlTower = (*mockControlTower)(nil)
@ -718,9 +720,6 @@ func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error {
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
a *channeldb.HTLCAttemptInfo) error {
m.Lock()
defer m.Unlock()
args := m.Called(phash, a)
return args.Error(0)
}
@ -729,29 +728,32 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
*channeldb.HTLCAttempt, error) {
m.Lock()
defer m.Unlock()
args := m.Called(phash, pid, settleInfo)
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
attempt := args.Get(0)
if attempt == nil {
return nil, args.Error(1)
}
return attempt.(*channeldb.HTLCAttempt), args.Error(1)
}
func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
m.Lock()
defer m.Unlock()
args := m.Called(phash, pid, failInfo)
attempt := args.Get(0)
if attempt == nil {
return nil, args.Error(1)
}
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
}
func (m *mockControlTower) FailPayment(phash lntypes.Hash,
reason channeldb.FailureReason) error {
m.Lock()
defer m.Unlock()
args := m.Called(phash, reason)
return args.Error(0)
}
@ -877,3 +879,70 @@ func (m *mockLink) EligibleToForward() bool {
func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error {
return m.mayAddOutgoingErr
}
type mockShardTracker struct {
mock.Mock
}
var _ shards.ShardTracker = (*mockShardTracker)(nil)
func (m *mockShardTracker) NewShard(attemptID uint64,
lastShard bool) (shards.PaymentShard, error) {
args := m.Called(attemptID, lastShard)
shard := args.Get(0)
if shard == nil {
return nil, args.Error(1)
}
return shard.(shards.PaymentShard), args.Error(1)
}
func (m *mockShardTracker) GetHash(attemptID uint64) (lntypes.Hash, error) {
args := m.Called(attemptID)
return args.Get(0).(lntypes.Hash), args.Error(1)
}
func (m *mockShardTracker) CancelShard(attemptID uint64) error {
args := m.Called(attemptID)
return args.Error(0)
}
type mockShard struct {
mock.Mock
}
var _ shards.PaymentShard = (*mockShard)(nil)
// Hash returns the hash used for the HTLC representing this shard.
func (m *mockShard) Hash() lntypes.Hash {
args := m.Called()
return args.Get(0).(lntypes.Hash)
}
// MPP returns any extra MPP records that should be set for the final
// hop on the route used by this shard.
func (m *mockShard) MPP() *record.MPP {
args := m.Called()
r := args.Get(0)
if r == nil {
return nil
}
return r.(*record.MPP)
}
// AMP returns any extra AMP records that should be set for the final
// hop on the route used by this shard.
func (m *mockShard) AMP() *record.AMP {
args := m.Called()
r := args.Get(0)
if r == nil {
return nil
}
return r.(*record.AMP)
}

View File

@ -3528,12 +3528,12 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
).Return(nil)
// Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 1)
resultChan := make(chan *htlcswitch.PaymentResult, 1)
payer.On("GetAttemptResult",
mock.Anything, mock.Anything, mock.Anything,
).Run(func(_ mock.Arguments) {
).Return(resultChan, nil).Run(func(_ mock.Arguments) {
// Send a successful payment result.
payer.resultChan <- &htlcswitch.PaymentResult{}
resultChan <- &htlcswitch.PaymentResult{}
})
missionControl.On("ReportPaymentSuccess",
@ -3599,6 +3599,11 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
},
}}
// Create the error to be returned.
tempErr := htlcswitch.NewForwardingError(
&lnwire.FailTemporaryChannelFailure{}, 1,
)
// Register mockers with the expected method calls.
controlTower.On("InitPayment", payHash, mock.Anything).Return(nil)
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)
@ -3608,26 +3613,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
// Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 1)
// Create the error to be returned.
tempErr := htlcswitch.NewForwardingError(
&lnwire.FailTemporaryChannelFailure{},
1,
)
// Mock GetAttemptResult to return a failure.
payer.On("GetAttemptResult",
mock.Anything, mock.Anything, mock.Anything,
).Run(func(_ mock.Arguments) {
// Send an attempt failure.
payer.resultChan <- &htlcswitch.PaymentResult{
Error: tempErr,
}
})
).Return(tempErr)
// Mock the control tower to return the mocked payment.
payment := &mockMPPayment{}