watchtower/wtdb: add Encode/Decode methods to wtclient structs

This commit is contained in:
Conner Fromknecht 2019-05-23 20:48:08 -07:00
parent 1db9bf2fd4
commit 2a904cb69f
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
4 changed files with 284 additions and 12 deletions

View File

@ -103,6 +103,11 @@ func WriteElement(w io.Writer, element interface{}) error {
return err return err
} }
case lnwire.ChannelID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case uint64: case uint64:
if err := binary.Write(w, byteOrder, e); err != nil { if err := binary.Write(w, byteOrder, e); err != nil {
return err return err
@ -259,6 +264,11 @@ func ReadElement(r io.Reader, element interface{}) error {
} }
*e = lnwire.NewShortChanIDFromInt(a) *e = lnwire.NewShortChanIDFromInt(a)
case *lnwire.ChannelID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *uint64: case *uint64:
if err := binary.Read(r, byteOrder, e); err != nil { if err := binary.Read(r, byteOrder, e); err != nil {
return err return err

View File

@ -2,6 +2,7 @@ package wtdb
import ( import (
"errors" "errors"
"io"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -112,8 +113,38 @@ type ClientSessionBody struct {
// deposited to if a sweep transaction confirms and the sessions // deposited to if a sweep transaction confirms and the sessions
// specifies a reward output. // specifies a reward output.
RewardPkScript []byte RewardPkScript []byte
}
// Encode writes a ClientSessionBody to the passed io.Writer.
func (s *ClientSessionBody) Encode(w io.Writer) error {
return WriteElements(w,
s.SeqNum,
s.TowerLastApplied,
uint64(s.TowerID),
s.KeyIndex,
s.Policy,
s.RewardPkScript,
)
}
// Decode reads a ClientSessionBody from the passed io.Reader.
func (s *ClientSessionBody) Decode(r io.Reader) error {
var towerID uint64
err := ReadElements(r,
&s.SeqNum,
&s.TowerLastApplied,
&towerID,
&s.KeyIndex,
&s.Policy,
&s.RewardPkScript,
)
if err != nil {
return err
}
s.TowerID = TowerID(towerID)
return nil
} }
// BackupID identifies a particular revoked, remote commitment by channel id and // BackupID identifies a particular revoked, remote commitment by channel id and
@ -126,6 +157,22 @@ type BackupID struct {
CommitHeight uint64 CommitHeight uint64
} }
// Encode writes the BackupID from the passed io.Writer.
func (b *BackupID) Encode(w io.Writer) error {
return WriteElements(w,
b.ChanID,
b.CommitHeight,
)
}
// Decode reads a BackupID from the passed io.Reader.
func (b *BackupID) Decode(r io.Reader) error {
return ReadElements(r,
&b.ChanID,
&b.CommitHeight,
)
}
// CommittedUpdate holds a state update sent by a client along with its // CommittedUpdate holds a state update sent by a client along with its
// allocated sequence number and the exact remote commitment the encrypted // allocated sequence number and the exact remote commitment the encrypted
// justice transaction can rectify. // justice transaction can rectify.
@ -152,3 +199,29 @@ type CommittedUpdateBody struct {
// hint is broadcast. // hint is broadcast.
EncryptedBlob []byte EncryptedBlob []byte
} }
// Encode writes the CommittedUpdateBody to the passed io.Writer.
func (u *CommittedUpdateBody) Encode(w io.Writer) error {
err := u.BackupID.Encode(w)
if err != nil {
return err
}
return WriteElements(w,
u.Hint,
u.EncryptedBlob,
)
}
// Decode reads a CommittedUpdateBody from the passed io.Reader.
func (u *CommittedUpdateBody) Decode(r io.Reader) error {
err := u.BackupID.Decode(r)
if err != nil {
return err
}
return ReadElements(r,
&u.Hint,
&u.EncryptedBlob,
)
}

View File

@ -2,14 +2,122 @@ package wtdb_test
import ( import (
"bytes" "bytes"
"encoding/binary"
"io" "io"
"math/rand"
"net"
"reflect" "reflect"
"testing" "testing"
"testing/quick" "testing/quick"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtdb"
) )
func randPubKey() (*btcec.PublicKey, error) {
priv, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
return nil, err
}
return priv.PubKey(), nil
}
func randTCP4Addr(r *rand.Rand) (*net.TCPAddr, error) {
var ip [4]byte
if _, err := r.Read(ip[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
addrIP := net.IP(ip[:])
addrPort := int(binary.BigEndian.Uint16(port[:]))
return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil
}
func randTCP6Addr(r *rand.Rand) (*net.TCPAddr, error) {
var ip [16]byte
if _, err := r.Read(ip[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
addrIP := net.IP(ip[:])
addrPort := int(binary.BigEndian.Uint16(port[:]))
return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil
}
func randV2OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) {
var serviceID [tor.V2DecodedLen]byte
if _, err := r.Read(serviceID[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
onionService := tor.Base32Encoding.EncodeToString(serviceID[:])
onionService += tor.OnionSuffix
addrPort := int(binary.BigEndian.Uint16(port[:]))
return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil
}
func randV3OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) {
var serviceID [tor.V3DecodedLen]byte
if _, err := r.Read(serviceID[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
onionService := tor.Base32Encoding.EncodeToString(serviceID[:])
onionService += tor.OnionSuffix
addrPort := int(binary.BigEndian.Uint16(port[:]))
return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil
}
func randAddrs(r *rand.Rand) ([]net.Addr, error) {
tcp4Addr, err := randTCP4Addr(r)
if err != nil {
return nil, err
}
tcp6Addr, err := randTCP6Addr(r)
if err != nil {
return nil, err
}
v2OnionAddr, err := randV2OnionAddr(r)
if err != nil {
return nil, err
}
v3OnionAddr, err := randV3OnionAddr(r)
if err != nil {
return nil, err
}
return []net.Addr{tcp4Addr, tcp6Addr, v2OnionAddr, v3OnionAddr}, nil
}
// dbObject is abstract object support encoding and decoding. // dbObject is abstract object support encoding and decoding.
type dbObject interface { type dbObject interface {
Encode(io.Writer) error Encode(io.Writer) error
@ -19,7 +127,9 @@ type dbObject interface {
// TestCodec serializes and deserializes wtdb objects in order to test that that // TestCodec serializes and deserializes wtdb objects in order to test that that
// the codec understands all of the required field types. The test also asserts // the codec understands all of the required field types. The test also asserts
// that decoding an object into another results in an equivalent object. // that decoding an object into another results in an equivalent object.
func TestCodec(t *testing.T) { func TestCodec(tt *testing.T) {
var t *testing.T
mainScenario := func(obj dbObject) bool { mainScenario := func(obj dbObject) bool {
// Ensure encoding the object succeeds. // Ensure encoding the object succeeds.
var b bytes.Buffer var b bytes.Buffer
@ -35,6 +145,14 @@ func TestCodec(t *testing.T) {
obj2 = &wtdb.SessionInfo{} obj2 = &wtdb.SessionInfo{}
case *wtdb.SessionStateUpdate: case *wtdb.SessionStateUpdate:
obj2 = &wtdb.SessionStateUpdate{} obj2 = &wtdb.SessionStateUpdate{}
case *wtdb.ClientSessionBody:
obj2 = &wtdb.ClientSessionBody{}
case *wtdb.CommittedUpdateBody:
obj2 = &wtdb.CommittedUpdateBody{}
case *wtdb.BackupID:
obj2 = &wtdb.BackupID{}
case *wtdb.Tower:
obj2 = &wtdb.Tower{}
default: default:
t.Fatalf("unknown type: %T", obj) t.Fatalf("unknown type: %T", obj)
return false return false
@ -57,6 +175,29 @@ func TestCodec(t *testing.T) {
return true return true
} }
customTypeGen := map[string]func([]reflect.Value, *rand.Rand){
"Tower": func(v []reflect.Value, r *rand.Rand) {
pk, err := randPubKey()
if err != nil {
t.Fatalf("unable to generate pubkey: %v", err)
return
}
addrs, err := randAddrs(r)
if err != nil {
t.Fatalf("unable to generate addrs: %v", err)
return
}
obj := wtdb.Tower{
IdentityKey: pk,
Addresses: addrs,
}
v[0] = reflect.ValueOf(obj)
},
}
tests := []struct { tests := []struct {
name string name string
scenario interface{} scenario interface{}
@ -73,11 +214,45 @@ func TestCodec(t *testing.T) {
return mainScenario(&obj) return mainScenario(&obj)
}, },
}, },
{
name: "ClientSessionBody",
scenario: func(obj wtdb.ClientSessionBody) bool {
return mainScenario(&obj)
},
},
{
name: "CommittedUpdateBody",
scenario: func(obj wtdb.CommittedUpdateBody) bool {
return mainScenario(&obj)
},
},
{
name: "BackupID",
scenario: func(obj wtdb.BackupID) bool {
return mainScenario(&obj)
},
},
{
name: "Tower",
scenario: func(obj wtdb.Tower) bool {
return mainScenario(&obj)
},
},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { tt.Run(test.name, func(h *testing.T) {
if err := quick.Check(test.scenario, nil); err != nil { t = h
var config *quick.Config
if valueGen, ok := customTypeGen[test.name]; ok {
config = &quick.Config{
Values: valueGen,
}
}
err := quick.Check(test.scenario, config)
if err != nil {
t.Fatalf("fuzz checks for msg=%s failed: %v", t.Fatalf("fuzz checks for msg=%s failed: %v",
test.name, err) test.name, err)
} }

View File

@ -2,8 +2,8 @@ package wtdb
import ( import (
"errors" "errors"
"io"
"net" "net"
"sync"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -47,18 +47,15 @@ type Tower struct {
// Addresses is a list of possible addresses to reach the tower. // Addresses is a list of possible addresses to reach the tower.
Addresses []net.Addr Addresses []net.Addr
mu sync.RWMutex
} }
// AddAddress adds the given address to the tower's in-memory list of addresses. // AddAddress adds the given address to the tower's in-memory list of addresses.
// If the address's string is already present, the Tower will be left // If the address's string is already present, the Tower will be left
// unmodified. Otherwise, the adddress is prepended to the beginning of the // unmodified. Otherwise, the adddress is prepended to the beginning of the
// Tower's addresses, on the assumption that it is fresher than the others. // Tower's addresses, on the assumption that it is fresher than the others.
//
// NOTE: This method is NOT safe for concurrent use.
func (t *Tower) AddAddress(addr net.Addr) { func (t *Tower) AddAddress(addr net.Addr) {
t.mu.Lock()
defer t.mu.Unlock()
// Ensure we don't add a duplicate address. // Ensure we don't add a duplicate address.
addrStr := addr.String() addrStr := addr.String()
for _, existingAddr := range t.Addresses { for _, existingAddr := range t.Addresses {
@ -75,10 +72,9 @@ func (t *Tower) AddAddress(addr net.Addr) {
// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's // LNAddrs generates a list of lnwire.NetAddress from a Tower instance's
// addresses. This can be used to have a client try multiple addresses for the // addresses. This can be used to have a client try multiple addresses for the
// same Tower. // same Tower.
//
// NOTE: This method is NOT safe for concurrent use.
func (t *Tower) LNAddrs() []*lnwire.NetAddress { func (t *Tower) LNAddrs() []*lnwire.NetAddress {
t.mu.RLock()
defer t.mu.RUnlock()
addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses)) addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses))
for _, addr := range t.Addresses { for _, addr := range t.Addresses {
addrs = append(addrs, &lnwire.NetAddress{ addrs = append(addrs, &lnwire.NetAddress{
@ -89,3 +85,21 @@ func (t *Tower) LNAddrs() []*lnwire.NetAddress {
return addrs return addrs
} }
// Encode writes the Tower to the passed io.Writer. The TowerID is not
// serialized, since it acts as the key.
func (t *Tower) Encode(w io.Writer) error {
return WriteElements(w,
t.IdentityKey,
t.Addresses,
)
}
// Decode reads a Tower from the passed io.Reader. The TowerID is meant to be
// decoded from the key.
func (t *Tower) Decode(r io.Reader) error {
return ReadElements(r,
&t.IdentityKey,
&t.Addresses,
)
}