routing: refactor attempt makers to return pointers

Thus adding following unit tests can be a bit easier.
This commit is contained in:
yyforyongyu 2023-03-08 01:14:32 +08:00
parent ddad6ad4c4
commit e46c689bf1
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
2 changed files with 58 additions and 20 deletions

View File

@ -4,6 +4,7 @@ import (
"testing"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lntypes"
@ -17,31 +18,68 @@ var (
dummyErr = errors.New("dummy")
)
func makeSettledAttempt(total, fee int,
preimage lntypes.Preimage) channeldb.HTLCAttempt {
// createDummyRoute builds a route a->b->c paying the given amt to c.
func createDummyRoute(t *testing.T, amt lnwire.MilliSatoshi) *route.Route {
t.Helper()
return channeldb.HTLCAttempt{
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
priv, err := btcec.NewPrivateKey()
require.NoError(t, err, "failed to create private key")
hop1 := route.NewVertex(priv.PubKey())
priv, err = btcec.NewPrivateKey()
require.NoError(t, err, "failed to create private key")
hop2 := route.NewVertex(priv.PubKey())
hopFee := lnwire.NewMSatFromSatoshis(3)
hops := []*route.Hop{
{
ChannelID: 1,
PubKeyBytes: hop1,
LegacyPayload: true,
AmtToForward: amt + hopFee,
},
{
ChannelID: 2,
PubKeyBytes: hop2,
LegacyPayload: true,
AmtToForward: amt,
},
}
priv, err = btcec.NewPrivateKey()
require.NoError(t, err, "failed to create private key")
source := route.NewVertex(priv.PubKey())
// We create a simple route that we will supply every time the router
// requests one.
rt, err := route.NewRouteFromHops(amt+2*hopFee, 100, source, hops)
require.NoError(t, err, "failed to create route")
return rt
}
func makeSettledAttempt(t *testing.T, total int,
preimage lntypes.Preimage) *channeldb.HTLCAttempt {
return &channeldb.HTLCAttempt{
HTLCAttemptInfo: makeAttemptInfo(t, total),
Settle: &channeldb.HTLCSettleInfo{Preimage: preimage},
}
}
func makeFailedAttempt(total, fee int) *channeldb.HTLCAttempt {
func makeFailedAttempt(t *testing.T, total int) *channeldb.HTLCAttempt {
return &channeldb.HTLCAttempt{
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
HTLCAttemptInfo: makeAttemptInfo(t, total),
Failure: &channeldb.HTLCFailInfo{
Reason: channeldb.HTLCFailInternal,
},
}
}
func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo {
hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)}
func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo {
rt := createDummyRoute(t, lnwire.MilliSatoshi(amt))
return channeldb.HTLCAttemptInfo{
Route: route.Route{
TotalAmount: lnwire.MilliSatoshi(total),
Hops: []*route.Hop{hop},
},
Route: *rt,
}
}

View File

@ -3483,7 +3483,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
)
preimage := lntypes.Preimage{1}
testAttempt := makeSettledAttempt(int(payAmt), 0, preimage)
testAttempt := makeSettledAttempt(t, int(payAmt), preimage)
node, err := createTestNode()
require.NoError(t, err)
@ -3521,7 +3521,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)
controlTower.On("SettleAttempt",
payHash, mock.Anything, mock.Anything,
).Return(&testAttempt, nil)
).Return(testAttempt, nil)
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
@ -3550,7 +3550,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
// Expect a successful send to route.
attempt, err := router.SendToRouteSkipTempErr(payHash, rt)
require.NoError(t, err)
require.Equal(t, &testAttempt, attempt)
require.Equal(t, testAttempt, attempt)
// Assert the above methods are called as expected.
controlTower.AssertExpectations(t)
@ -3563,11 +3563,11 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
// cause the payment to be failed.
func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
var (
payHash lntypes.Hash
payAmt = lnwire.MilliSatoshi(10000)
testAttempt = &channeldb.HTLCAttempt{}
payHash lntypes.Hash
payAmt = lnwire.MilliSatoshi(10000)
)
testAttempt := makeFailedAttempt(t, int(payAmt))
node, err := createTestNode()
require.NoError(t, err)
@ -3648,7 +3648,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) {
payAmt = lnwire.MilliSatoshi(10000)
)
testAttempt := makeFailedAttempt(int(payAmt), 0)
testAttempt := makeFailedAttempt(t, int(payAmt))
node, err := createTestNode()
require.NoError(t, err)
@ -3733,7 +3733,7 @@ func TestSendToRouteTempFailure(t *testing.T) {
payAmt = lnwire.MilliSatoshi(10000)
)
testAttempt := makeFailedAttempt(int(payAmt), 0)
testAttempt := makeFailedAttempt(t, int(payAmt))
node, err := createTestNode()
require.NoError(t, err)