diff --git a/lnrpc/invoicesrpc/utils.go b/lnrpc/invoicesrpc/utils.go index ad7c0e68c..9fdafcef8 100644 --- a/lnrpc/invoicesrpc/utils.go +++ b/lnrpc/invoicesrpc/utils.go @@ -129,7 +129,7 @@ func CreateRPCInvoice(invoice *channeldb.Invoice, rpcHtlc.Amp = &lnrpc.AMP{ RootShare: rootShare[:], SetId: setID[:], - ChildIndex: uint32(htlc.AMP.Record.ChildIndex()), + ChildIndex: htlc.AMP.Record.ChildIndex(), Hash: htlc.AMP.Hash[:], Preimage: preimage, } diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 7944dddcd..66fab66ef 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -455,6 +455,11 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop return nil, err } + amp, err := UnmarshalAMP(rpcHop.AmpRecord) + if err != nil { + return nil, err + } + return &route.Hop{ OutgoingTimeLock: rpcHop.Expiry, AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat), @@ -463,6 +468,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop CustomRecords: customRecords, LegacyPayload: !rpcHop.TlvPayload, MPP: mpp, + AMP: amp, }, nil } @@ -895,6 +901,32 @@ func UnmarshalMPP(reqMPP *lnrpc.MPPRecord) (*record.MPP, error) { return record.NewMPP(total, addr), nil } +func UnmarshalAMP(reqAMP *lnrpc.AMPRecord) (*record.AMP, error) { + if reqAMP == nil { + return nil, nil + } + + reqRootShare := reqAMP.RootShare + reqSetID := reqAMP.SetId + + switch { + case len(reqRootShare) != 32: + return nil, errors.New("AMP root_share must be 32 bytes") + + case len(reqSetID) != 32: + return nil, errors.New("AMP set_id must be 32 bytes") + } + + var ( + rootShare [32]byte + setID [32]byte + ) + copy(rootShare[:], reqRootShare) + copy(setID[:], reqSetID) + + return record.NewAMP(rootShare, setID, reqAMP.ChildIndex), nil +} + // MarshalHTLCAttempt constructs an RPC HTLCAttempt from the db representation. func (r *RouterBackend) MarshalHTLCAttempt( htlc channeldb.HTLCAttempt) (*lnrpc.HTLCAttempt, error) { diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index f100909ad..f92a6f15a 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" "github.com/lightningnetwork/lnd/lnrpc" ) @@ -239,18 +240,18 @@ func (m *mockMissionControl) GetPairHistorySnapshot(fromNode, return routing.TimedPairResult{} } -type mppOutcome byte +type recordParseOutcome byte const ( - valid mppOutcome = iota + valid recordParseOutcome = iota invalid - nompp + norecord ) type unmarshalMPPTest struct { name string mpp *lnrpc.MPPRecord - outcome mppOutcome + outcome recordParseOutcome } // TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to @@ -262,7 +263,7 @@ func TestUnmarshalMPP(t *testing.T) { { name: "nil record", mpp: nil, - outcome: nompp, + outcome: norecord, }, { name: "invalid total or addr", @@ -346,7 +347,7 @@ func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) { // Arguments that produce no MPP field should return no error and no MPP // record. - case nompp: + case norecord: if err != nil { t.Fatalf("failure for args resulting for no-mpp") } @@ -358,3 +359,95 @@ func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) { t.Fatalf("test case has non-standard outcome") } } + +type unmarshalAMPTest struct { + name string + amp *lnrpc.AMPRecord + outcome recordParseOutcome +} + +// TestUnmarshalAMP asserts the behavior of decoding an RPC AMPRecord. +func TestUnmarshalAMP(t *testing.T) { + rootShare := bytes.Repeat([]byte{0x01}, 32) + setID := bytes.Repeat([]byte{0x02}, 32) + + // All child indexes are valid. + childIndex := uint32(3) + + tests := []unmarshalAMPTest{ + { + name: "nil record", + amp: nil, + outcome: norecord, + }, + { + name: "invalid root share invalid set id", + amp: &lnrpc.AMPRecord{ + RootShare: []byte{0x01}, + SetId: []byte{0x02}, + ChildIndex: childIndex, + }, + outcome: invalid, + }, + { + name: "valid root share invalid set id", + amp: &lnrpc.AMPRecord{ + RootShare: rootShare, + SetId: []byte{0x02}, + ChildIndex: childIndex, + }, + outcome: invalid, + }, + { + name: "invalid root share valid set id", + amp: &lnrpc.AMPRecord{ + RootShare: []byte{0x01}, + SetId: setID, + ChildIndex: childIndex, + }, + outcome: invalid, + }, + { + name: "valid root share valid set id", + amp: &lnrpc.AMPRecord{ + RootShare: rootShare, + SetId: setID, + ChildIndex: childIndex, + }, + outcome: valid, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + testUnmarshalAMP(t, test) + }) + } +} + +func testUnmarshalAMP(t *testing.T, test unmarshalAMPTest) { + amp, err := UnmarshalAMP(test.amp) + switch test.outcome { + case valid: + require.NoError(t, err) + require.NotNil(t, amp) + + rootShare := amp.RootShare() + setID := amp.SetID() + require.Equal(t, test.amp.RootShare, rootShare[:]) + require.Equal(t, test.amp.SetId, setID[:]) + require.Equal(t, test.amp.ChildIndex, amp.ChildIndex()) + + case invalid: + require.Error(t, err) + require.Nil(t, amp) + + case norecord: + require.NoError(t, err) + require.Nil(t, amp) + + default: + t.Fatalf("test case has non-standard outcome") + } +}