diff --git a/channeldb/channel.go b/channeldb/channel.go index d36ded213..0d27c0eab 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -3657,3 +3657,38 @@ func storeThawHeight(chanBucket kvdb.RwBucket, height uint32) error { func deleteThawHeight(chanBucket kvdb.RwBucket) error { return chanBucket.Delete(frozenChanKey) } + +// EKeyLocator is an encoder for keychain.KeyLocator. +func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*keychain.KeyLocator); ok { + err := tlv.EUint32T(w, uint32(v.Family), buf) + if err != nil { + return err + } + + return tlv.EUint32T(w, v.Index, buf) + } + return tlv.NewTypeForEncodingErr(val, "keychain.KeyLocator") +} + +// DKeyLocator is a decoder for keychain.KeyLocator. +func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*keychain.KeyLocator); ok { + var family uint32 + err := tlv.DUint32(r, &family, buf, 4) + if err != nil { + return err + } + v.Family = keychain.KeyFamily(family) + + return tlv.DUint32(r, &v.Index, buf, 4) + } + return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8) +} + +// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed +// Type and the EKeyLocator and DKeyLocator functions. The size will always be +// 8 as KeyFamily is uint32 and the Index is uint32. +func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record { + return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator) +} diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 437473723..482bc7af8 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -8,6 +8,8 @@ import ( "runtime" "testing" + "github.com/stretchr/testify/require" + "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" @@ -50,6 +52,9 @@ var ( IP: net.ParseIP("127.0.0.1"), Port: 18555, } + + // keyLocIndex is the KeyLocator Index we use for TestKeyLocatorEncoding. + keyLocIndex = uint32(2049) ) // testChannelParams is a struct which details the specifics of how a channel @@ -1587,3 +1592,33 @@ func TestHasChanStatus(t *testing.T) { }) } } + +// TestKeyLocatorEncoding tests that we are able to serialize a given +// keychain.KeyLocator. After successfully encoding, we check that the decode +// output arrives at the same initial KeyLocator. +func TestKeyLocatorEncoding(t *testing.T) { + keyLoc := keychain.KeyLocator{ + Family: keychain.KeyFamilyRevocationRoot, + Index: keyLocIndex, + } + + // First, we'll encode the KeyLocator into a buffer. + var ( + b bytes.Buffer + buf [8]byte + ) + + err := EKeyLocator(&b, &keyLoc, &buf) + require.NoError(t, err, "unable to encode key locator") + + // Next, we'll attempt to decode the bytes into a new KeyLocator. + r := bytes.NewReader(b.Bytes()) + var decodedKeyLoc keychain.KeyLocator + + err = DKeyLocator(r, &decodedKeyLoc, &buf, 8) + require.NoError(t, err, "unable to decode key locator") + + // Finally, we'll compare that the original KeyLocator and the decoded + // version are equal. + require.Equal(t, keyLoc, decodedKeyLoc) +}