mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 21:35:24 +01:00
routing: update mockers in unit test
This commit adds more mockers to be used in coming unit tests and simplified the mockers to be more straightforward.
This commit is contained in:
parent
01e3bd87ab
commit
ddad6ad4c4
@ -12,7 +12,9 @@ import (
|
||||
"github.com/lightningnetwork/lnd/htlcswitch"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/routing/shards"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
@ -572,8 +574,6 @@ func (m *mockControlTowerOld) SubscribeAllPayments() (
|
||||
|
||||
type mockPaymentAttemptDispatcher struct {
|
||||
mock.Mock
|
||||
|
||||
resultChan chan *htlcswitch.PaymentResult
|
||||
}
|
||||
|
||||
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
|
||||
@ -589,11 +589,14 @@ func (m *mockPaymentAttemptDispatcher) GetAttemptResult(attemptID uint64,
|
||||
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
|
||||
<-chan *htlcswitch.PaymentResult, error) {
|
||||
|
||||
m.Called(attemptID, paymentHash, deobfuscator)
|
||||
args := m.Called(attemptID, paymentHash, deobfuscator)
|
||||
|
||||
// Instead of returning the mocked returned values, we need to return
|
||||
// the chan resultChan so it can be converted into a read-only chan.
|
||||
return m.resultChan, nil
|
||||
resultChan := args.Get(0)
|
||||
if resultChan == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(chan *htlcswitch.PaymentResult), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockPaymentAttemptDispatcher) CleanStore(
|
||||
@ -698,7 +701,6 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
|
||||
|
||||
type mockControlTower struct {
|
||||
mock.Mock
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
var _ ControlTower = (*mockControlTower)(nil)
|
||||
@ -718,9 +720,6 @@ func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error {
|
||||
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
|
||||
a *channeldb.HTLCAttemptInfo) error {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
args := m.Called(phash, a)
|
||||
return args.Error(0)
|
||||
}
|
||||
@ -729,29 +728,32 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
|
||||
pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
|
||||
*channeldb.HTLCAttempt, error) {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
args := m.Called(phash, pid, settleInfo)
|
||||
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
|
||||
|
||||
attempt := args.Get(0)
|
||||
if attempt == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return attempt.(*channeldb.HTLCAttempt), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
|
||||
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
args := m.Called(phash, pid, failInfo)
|
||||
|
||||
attempt := args.Get(0)
|
||||
if attempt == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) FailPayment(phash lntypes.Hash,
|
||||
reason channeldb.FailureReason) error {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
args := m.Called(phash, reason)
|
||||
return args.Error(0)
|
||||
}
|
||||
@ -877,3 +879,70 @@ func (m *mockLink) EligibleToForward() bool {
|
||||
func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error {
|
||||
return m.mayAddOutgoingErr
|
||||
}
|
||||
|
||||
type mockShardTracker struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ shards.ShardTracker = (*mockShardTracker)(nil)
|
||||
|
||||
func (m *mockShardTracker) NewShard(attemptID uint64,
|
||||
lastShard bool) (shards.PaymentShard, error) {
|
||||
|
||||
args := m.Called(attemptID, lastShard)
|
||||
|
||||
shard := args.Get(0)
|
||||
if shard == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return shard.(shards.PaymentShard), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockShardTracker) GetHash(attemptID uint64) (lntypes.Hash, error) {
|
||||
args := m.Called(attemptID)
|
||||
return args.Get(0).(lntypes.Hash), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockShardTracker) CancelShard(attemptID uint64) error {
|
||||
args := m.Called(attemptID)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type mockShard struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ shards.PaymentShard = (*mockShard)(nil)
|
||||
|
||||
// Hash returns the hash used for the HTLC representing this shard.
|
||||
func (m *mockShard) Hash() lntypes.Hash {
|
||||
args := m.Called()
|
||||
return args.Get(0).(lntypes.Hash)
|
||||
}
|
||||
|
||||
// MPP returns any extra MPP records that should be set for the final
|
||||
// hop on the route used by this shard.
|
||||
func (m *mockShard) MPP() *record.MPP {
|
||||
args := m.Called()
|
||||
|
||||
r := args.Get(0)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.(*record.MPP)
|
||||
}
|
||||
|
||||
// AMP returns any extra AMP records that should be set for the final
|
||||
// hop on the route used by this shard.
|
||||
func (m *mockShard) AMP() *record.AMP {
|
||||
args := m.Called()
|
||||
|
||||
r := args.Get(0)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.(*record.AMP)
|
||||
}
|
||||
|
@ -3528,12 +3528,12 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
|
||||
).Return(nil)
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 1)
|
||||
resultChan := make(chan *htlcswitch.PaymentResult, 1)
|
||||
payer.On("GetAttemptResult",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Run(func(_ mock.Arguments) {
|
||||
).Return(resultChan, nil).Run(func(_ mock.Arguments) {
|
||||
// Send a successful payment result.
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{}
|
||||
resultChan <- &htlcswitch.PaymentResult{}
|
||||
})
|
||||
|
||||
missionControl.On("ReportPaymentSuccess",
|
||||
@ -3599,6 +3599,11 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
|
||||
},
|
||||
}}
|
||||
|
||||
// Create the error to be returned.
|
||||
tempErr := htlcswitch.NewForwardingError(
|
||||
&lnwire.FailTemporaryChannelFailure{}, 1,
|
||||
)
|
||||
|
||||
// Register mockers with the expected method calls.
|
||||
controlTower.On("InitPayment", payHash, mock.Anything).Return(nil)
|
||||
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)
|
||||
@ -3608,26 +3613,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
|
||||
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 1)
|
||||
|
||||
// Create the error to be returned.
|
||||
tempErr := htlcswitch.NewForwardingError(
|
||||
&lnwire.FailTemporaryChannelFailure{},
|
||||
1,
|
||||
)
|
||||
|
||||
// Mock GetAttemptResult to return a failure.
|
||||
payer.On("GetAttemptResult",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Run(func(_ mock.Arguments) {
|
||||
// Send an attempt failure.
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: tempErr,
|
||||
}
|
||||
})
|
||||
).Return(tempErr)
|
||||
|
||||
// Mock the control tower to return the mocked payment.
|
||||
payment := &mockMPPayment{}
|
||||
|
Loading…
Reference in New Issue
Block a user