mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
channeldb: add KeyLocator Record
This commit is contained in:
parent
56b61078c5
commit
986e69c81b
@ -3657,3 +3657,38 @@ func storeThawHeight(chanBucket kvdb.RwBucket, height uint32) error {
|
||||
func deleteThawHeight(chanBucket kvdb.RwBucket) error {
|
||||
return chanBucket.Delete(frozenChanKey)
|
||||
}
|
||||
|
||||
// EKeyLocator is an encoder for keychain.KeyLocator.
|
||||
func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if v, ok := val.(*keychain.KeyLocator); ok {
|
||||
err := tlv.EUint32T(w, uint32(v.Family), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tlv.EUint32T(w, v.Index, buf)
|
||||
}
|
||||
return tlv.NewTypeForEncodingErr(val, "keychain.KeyLocator")
|
||||
}
|
||||
|
||||
// DKeyLocator is a decoder for keychain.KeyLocator.
|
||||
func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if v, ok := val.(*keychain.KeyLocator); ok {
|
||||
var family uint32
|
||||
err := tlv.DUint32(r, &family, buf, 4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Family = keychain.KeyFamily(family)
|
||||
|
||||
return tlv.DUint32(r, &v.Index, buf, 4)
|
||||
}
|
||||
return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8)
|
||||
}
|
||||
|
||||
// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed
|
||||
// Type and the EKeyLocator and DKeyLocator functions. The size will always be
|
||||
// 8 as KeyFamily is uint32 and the Index is uint32.
|
||||
func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record {
|
||||
return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator)
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
@ -50,6 +52,9 @@ var (
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}
|
||||
|
||||
// keyLocIndex is the KeyLocator Index we use for TestKeyLocatorEncoding.
|
||||
keyLocIndex = uint32(2049)
|
||||
)
|
||||
|
||||
// testChannelParams is a struct which details the specifics of how a channel
|
||||
@ -1587,3 +1592,33 @@ func TestHasChanStatus(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestKeyLocatorEncoding tests that we are able to serialize a given
|
||||
// keychain.KeyLocator. After successfully encoding, we check that the decode
|
||||
// output arrives at the same initial KeyLocator.
|
||||
func TestKeyLocatorEncoding(t *testing.T) {
|
||||
keyLoc := keychain.KeyLocator{
|
||||
Family: keychain.KeyFamilyRevocationRoot,
|
||||
Index: keyLocIndex,
|
||||
}
|
||||
|
||||
// First, we'll encode the KeyLocator into a buffer.
|
||||
var (
|
||||
b bytes.Buffer
|
||||
buf [8]byte
|
||||
)
|
||||
|
||||
err := EKeyLocator(&b, &keyLoc, &buf)
|
||||
require.NoError(t, err, "unable to encode key locator")
|
||||
|
||||
// Next, we'll attempt to decode the bytes into a new KeyLocator.
|
||||
r := bytes.NewReader(b.Bytes())
|
||||
var decodedKeyLoc keychain.KeyLocator
|
||||
|
||||
err = DKeyLocator(r, &decodedKeyLoc, &buf, 8)
|
||||
require.NoError(t, err, "unable to decode key locator")
|
||||
|
||||
// Finally, we'll compare that the original KeyLocator and the decoded
|
||||
// version are equal.
|
||||
require.Equal(t, keyLoc, decodedKeyLoc)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user