channeldb/migration_01_to_11: remove unused code

This commit is contained in:
Joost Jager 2019-10-24 12:45:07 +02:00
parent f5191440c5
commit 60503d6c44
No known key found for this signature in database
GPG key ID: A61B9D4C393C59C7
33 changed files with 1 additions and 17048 deletions

View file

@ -1,24 +0,0 @@
channeldb
==========
[![Build Status](http://img.shields.io/travis/lightningnetwork/lnd.svg)](https://travis-ci.org/lightningnetwork/lnd)
[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/lightningnetwork/lnd/blob/master/LICENSE)
[![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)](http://godoc.org/github.com/lightningnetwork/lnd/channeldb)
The channeldb implements the persistent storage engine for `lnd` and
generically a data storage layer for the required state within the Lightning
Network. The backing storage engine is
[boltdb](https://github.com/coreos/bbolt), an embedded pure-go key-value store
based off of LMDB.
The package implements an object-oriented storage model with queries and
mutations flowing through a particular object instance rather than the database
itself. The storage implemented by the objects includes: open channels, past
commitment revocation states, the channel graph which includes authenticated
node and channel announcements, outgoing payments, and invoices
## Installation and Updating
```bash
$ go get -u github.com/lightningnetwork/lnd/channeldb
```

View file

@ -1,149 +0,0 @@
package migration_01_to_11
import (
"bytes"
"net"
"strings"
"testing"
"github.com/lightningnetwork/lnd/tor"
)
type unknownAddrType struct{}
func (t unknownAddrType) Network() string { return "unknown" }
func (t unknownAddrType) String() string { return "unknown" }
var testIP4 = net.ParseIP("192.168.1.1")
var testIP6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329")
var addrTests = []struct {
expAddr net.Addr
serErr string
}{
// Valid addresses.
{
expAddr: &net.TCPAddr{
IP: testIP4,
Port: 12345,
},
},
{
expAddr: &net.TCPAddr{
IP: testIP6,
Port: 65535,
},
},
{
expAddr: &tor.OnionAddr{
OnionService: "3g2upl4pq6kufc4m.onion",
Port: 9735,
},
},
{
expAddr: &tor.OnionAddr{
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion",
Port: 80,
},
},
// Invalid addresses.
{
expAddr: unknownAddrType{},
serErr: ErrUnknownAddressType.Error(),
},
{
expAddr: &net.TCPAddr{
// Remove last byte of IPv4 address.
IP: testIP4[:len(testIP4)-1],
Port: 12345,
},
serErr: "unable to encode",
},
{
expAddr: &net.TCPAddr{
// Add an extra byte of IPv4 address.
IP: append(testIP4, 0xff),
Port: 12345,
},
serErr: "unable to encode",
},
{
expAddr: &net.TCPAddr{
// Remove last byte of IPv6 address.
IP: testIP6[:len(testIP6)-1],
Port: 65535,
},
serErr: "unable to encode",
},
{
expAddr: &net.TCPAddr{
// Add an extra byte to the IPv6 address.
IP: append(testIP6, 0xff),
Port: 65535,
},
serErr: "unable to encode",
},
{
expAddr: &tor.OnionAddr{
// Invalid suffix.
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.inion",
Port: 80,
},
serErr: "invalid suffix",
},
{
expAddr: &tor.OnionAddr{
// Invalid length.
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyy.onion",
Port: 80,
},
serErr: "unknown onion service length",
},
{
expAddr: &tor.OnionAddr{
// Invalid encoding.
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyA.onion",
Port: 80,
},
serErr: "illegal base32",
},
}
// TestAddrSerialization tests that the serialization method used by channeldb
// for net.Addr's works as intended.
func TestAddrSerialization(t *testing.T) {
t.Parallel()
var b bytes.Buffer
for _, test := range addrTests {
err := serializeAddr(&b, test.expAddr)
switch {
case err == nil && test.serErr != "":
t.Fatalf("expected serialization err for addr %v",
test.expAddr)
case err != nil && test.serErr == "":
t.Fatalf("unexpected serialization err for addr %v: %v",
test.expAddr, err)
case err != nil && !strings.Contains(err.Error(), test.serErr):
t.Fatalf("unexpected serialization err for addr %v, "+
"want: %v, got %v", test.expAddr, test.serErr,
err)
case err != nil:
continue
}
addr, err := deserializeAddr(&b)
if err != nil {
t.Fatalf("unable to deserialize address: %v", err)
}
if addr.String() != test.expAddr.String() {
t.Fatalf("expected address %v after serialization, "+
"got %v", addr, test.expAddr)
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,50 +0,0 @@
package migration_01_to_11
// channelCache is an in-memory cache used to improve the performance of
// ChanUpdatesInHorizon. It caches the chan info and edge policies for a
// particular channel.
type channelCache struct {
n int
channels map[uint64]ChannelEdge
}
// newChannelCache creates a new channelCache with maximum capacity of n
// channels.
func newChannelCache(n int) *channelCache {
return &channelCache{
n: n,
channels: make(map[uint64]ChannelEdge),
}
}
// get returns the channel from the cache, if it exists.
func (c *channelCache) get(chanid uint64) (ChannelEdge, bool) {
channel, ok := c.channels[chanid]
return channel, ok
}
// insert adds the entry to the channel cache. If an entry for chanid already
// exists, it will be replaced with the new entry. If the entry doesn't exist,
// it will be inserted to the cache, performing a random eviction if the cache
// is at capacity.
func (c *channelCache) insert(chanid uint64, channel ChannelEdge) {
// If entry exists, replace it.
if _, ok := c.channels[chanid]; ok {
c.channels[chanid] = channel
return
}
// Otherwise, evict an entry at random and insert.
if len(c.channels) == c.n {
for id := range c.channels {
delete(c.channels, id)
break
}
}
c.channels[chanid] = channel
}
// remove deletes an edge for chanid from the cache, if it exists.
func (c *channelCache) remove(chanid uint64) {
delete(c.channels, chanid)
}

View file

@ -1,105 +0,0 @@
package migration_01_to_11
import (
"reflect"
"testing"
)
// TestChannelCache checks the behavior of the channelCache with respect to
// insertion, eviction, and removal of cache entries.
func TestChannelCache(t *testing.T) {
const cacheSize = 100
// Create a new channel cache with the configured max size.
c := newChannelCache(cacheSize)
// As a sanity check, assert that querying the empty cache does not
// return an entry.
_, ok := c.get(0)
if ok {
t.Fatalf("channel cache should be empty")
}
// Now, fill up the cache entirely.
for i := uint64(0); i < cacheSize; i++ {
c.insert(i, channelForInt(i))
}
// Assert that the cache has all of the entries just inserted, since no
// eviction should occur until we try to surpass the max size.
assertHasChanEntries(t, c, 0, cacheSize)
// Now, insert a new element that causes the cache to evict an element.
c.insert(cacheSize, channelForInt(cacheSize))
// Assert that the cache has this last entry, as the cache should evict
// some prior element and not the newly inserted one.
assertHasChanEntries(t, c, cacheSize, cacheSize)
// Iterate over all inserted elements and construct a set of the evicted
// elements.
evicted := make(map[uint64]struct{})
for i := uint64(0); i < cacheSize+1; i++ {
_, ok := c.get(i)
if !ok {
evicted[i] = struct{}{}
}
}
// Assert that exactly one element has been evicted.
numEvicted := len(evicted)
if numEvicted != 1 {
t.Fatalf("expected one evicted entry, got: %d", numEvicted)
}
// Remove the highest item which initially caused the eviction and
// reinsert the element that was evicted prior.
c.remove(cacheSize)
for i := range evicted {
c.insert(i, channelForInt(i))
}
// Since the removal created an extra slot, the last insertion should
// not have caused an eviction and the entries for all channels in the
// original set that filled the cache should be present.
assertHasChanEntries(t, c, 0, cacheSize)
// Finally, reinsert the existing set back into the cache and test that
// the cache still has all the entries. If the randomized eviction were
// happening on inserts for existing cache items, we expect this to fail
// with high probability.
for i := uint64(0); i < cacheSize; i++ {
c.insert(i, channelForInt(i))
}
assertHasChanEntries(t, c, 0, cacheSize)
}
// assertHasEntries queries the edge cache for all channels in the range [start,
// end), asserting that they exist and their value matches the entry produced by
// entryForInt.
func assertHasChanEntries(t *testing.T, c *channelCache, start, end uint64) {
t.Helper()
for i := start; i < end; i++ {
entry, ok := c.get(i)
if !ok {
t.Fatalf("channel cache should contain chan %d", i)
}
expEntry := channelForInt(i)
if !reflect.DeepEqual(entry, expEntry) {
t.Fatalf("entry mismatch, want: %v, got: %v",
expEntry, entry)
}
}
}
// channelForInt generates a unique ChannelEdge given an integer.
func channelForInt(i uint64) ChannelEdge {
return ChannelEdge{
Info: &ChannelEdgeInfo{
ChannelID: i,
},
}
}

View file

@ -4,18 +4,13 @@ import (
"bytes"
"io/ioutil"
"math/rand"
"net"
"os"
"reflect"
"runtime"
"testing"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
@ -66,8 +61,6 @@ var (
LockTime: 5,
}
privKey, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:])
wireSig, _ = lnwire.NewSigFromSignature(testSig)
)
// makeTestDB creates a new instance of the ChannelDB for testing purposes. A
@ -223,819 +216,6 @@ func createTestChannelState(cdb *DB) (*OpenChannel, error) {
RevocationProducer: producer,
RevocationStore: store,
Db: cdb,
Packager: NewChannelPackager(chanID),
FundingTxn: testTx,
}, nil
}
func TestOpenChannelPutGetDelete(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
// Create the test channel state, then add an additional fake HTLC
// before syncing to disk.
state, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
state.LocalCommitment.Htlcs = []HTLC{
{
Signature: testSig.Serialize(),
Incoming: true,
Amt: 10,
RHash: key,
RefundTimeout: 1,
OnionBlob: []byte("onionblob"),
},
}
state.RemoteCommitment.Htlcs = []HTLC{
{
Signature: testSig.Serialize(),
Incoming: false,
Amt: 10,
RHash: key,
RefundTimeout: 1,
OnionBlob: []byte("onionblob"),
},
}
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18556,
}
if err := state.SyncPending(addr, 101); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
openChannels, err := cdb.FetchOpenChannels(state.IdentityPub)
if err != nil {
t.Fatalf("unable to fetch open channel: %v", err)
}
newState := openChannels[0]
// The decoded channel state should be identical to what we stored
// above.
if !reflect.DeepEqual(state, newState) {
t.Fatalf("channel state doesn't match:: %v vs %v",
spew.Sdump(state), spew.Sdump(newState))
}
// We'll also test that the channel is properly able to hot swap the
// next revocation for the state machine. This tests the initial
// post-funding revocation exchange.
nextRevKey, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
t.Fatalf("unable to create new private key: %v", err)
}
if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil {
t.Fatalf("unable to update revocation: %v", err)
}
openChannels, err = cdb.FetchOpenChannels(state.IdentityPub)
if err != nil {
t.Fatalf("unable to fetch open channel: %v", err)
}
updatedChan := openChannels[0]
// Ensure that the revocation was set properly.
if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) {
t.Fatalf("next revocation wasn't updated")
}
// Finally to wrap up the test, delete the state of the channel within
// the database. This involves "closing" the channel which removes all
// written state, and creates a small "summary" elsewhere within the
// database.
closeSummary := &ChannelCloseSummary{
ChanPoint: state.FundingOutpoint,
RemotePub: state.IdentityPub,
SettledBalance: btcutil.Amount(500),
TimeLockedBalance: btcutil.Amount(10000),
IsPending: false,
CloseType: CooperativeClose,
}
if err := state.CloseChannel(closeSummary); err != nil {
t.Fatalf("unable to close channel: %v", err)
}
// As the channel is now closed, attempting to fetch all open channels
// for our fake node ID should return an empty slice.
openChans, err := cdb.FetchOpenChannels(state.IdentityPub)
if err != nil {
t.Fatalf("unable to fetch open channels: %v", err)
}
if len(openChans) != 0 {
t.Fatalf("all channels not deleted, found %v", len(openChans))
}
// Additionally, attempting to fetch all the open channels globally
// should yield no results.
openChans, err = cdb.FetchAllChannels()
if err != nil {
t.Fatal("unable to fetch all open chans")
}
if len(openChans) != 0 {
t.Fatalf("all channels not deleted, found %v", len(openChans))
}
}
func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
if !reflect.DeepEqual(a, b) {
_, _, line, _ := runtime.Caller(1)
t.Fatalf("line %v: commitments don't match: %v vs %v",
line, spew.Sdump(a), spew.Sdump(b))
}
}
func TestChannelStateTransition(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
// First create a minimal channel, then perform a full sync in order to
// persist the data.
channel, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18556,
}
if err := channel.SyncPending(addr, 101); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
// Add some HTLCs which were added during this new state transition.
// Half of the HTLCs are incoming, while the other half are outgoing.
var (
htlcs []HTLC
htlcAmt lnwire.MilliSatoshi
)
for i := uint32(0); i < 10; i++ {
var incoming bool
if i > 5 {
incoming = true
}
htlc := HTLC{
Signature: testSig.Serialize(),
Incoming: incoming,
Amt: 10,
RHash: key,
RefundTimeout: i,
OutputIndex: int32(i * 3),
LogIndex: uint64(i * 2),
HtlcIndex: uint64(i),
}
htlc.OnionBlob = make([]byte, 10)
copy(htlc.OnionBlob[:], bytes.Repeat([]byte{2}, 10))
htlcs = append(htlcs, htlc)
htlcAmt += htlc.Amt
}
// Create a new channel delta which includes the above HTLCs, some
// balance updates, and an increment of the current commitment height.
// Additionally, modify the signature and commitment transaction.
newSequence := uint32(129498)
newSig := bytes.Repeat([]byte{3}, 71)
newTx := channel.LocalCommitment.CommitTx.Copy()
newTx.TxIn[0].Sequence = newSequence
commitment := ChannelCommitment{
CommitHeight: 1,
LocalLogIndex: 2,
LocalHtlcIndex: 1,
RemoteLogIndex: 2,
RemoteHtlcIndex: 1,
LocalBalance: lnwire.MilliSatoshi(1e8),
RemoteBalance: lnwire.MilliSatoshi(1e8),
CommitFee: 55,
FeePerKw: 99,
CommitTx: newTx,
CommitSig: newSig,
Htlcs: htlcs,
}
// First update the local node's broadcastable state and also add a
// CommitDiff remote node's as well in order to simulate a proper state
// transition.
if err := channel.UpdateCommitment(&commitment); err != nil {
t.Fatalf("unable to update commitment: %v", err)
}
// The balances, new update, the HTLCs and the changes to the fake
// commitment transaction along with the modified signature should all
// have been updated.
updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub)
if err != nil {
t.Fatalf("unable to fetch updated channel: %v", err)
}
assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment)
numDiskUpdates, err := updatedChannel[0].CommitmentHeight()
if err != nil {
t.Fatalf("unable to read commitment height from disk: %v", err)
}
if numDiskUpdates != uint64(commitment.CommitHeight) {
t.Fatalf("num disk updates doesn't match: %v vs %v",
numDiskUpdates, commitment.CommitHeight)
}
// Attempting to query for a commitment diff should return
// ErrNoPendingCommit as we haven't yet created a new state for them.
_, err = channel.RemoteCommitChainTip()
if err != ErrNoPendingCommit {
t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
}
// To simulate us extending a new state to the remote party, we'll also
// create a new commit diff for them.
remoteCommit := commitment
remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8)
remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8)
remoteCommit.CommitHeight = 1
commitDiff := &CommitDiff{
Commitment: remoteCommit,
CommitSig: &lnwire.CommitSig{
ChanID: lnwire.ChannelID(key),
CommitSig: wireSig,
HtlcSigs: []lnwire.Sig{
wireSig,
wireSig,
},
},
LogUpdates: []LogUpdate{
{
LogIndex: 1,
UpdateMsg: &lnwire.UpdateAddHTLC{
ID: 1,
Amount: lnwire.NewMSatFromSatoshis(100),
Expiry: 25,
},
},
{
LogIndex: 2,
UpdateMsg: &lnwire.UpdateAddHTLC{
ID: 2,
Amount: lnwire.NewMSatFromSatoshis(200),
Expiry: 50,
},
},
},
OpenedCircuitKeys: []CircuitKey{},
ClosedCircuitKeys: []CircuitKey{},
}
copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
bytes.Repeat([]byte{1}, 32))
copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
bytes.Repeat([]byte{2}, 32))
if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
t.Fatalf("unable to add to commit chain: %v", err)
}
// The commitment tip should now match the commitment that we just
// inserted.
diskCommitDiff, err := channel.RemoteCommitChainTip()
if err != nil {
t.Fatalf("unable to fetch commit diff: %v", err)
}
if !reflect.DeepEqual(commitDiff, diskCommitDiff) {
t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit),
spew.Sdump(diskCommitDiff))
}
// We'll save the old remote commitment as this will be added to the
// revocation log shortly.
oldRemoteCommit := channel.RemoteCommitment
// Next, write to the log which tracks the necessary revocation state
// needed to rectify any fishy behavior by the remote party. Modify the
// current uncollapsed revocation state to simulate a state transition
// by the remote party.
channel.RemoteCurrentRevocation = channel.RemoteNextRevocation
newPriv, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
t.Fatalf("unable to generate key: %v", err)
}
channel.RemoteNextRevocation = newPriv.PubKey()
fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight,
diskCommitDiff.LogUpdates, nil)
err = channel.AdvanceCommitChainTail(fwdPkg)
if err != nil {
t.Fatalf("unable to append to revocation log: %v", err)
}
// At this point, the remote commit chain should be nil, and the posted
// remote commitment should match the one we added as a diff above.
if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit {
t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
}
// We should be able to fetch the channel delta created above by its
// update number with all the state properly reconstructed.
diskPrevCommit, err := channel.FindPreviousState(
oldRemoteCommit.CommitHeight,
)
if err != nil {
t.Fatalf("unable to fetch past delta: %v", err)
}
// The two deltas (the original vs the on-disk version) should
// identical, and all HTLC data should properly be retained.
assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit)
// The state number recovered from the tail of the revocation log
// should be identical to this current state.
logTail, err := channel.RevocationLogTail()
if err != nil {
t.Fatalf("unable to retrieve log: %v", err)
}
if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
t.Fatal("update number doesn't match")
}
oldRemoteCommit = channel.RemoteCommitment
// Next modify the posted diff commitment slightly, then create a new
// commitment diff and advance the tail.
commitDiff.Commitment.CommitHeight = 2
commitDiff.Commitment.LocalBalance -= htlcAmt
commitDiff.Commitment.RemoteBalance += htlcAmt
commitDiff.LogUpdates = []LogUpdate{}
if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
t.Fatalf("unable to add to commit chain: %v", err)
}
fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil)
err = channel.AdvanceCommitChainTail(fwdPkg)
if err != nil {
t.Fatalf("unable to append to revocation log: %v", err)
}
// Once again, fetch the state and ensure it has been properly updated.
prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight)
if err != nil {
t.Fatalf("unable to fetch past delta: %v", err)
}
assertCommitmentEqual(t, &oldRemoteCommit, prevCommit)
// Once again, state number recovered from the tail of the revocation
// log should be identical to this current state.
logTail, err = channel.RevocationLogTail()
if err != nil {
t.Fatalf("unable to retrieve log: %v", err)
}
if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
t.Fatal("update number doesn't match")
}
// The revocation state stored on-disk should now also be identical.
updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub)
if err != nil {
t.Fatalf("unable to fetch updated channel: %v", err)
}
if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) {
t.Fatalf("revocation state was not synced")
}
if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) {
t.Fatalf("revocation state was not synced")
}
// Now attempt to delete the channel from the database.
closeSummary := &ChannelCloseSummary{
ChanPoint: channel.FundingOutpoint,
RemotePub: channel.IdentityPub,
SettledBalance: btcutil.Amount(500),
TimeLockedBalance: btcutil.Amount(10000),
IsPending: false,
CloseType: RemoteForceClose,
}
if err := updatedChannel[0].CloseChannel(closeSummary); err != nil {
t.Fatalf("unable to delete updated channel: %v", err)
}
// If we attempt to fetch the target channel again, it shouldn't be
// found.
channels, err := cdb.FetchOpenChannels(channel.IdentityPub)
if err != nil {
t.Fatalf("unable to fetch updated channels: %v", err)
}
if len(channels) != 0 {
t.Fatalf("%v channels, found, but none should be",
len(channels))
}
// Attempting to find previous states on the channel should fail as the
// revocation log has been deleted.
_, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight)
if err == nil {
t.Fatal("revocation log search should have failed")
}
}
func TestFetchPendingChannels(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
// Create first test channel state
state, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
const broadcastHeight = 99
if err := state.SyncPending(addr, broadcastHeight); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
pendingChannels, err := cdb.FetchPendingChannels()
if err != nil {
t.Fatalf("unable to list pending channels: %v", err)
}
if len(pendingChannels) != 1 {
t.Fatalf("incorrect number of pending channels: expecting %v,"+
"got %v", 1, len(pendingChannels))
}
// The broadcast height of the pending channel should have been set
// properly.
if pendingChannels[0].FundingBroadcastHeight != broadcastHeight {
t.Fatalf("broadcast height mismatch: expected %v, got %v",
pendingChannels[0].FundingBroadcastHeight,
broadcastHeight)
}
chanOpenLoc := lnwire.ShortChannelID{
BlockHeight: 5,
TxIndex: 10,
TxPosition: 15,
}
err = pendingChannels[0].MarkAsOpen(chanOpenLoc)
if err != nil {
t.Fatalf("unable to mark channel as open: %v", err)
}
if pendingChannels[0].IsPending {
t.Fatalf("channel marked open should no longer be pending")
}
if pendingChannels[0].ShortChanID() != chanOpenLoc {
t.Fatalf("channel opening height not updated: expected %v, "+
"got %v", spew.Sdump(pendingChannels[0].ShortChanID()),
chanOpenLoc)
}
// Next, we'll re-fetch the channel to ensure that the open height was
// properly set.
openChans, err := cdb.FetchAllChannels()
if err != nil {
t.Fatalf("unable to fetch channels: %v", err)
}
if openChans[0].ShortChanID() != chanOpenLoc {
t.Fatalf("channel opening heights don't match: expected %v, "+
"got %v", spew.Sdump(openChans[0].ShortChanID()),
chanOpenLoc)
}
if openChans[0].FundingBroadcastHeight != broadcastHeight {
t.Fatalf("broadcast height mismatch: expected %v, got %v",
openChans[0].FundingBroadcastHeight,
broadcastHeight)
}
pendingChannels, err = cdb.FetchPendingChannels()
if err != nil {
t.Fatalf("unable to list pending channels: %v", err)
}
if len(pendingChannels) != 0 {
t.Fatalf("incorrect number of pending channels: expecting %v,"+
"got %v", 0, len(pendingChannels))
}
}
func TestFetchClosedChannels(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
// First create a test channel, that we'll be closing within this pull
// request.
state, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
// Next sync the channel to disk, marking it as being in a pending open
// state.
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
const broadcastHeight = 99
if err := state.SyncPending(addr, broadcastHeight); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
// Next, simulate the confirmation of the channel by marking it as
// pending within the database.
chanOpenLoc := lnwire.ShortChannelID{
BlockHeight: 5,
TxIndex: 10,
TxPosition: 15,
}
err = state.MarkAsOpen(chanOpenLoc)
if err != nil {
t.Fatalf("unable to mark channel as open: %v", err)
}
// Next, close the channel by including a close channel summary in the
// database.
summary := &ChannelCloseSummary{
ChanPoint: state.FundingOutpoint,
ClosingTXID: rev,
RemotePub: state.IdentityPub,
Capacity: state.Capacity,
SettledBalance: state.LocalCommitment.LocalBalance.ToSatoshis(),
TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000,
CloseType: RemoteForceClose,
IsPending: true,
LocalChanConfig: state.LocalChanCfg,
}
if err := state.CloseChannel(summary); err != nil {
t.Fatalf("unable to close channel: %v", err)
}
// Query the database to ensure that the channel has now been properly
// closed. We should get the same result whether querying for pending
// channels only, or not.
pendingClosed, err := cdb.FetchClosedChannels(true)
if err != nil {
t.Fatalf("failed fetching closed channels: %v", err)
}
if len(pendingClosed) != 1 {
t.Fatalf("incorrect number of pending closed channels: expecting %v,"+
"got %v", 1, len(pendingClosed))
}
if !reflect.DeepEqual(summary, pendingClosed[0]) {
t.Fatalf("database summaries don't match: expected %v got %v",
spew.Sdump(summary), spew.Sdump(pendingClosed[0]))
}
closed, err := cdb.FetchClosedChannels(false)
if err != nil {
t.Fatalf("failed fetching all closed channels: %v", err)
}
if len(closed) != 1 {
t.Fatalf("incorrect number of closed channels: expecting %v, "+
"got %v", 1, len(closed))
}
if !reflect.DeepEqual(summary, closed[0]) {
t.Fatalf("database summaries don't match: expected %v got %v",
spew.Sdump(summary), spew.Sdump(closed[0]))
}
// Mark the channel as fully closed.
err = cdb.MarkChanFullyClosed(&state.FundingOutpoint)
if err != nil {
t.Fatalf("failed fully closing channel: %v", err)
}
// The channel should no longer be considered pending, but should still
// be retrieved when fetching all the closed channels.
closed, err = cdb.FetchClosedChannels(false)
if err != nil {
t.Fatalf("failed fetching closed channels: %v", err)
}
if len(closed) != 1 {
t.Fatalf("incorrect number of closed channels: expecting %v, "+
"got %v", 1, len(closed))
}
pendingClose, err := cdb.FetchClosedChannels(true)
if err != nil {
t.Fatalf("failed fetching channels pending close: %v", err)
}
if len(pendingClose) != 0 {
t.Fatalf("incorrect number of closed channels: expecting %v, "+
"got %v", 0, len(closed))
}
}
// TestFetchWaitingCloseChannels ensures that the correct channels that are
// waiting to be closed are returned.
func TestFetchWaitingCloseChannels(t *testing.T) {
t.Parallel()
const numChannels = 2
const broadcastHeight = 99
addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 18555}
// We'll start by creating two channels within our test database. One of
// them will have their funding transaction confirmed on-chain, while
// the other one will remain unconfirmed.
db, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
channels := make([]*OpenChannel, numChannels)
for i := 0; i < numChannels; i++ {
channel, err := createTestChannelState(db)
if err != nil {
t.Fatalf("unable to create channel: %v", err)
}
err = channel.SyncPending(addr, broadcastHeight)
if err != nil {
t.Fatalf("unable to sync channel: %v", err)
}
channels[i] = channel
}
// We'll only confirm the first one.
channelConf := lnwire.ShortChannelID{
BlockHeight: broadcastHeight + 1,
TxIndex: 10,
TxPosition: 15,
}
if err := channels[0].MarkAsOpen(channelConf); err != nil {
t.Fatalf("unable to mark channel as open: %v", err)
}
// Then, we'll mark the channels as if their commitments were broadcast.
// This would happen in the event of a force close and should make the
// channels enter a state of waiting close.
for _, channel := range channels {
closeTx := wire.NewMsgTx(2)
closeTx.AddTxIn(
&wire.TxIn{
PreviousOutPoint: channel.FundingOutpoint,
},
)
if err := channel.MarkCommitmentBroadcasted(closeTx); err != nil {
t.Fatalf("unable to mark commitment broadcast: %v", err)
}
}
// Now, we'll fetch all the channels waiting to be closed from the
// database. We should expect to see both channels above, even if any of
// them haven't had their funding transaction confirm on-chain.
waitingCloseChannels, err := db.FetchWaitingCloseChannels()
if err != nil {
t.Fatalf("unable to fetch all waiting close channels: %v", err)
}
if len(waitingCloseChannels) != 2 {
t.Fatalf("expected %d channels waiting to be closed, got %d", 2,
len(waitingCloseChannels))
}
expectedChannels := make(map[wire.OutPoint]struct{})
for _, channel := range channels {
expectedChannels[channel.FundingOutpoint] = struct{}{}
}
for _, channel := range waitingCloseChannels {
if _, ok := expectedChannels[channel.FundingOutpoint]; !ok {
t.Fatalf("expected channel %v to be waiting close",
channel.FundingOutpoint)
}
// Finally, make sure we can retrieve the closing tx for the
// channel.
closeTx, err := channel.BroadcastedCommitment()
if err != nil {
t.Fatalf("Unable to retrieve commitment: %v", err)
}
if closeTx.TxIn[0].PreviousOutPoint != channel.FundingOutpoint {
t.Fatalf("expected outpoint %v, got %v",
channel.FundingOutpoint,
closeTx.TxIn[0].PreviousOutPoint)
}
}
}
// TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory
// short channel ID of another OpenChannel to reflect a preceding call to
// MarkOpen on a different OpenChannel.
func TestRefreshShortChanID(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
// First create a test channel.
state, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
// Mark the channel as pending within the channeldb.
const broadcastHeight = 99
if err := state.SyncPending(addr, broadcastHeight); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
// Next, locate the pending channel with the database.
pendingChannels, err := cdb.FetchPendingChannels()
if err != nil {
t.Fatalf("unable to load pending channels; %v", err)
}
var pendingChannel *OpenChannel
for _, channel := range pendingChannels {
if channel.FundingOutpoint == state.FundingOutpoint {
pendingChannel = channel
break
}
}
if pendingChannel == nil {
t.Fatalf("unable to find pending channel with funding "+
"outpoint=%v: %v", state.FundingOutpoint, err)
}
// Next, simulate the confirmation of the channel by marking it as
// pending within the database.
chanOpenLoc := lnwire.ShortChannelID{
BlockHeight: 105,
TxIndex: 10,
TxPosition: 15,
}
err = state.MarkAsOpen(chanOpenLoc)
if err != nil {
t.Fatalf("unable to mark channel open: %v", err)
}
// The short_chan_id of the receiver to MarkAsOpen should reflect the
// open location, but the other pending channel should remain unchanged.
if state.ShortChanID() == pendingChannel.ShortChanID() {
t.Fatalf("pending channel short_chan_ID should not have been " +
"updated before refreshing short_chan_id")
}
// Now that the receiver's short channel id has been updated, check to
// ensure that the channel packager's source has been updated as well.
// This ensures that the packager will read and write to buckets
// corresponding to the new short chan id, instead of the prior.
if state.Packager.(*ChannelPackager).source != chanOpenLoc {
t.Fatalf("channel packager source was not updated: want %v, "+
"got %v", chanOpenLoc,
state.Packager.(*ChannelPackager).source)
}
// Now, refresh the short channel ID of the pending channel.
err = pendingChannel.RefreshShortChanID()
if err != nil {
t.Fatalf("unable to refresh short_chan_id: %v", err)
}
// This should result in both OpenChannel's now having the same
// ShortChanID.
if state.ShortChanID() != pendingChannel.ShortChanID() {
t.Fatalf("expected pending channel short_chan_id to be "+
"refreshed: want %v, got %v", state.ShortChanID(),
pendingChannel.ShortChanID())
}
// Check to ensure that the _other_ OpenChannel channel packager's
// source has also been updated after the refresh. This ensures that the
// other packagers will read and write to buckets corresponding to the
// updated short chan id.
if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc {
t.Fatalf("channel packager source was not updated: want %v, "+
"got %v", chanOpenLoc,
pendingChannel.Packager.(*ChannelPackager).source)
}
}

View file

@ -48,12 +48,6 @@ type UnknownElementType struct {
element interface{}
}
// NewUnknownElementType creates a new UnknownElementType error from the passed
// method name and element.
func NewUnknownElementType(method string, el interface{}) UnknownElementType {
return UnknownElementType{method: method, element: el}
}
// Error returns the name of the method that encountered the error, as well as
// the type that was unsupported.
func (e UnknownElementType) Error() string {

View file

@ -4,16 +4,11 @@ import (
"bytes"
"encoding/binary"
"fmt"
"net"
"os"
"path/filepath"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lnwire"
)
const (
@ -87,57 +82,6 @@ func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) {
return chanDB, nil
}
// Path returns the file path to the channel database.
func (d *DB) Path() string {
return d.dbPath
}
// Wipe completely deletes all saved state within all used buckets within the
// database. The deletion is done in a single transaction, therefore this
// operation is fully atomic.
func (d *DB) Wipe() error {
return d.Update(func(tx *bbolt.Tx) error {
err := tx.DeleteBucket(openChannelBucket)
if err != nil && err != bbolt.ErrBucketNotFound {
return err
}
err = tx.DeleteBucket(closedChannelBucket)
if err != nil && err != bbolt.ErrBucketNotFound {
return err
}
err = tx.DeleteBucket(invoiceBucket)
if err != nil && err != bbolt.ErrBucketNotFound {
return err
}
err = tx.DeleteBucket(nodeInfoBucket)
if err != nil && err != bbolt.ErrBucketNotFound {
return err
}
err = tx.DeleteBucket(nodeBucket)
if err != nil && err != bbolt.ErrBucketNotFound {
return err
}
err = tx.DeleteBucket(edgeBucket)
if err != nil && err != bbolt.ErrBucketNotFound {
return err
}
err = tx.DeleteBucket(edgeIndexBucket)
if err != nil && err != bbolt.ErrBucketNotFound {
return err
}
err = tx.DeleteBucket(graphMetaBucket)
if err != nil && err != bbolt.ErrBucketNotFound {
return err
}
return nil
})
}
// createChannelDB creates and initializes a fresh version of channeldb. In
// the case that the target path has not yet been created or doesn't yet exist,
// then the path is created. Additionally, all required top-level buckets used
@ -163,14 +107,6 @@ func createChannelDB(dbPath string) error {
return err
}
if _, err := tx.CreateBucket(forwardingLogBucket); err != nil {
return err
}
if _, err := tx.CreateBucket(fwdPackagesKey); err != nil {
return err
}
if _, err := tx.CreateBucket(invoiceBucket); err != nil {
return err
}
@ -179,10 +115,6 @@ func createChannelDB(dbPath string) error {
return err
}
if _, err := tx.CreateBucket(nodeInfoBucket); err != nil {
return err
}
nodes, err := tx.CreateBucket(nodeBucket)
if err != nil {
return err
@ -249,359 +181,6 @@ func fileExists(path string) bool {
return true
}
// FetchOpenChannels starts a new database transaction and returns all stored
// currently active/open channels associated with the target nodeID. In the case
// that no active channels are known to have been created with this node, then a
// zero-length slice is returned.
func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
var channels []*OpenChannel
err := d.View(func(tx *bbolt.Tx) error {
var err error
channels, err = d.fetchOpenChannels(tx, nodeID)
return err
})
return channels, err
}
// fetchOpenChannels uses and existing database transaction and returns all
// stored currently active/open channels associated with the target nodeID. In
// the case that no active channels are known to have been created with this
// node, then a zero-length slice is returned.
func (d *DB) fetchOpenChannels(tx *bbolt.Tx,
nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
// Get the bucket dedicated to storing the metadata for open channels.
openChanBucket := tx.Bucket(openChannelBucket)
if openChanBucket == nil {
return nil, nil
}
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
pub := nodeID.SerializeCompressed()
nodeChanBucket := openChanBucket.Bucket(pub)
if nodeChanBucket == nil {
return nil, nil
}
// Next, we'll need to go down an additional layer in order to retrieve
// the channels for each chain the node knows of.
var channels []*OpenChannel
err := nodeChanBucket.ForEach(func(chainHash, v []byte) error {
// If there's a value, it's not a bucket so ignore it.
if v != nil {
return nil
}
// If we've found a valid chainhash bucket, then we'll retrieve
// that so we can extract all the channels.
chainBucket := nodeChanBucket.Bucket(chainHash)
if chainBucket == nil {
return fmt.Errorf("unable to read bucket for chain=%x",
chainHash[:])
}
// Finally, we both of the necessary buckets retrieved, fetch
// all the active channels related to this node.
nodeChannels, err := d.fetchNodeChannels(chainBucket)
if err != nil {
return fmt.Errorf("unable to read channel for "+
"chain_hash=%x, node_key=%x: %v",
chainHash[:], pub, err)
}
channels = append(channels, nodeChannels...)
return nil
})
return channels, err
}
// fetchNodeChannels retrieves all active channels from the target chainBucket
// which is under a node's dedicated channel bucket. This function is typically
// used to fetch all the active channels related to a particular node.
func (d *DB) fetchNodeChannels(chainBucket *bbolt.Bucket) ([]*OpenChannel, error) {
var channels []*OpenChannel
// A node may have channels on several chains, so for each known chain,
// we'll extract all the channels.
err := chainBucket.ForEach(func(chanPoint, v []byte) error {
// If there's a value, it's not a bucket so ignore it.
if v != nil {
return nil
}
// Once we've found a valid channel bucket, we'll extract it
// from the node's chain bucket.
chanBucket := chainBucket.Bucket(chanPoint)
var outPoint wire.OutPoint
err := readOutpoint(bytes.NewReader(chanPoint), &outPoint)
if err != nil {
return err
}
oChannel, err := fetchOpenChannel(chanBucket, &outPoint)
if err != nil {
return fmt.Errorf("unable to read channel data for "+
"chan_point=%v: %v", outPoint, err)
}
oChannel.Db = d
channels = append(channels, oChannel)
return nil
})
if err != nil {
return nil, err
}
return channels, nil
}
// FetchChannel attempts to locate a channel specified by the passed channel
// point. If the channel cannot be found, then an error will be returned.
func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) {
var (
targetChan *OpenChannel
targetChanPoint bytes.Buffer
)
if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil {
return nil, err
}
// chanScan will traverse the following bucket structure:
// * nodePub => chainHash => chanPoint
//
// At each level we go one further, ensuring that we're traversing the
// proper key (that's actually a bucket). By only reading the bucket
// structure and skipping fully decoding each channel, we save a good
// bit of CPU as we don't need to do things like decompress public
// keys.
chanScan := func(tx *bbolt.Tx) error {
// Get the bucket dedicated to storing the metadata for open
// channels.
openChanBucket := tx.Bucket(openChannelBucket)
if openChanBucket == nil {
return ErrNoActiveChannels
}
// Within the node channel bucket, are the set of node pubkeys
// we have channels with, we don't know the entire set, so
// we'll check them all.
return openChanBucket.ForEach(func(nodePub, v []byte) error {
// Ensure that this is a key the same size as a pubkey,
// and also that it leads directly to a bucket.
if len(nodePub) != 33 || v != nil {
return nil
}
nodeChanBucket := openChanBucket.Bucket(nodePub)
if nodeChanBucket == nil {
return nil
}
// The next layer down is all the chains that this node
// has channels on with us.
return nodeChanBucket.ForEach(func(chainHash, v []byte) error {
// If there's a value, it's not a bucket so
// ignore it.
if v != nil {
return nil
}
chainBucket := nodeChanBucket.Bucket(chainHash)
if chainBucket == nil {
return fmt.Errorf("unable to read "+
"bucket for chain=%x", chainHash[:])
}
// Finally we reach the leaf bucket that stores
// all the chanPoints for this node.
chanBucket := chainBucket.Bucket(
targetChanPoint.Bytes(),
)
if chanBucket == nil {
return nil
}
channel, err := fetchOpenChannel(
chanBucket, &chanPoint,
)
if err != nil {
return err
}
targetChan = channel
targetChan.Db = d
return nil
})
})
}
err := d.View(chanScan)
if err != nil {
return nil, err
}
if targetChan != nil {
return targetChan, nil
}
// If we can't find the channel, then we return with an error, as we
// have nothing to backup.
return nil, ErrChannelNotFound
}
// FetchAllChannels attempts to retrieve all open channels currently stored
// within the database, including pending open, fully open and channels waiting
// for a closing transaction to confirm.
func (d *DB) FetchAllChannels() ([]*OpenChannel, error) {
var channels []*OpenChannel
// TODO(halseth): fetch all in one db tx.
openChannels, err := d.FetchAllOpenChannels()
if err != nil {
return nil, err
}
channels = append(channels, openChannels...)
pendingChannels, err := d.FetchPendingChannels()
if err != nil {
return nil, err
}
channels = append(channels, pendingChannels...)
waitingClose, err := d.FetchWaitingCloseChannels()
if err != nil {
return nil, err
}
channels = append(channels, waitingClose...)
return channels, nil
}
// FetchAllOpenChannels will return all channels that have the funding
// transaction confirmed, and is not waiting for a closing transaction to be
// confirmed.
func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) {
return fetchChannels(d, false, false)
}
// FetchPendingChannels will return channels that have completed the process of
// generating and broadcasting funding transactions, but whose funding
// transactions have yet to be confirmed on the blockchain.
func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) {
return fetchChannels(d, true, false)
}
// FetchWaitingCloseChannels will return all channels that have been opened,
// but are now waiting for a closing transaction to be confirmed.
//
// NOTE: This includes channels that are also pending to be opened.
func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) {
waitingClose, err := fetchChannels(d, false, true)
if err != nil {
return nil, err
}
pendingWaitingClose, err := fetchChannels(d, true, true)
if err != nil {
return nil, err
}
return append(waitingClose, pendingWaitingClose...), nil
}
// fetchChannels attempts to retrieve channels currently stored in the
// database. The pending parameter determines whether only pending channels
// will be returned, or only open channels will be returned. The waitingClose
// parameter determines whether only channels waiting for a closing transaction
// to be confirmed should be returned. If no active channels exist within the
// network, then ErrNoActiveChannels is returned.
func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) {
var channels []*OpenChannel
err := d.View(func(tx *bbolt.Tx) error {
// Get the bucket dedicated to storing the metadata for open
// channels.
openChanBucket := tx.Bucket(openChannelBucket)
if openChanBucket == nil {
return ErrNoActiveChannels
}
// Next, fetch the bucket dedicated to storing metadata related
// to all nodes. All keys within this bucket are the serialized
// public keys of all our direct counterparties.
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return fmt.Errorf("node bucket not created")
}
// Finally for each node public key in the bucket, fetch all
// the channels related to this particular node.
return nodeMetaBucket.ForEach(func(k, v []byte) error {
nodeChanBucket := openChanBucket.Bucket(k)
if nodeChanBucket == nil {
return nil
}
return nodeChanBucket.ForEach(func(chainHash, v []byte) error {
// If there's a value, it's not a bucket so
// ignore it.
if v != nil {
return nil
}
// If we've found a valid chainhash bucket,
// then we'll retrieve that so we can extract
// all the channels.
chainBucket := nodeChanBucket.Bucket(chainHash)
if chainBucket == nil {
return fmt.Errorf("unable to read "+
"bucket for chain=%x", chainHash[:])
}
nodeChans, err := d.fetchNodeChannels(chainBucket)
if err != nil {
return fmt.Errorf("unable to read "+
"channel for chain_hash=%x, "+
"node_key=%x: %v", chainHash[:], k, err)
}
for _, channel := range nodeChans {
if channel.IsPending != pending {
continue
}
// If the channel is in any other state
// than Default, then it means it is
// waiting to be closed.
channelWaitingClose :=
channel.ChanStatus() != ChanStatusDefault
// Only include it if we requested
// channels with the same waitingClose
// status.
if channelWaitingClose != waitingClose {
continue
}
channels = append(channels, channel)
}
return nil
})
})
})
if err != nil {
return nil, err
}
return channels, nil
}
// FetchClosedChannels attempts to fetch all closed channels from the database.
// The pendingOnly bool toggles if channels that aren't yet fully closed should
// be returned in the response or not. When a channel was cooperatively closed,
@ -641,371 +220,6 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
return chanSummaries, nil
}
// ErrClosedChannelNotFound signals that a closed channel could not be found in
// the channeldb.
var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary")
// FetchClosedChannel queries for a channel close summary using the channel
// point of the channel in question.
func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) {
var chanSummary *ChannelCloseSummary
if err := d.View(func(tx *bbolt.Tx) error {
closeBucket := tx.Bucket(closedChannelBucket)
if closeBucket == nil {
return ErrClosedChannelNotFound
}
var b bytes.Buffer
var err error
if err = writeOutpoint(&b, chanID); err != nil {
return err
}
summaryBytes := closeBucket.Get(b.Bytes())
if summaryBytes == nil {
return ErrClosedChannelNotFound
}
summaryReader := bytes.NewReader(summaryBytes)
chanSummary, err = deserializeCloseChannelSummary(summaryReader)
return err
}); err != nil {
return nil, err
}
return chanSummary, nil
}
// FetchClosedChannelForID queries for a channel close summary using the
// channel ID of the channel in question.
func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
*ChannelCloseSummary, error) {
var chanSummary *ChannelCloseSummary
if err := d.View(func(tx *bbolt.Tx) error {
closeBucket := tx.Bucket(closedChannelBucket)
if closeBucket == nil {
return ErrClosedChannelNotFound
}
// The first 30 bytes of the channel ID and outpoint will be
// equal.
cursor := closeBucket.Cursor()
op, c := cursor.Seek(cid[:30])
// We scan over all possible candidates for this channel ID.
for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() {
var outPoint wire.OutPoint
err := readOutpoint(bytes.NewReader(op), &outPoint)
if err != nil {
return err
}
// If the found outpoint does not correspond to this
// channel ID, we continue.
if !cid.IsChanPoint(&outPoint) {
continue
}
// Deserialize the close summary and return.
r := bytes.NewReader(c)
chanSummary, err = deserializeCloseChannelSummary(r)
if err != nil {
return err
}
return nil
}
return ErrClosedChannelNotFound
}); err != nil {
return nil, err
}
return chanSummary, nil
}
// MarkChanFullyClosed marks a channel as fully closed within the database. A
// channel should be marked as fully closed if the channel was initially
// cooperatively closed and it's reached a single confirmation, or after all
// the pending funds in a channel that has been forcibly closed have been
// swept.
func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
return d.Update(func(tx *bbolt.Tx) error {
var b bytes.Buffer
if err := writeOutpoint(&b, chanPoint); err != nil {
return err
}
chanID := b.Bytes()
closedChanBucket, err := tx.CreateBucketIfNotExists(
closedChannelBucket,
)
if err != nil {
return err
}
chanSummaryBytes := closedChanBucket.Get(chanID)
if chanSummaryBytes == nil {
return fmt.Errorf("no closed channel for "+
"chan_point=%v found", chanPoint)
}
chanSummaryReader := bytes.NewReader(chanSummaryBytes)
chanSummary, err := deserializeCloseChannelSummary(
chanSummaryReader,
)
if err != nil {
return err
}
chanSummary.IsPending = false
var newSummary bytes.Buffer
err = serializeChannelCloseSummary(&newSummary, chanSummary)
if err != nil {
return err
}
err = closedChanBucket.Put(chanID, newSummary.Bytes())
if err != nil {
return err
}
// Now that the channel is closed, we'll check if we have any
// other open channels with this peer. If we don't we'll
// garbage collect it to ensure we don't establish persistent
// connections to peers without open channels.
return d.pruneLinkNode(tx, chanSummary.RemotePub)
})
}
// pruneLinkNode determines whether we should garbage collect a link node from
// the database due to no longer having any open channels with it. If there are
// any left, then this acts as a no-op.
func (d *DB) pruneLinkNode(tx *bbolt.Tx, remotePub *btcec.PublicKey) error {
openChannels, err := d.fetchOpenChannels(tx, remotePub)
if err != nil {
return fmt.Errorf("unable to fetch open channels for peer %x: "+
"%v", remotePub.SerializeCompressed(), err)
}
if len(openChannels) > 0 {
return nil
}
log.Infof("Pruning link node %x with zero open channels from database",
remotePub.SerializeCompressed())
return d.deleteLinkNode(tx, remotePub)
}
// PruneLinkNodes attempts to prune all link nodes found within the databse with
// whom we no longer have any open channels with.
func (d *DB) PruneLinkNodes() error {
return d.Update(func(tx *bbolt.Tx) error {
linkNodes, err := d.fetchAllLinkNodes(tx)
if err != nil {
return err
}
for _, linkNode := range linkNodes {
err := d.pruneLinkNode(tx, linkNode.IdentityPub)
if err != nil {
return err
}
}
return nil
})
}
// ChannelShell is a shell of a channel that is meant to be used for channel
// recovery purposes. It contains a minimal OpenChannel instance along with
// addresses for that target node.
type ChannelShell struct {
// NodeAddrs the set of addresses that this node has known to be
// reachable at in the past.
NodeAddrs []net.Addr
// Chan is a shell of an OpenChannel, it contains only the items
// required to restore the channel on disk.
Chan *OpenChannel
}
// RestoreChannelShells is a method that allows the caller to reconstruct the
// state of an OpenChannel from the ChannelShell. We'll attempt to write the
// new channel to disk, create a LinkNode instance with the passed node
// addresses, and finally create an edge within the graph for the channel as
// well. This method is idempotent, so repeated calls with the same set of
// channel shells won't modify the database after the initial call.
func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
chanGraph := d.ChannelGraph()
// TODO(conner): find way to do this w/o accessing internal members?
chanGraph.cacheMu.Lock()
defer chanGraph.cacheMu.Unlock()
var chansRestored []uint64
err := d.Update(func(tx *bbolt.Tx) error {
for _, channelShell := range channelShells {
channel := channelShell.Chan
// When we make a channel, we mark that the channel has
// been restored, this will signal to other sub-systems
// to not attempt to use the channel as if it was a
// regular one.
channel.chanStatus |= ChanStatusRestored
// First, we'll attempt to create a new open channel
// and link node for this channel. If the channel
// already exists, then in order to ensure this method
// is idempotent, we'll continue to the next step.
channel.Db = d
err := syncNewChannel(
tx, channel, channelShell.NodeAddrs,
)
if err != nil {
return err
}
// Next, we'll create an active edge in the graph
// database for this channel in order to restore our
// partial view of the network.
//
// TODO(roasbeef): if we restore *after* the channel
// has been closed on chain, then need to inform the
// router that it should try and prune these values as
// we can detect them
edgeInfo := ChannelEdgeInfo{
ChannelID: channel.ShortChannelID.ToUint64(),
ChainHash: channel.ChainHash,
ChannelPoint: channel.FundingOutpoint,
Capacity: channel.Capacity,
}
nodes := tx.Bucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
selfNode, err := chanGraph.sourceNode(nodes)
if err != nil {
return err
}
// Depending on which pub key is smaller, we'll assign
// our roles as "node1" and "node2".
chanPeer := channel.IdentityPub.SerializeCompressed()
selfIsSmaller := bytes.Compare(
selfNode.PubKeyBytes[:], chanPeer,
) == -1
if selfIsSmaller {
copy(edgeInfo.NodeKey1Bytes[:], selfNode.PubKeyBytes[:])
copy(edgeInfo.NodeKey2Bytes[:], chanPeer)
} else {
copy(edgeInfo.NodeKey1Bytes[:], chanPeer)
copy(edgeInfo.NodeKey2Bytes[:], selfNode.PubKeyBytes[:])
}
// With the edge info shell constructed, we'll now add
// it to the graph.
err = chanGraph.addChannelEdge(tx, &edgeInfo)
if err != nil && err != ErrEdgeAlreadyExist {
return err
}
// Similarly, we'll construct a channel edge shell and
// add that itself to the graph.
chanEdge := ChannelEdgePolicy{
ChannelID: edgeInfo.ChannelID,
LastUpdate: time.Now(),
}
// If their pubkey is larger, then we'll flip the
// direction bit to indicate that us, the "second" node
// is updating their policy.
if !selfIsSmaller {
chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection
}
_, err = updateEdgePolicy(tx, &chanEdge)
if err != nil {
return err
}
chansRestored = append(chansRestored, edgeInfo.ChannelID)
}
return nil
})
if err != nil {
return err
}
for _, chanid := range chansRestored {
chanGraph.rejectCache.remove(chanid)
chanGraph.chanCache.remove(chanid)
}
return nil
}
// AddrsForNode consults the graph and channel database for all addresses known
// to the passed node public key.
func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
var (
linkNode *LinkNode
graphNode LightningNode
)
dbErr := d.View(func(tx *bbolt.Tx) error {
var err error
linkNode, err = fetchLinkNode(tx, nodePub)
if err != nil {
return err
}
// We'll also query the graph for this peer to see if they have
// any addresses that we don't currently have stored within the
// link node database.
nodes := tx.Bucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
compressedPubKey := nodePub.SerializeCompressed()
graphNode, err = fetchLightningNode(nodes, compressedPubKey)
if err != nil && err != ErrGraphNodeNotFound {
// If the node isn't found, then that's OK, as we still
// have the link node data.
return err
}
return nil
})
if dbErr != nil {
return nil, dbErr
}
// Now that we have both sources of addrs for this node, we'll use a
// map to de-duplicate any addresses between the two sources, and
// produce a final list of the combined addrs.
addrs := make(map[string]net.Addr)
for _, addr := range linkNode.Addresses {
addrs[addr.String()] = addr
}
for _, addr := range graphNode.Addresses {
addrs[addr.String()] = addr
}
dedupedAddrs := make([]net.Addr, 0, len(addrs))
for _, addr := range addrs {
dedupedAddrs = append(dedupedAddrs, addr)
}
return dedupedAddrs, nil
}
// syncVersions function is used for safe db version synchronization. It
// applies migration functions to the current database and recovers the
// previous state of db if at least one error/panic appeared during migration.

View file

@ -1,471 +0,0 @@
package migration_01_to_11
import (
"io/ioutil"
"math"
"math/rand"
"net"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
)
func TestOpenWithCreate(t *testing.T) {
t.Parallel()
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName, err := ioutil.TempDir("", "channeldb")
if err != nil {
t.Fatalf("unable to create temp dir: %v", err)
}
defer os.RemoveAll(tempDirName)
// Next, open thereby creating channeldb for the first time.
dbPath := filepath.Join(tempDirName, "cdb")
cdb, err := Open(dbPath)
if err != nil {
t.Fatalf("unable to create channeldb: %v", err)
}
if err := cdb.Close(); err != nil {
t.Fatalf("unable to close channeldb: %v", err)
}
// The path should have been successfully created.
if !fileExists(dbPath) {
t.Fatalf("channeldb failed to create data directory")
}
}
// TestWipe tests that the database wipe operation completes successfully
// and that the buckets are deleted. It also checks that attempts to fetch
// information while the buckets are not set return the correct errors.
func TestWipe(t *testing.T) {
t.Parallel()
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName, err := ioutil.TempDir("", "channeldb")
if err != nil {
t.Fatalf("unable to create temp dir: %v", err)
}
defer os.RemoveAll(tempDirName)
// Next, open thereby creating channeldb for the first time.
dbPath := filepath.Join(tempDirName, "cdb")
cdb, err := Open(dbPath)
if err != nil {
t.Fatalf("unable to create channeldb: %v", err)
}
defer cdb.Close()
if err := cdb.Wipe(); err != nil {
t.Fatalf("unable to wipe channeldb: %v", err)
}
// Check correct errors are returned
_, err = cdb.FetchAllOpenChannels()
if err != ErrNoActiveChannels {
t.Fatalf("fetching open channels: expected '%v' instead got '%v'",
ErrNoActiveChannels, err)
}
_, err = cdb.FetchClosedChannels(false)
if err != ErrNoClosedChannels {
t.Fatalf("fetching closed channels: expected '%v' instead got '%v'",
ErrNoClosedChannels, err)
}
}
// TestFetchClosedChannelForID tests that we are able to properly retrieve a
// ChannelCloseSummary from the DB given a ChannelID.
func TestFetchClosedChannelForID(t *testing.T) {
t.Parallel()
const numChans = 101
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
// Create the test channel state, that we will mutate the index of the
// funding point.
state, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
// Now run through the number of channels, and modify the outpoint index
// to create new channel IDs.
for i := uint32(0); i < numChans; i++ {
// Save the open channel to disk.
state.FundingOutpoint.Index = i
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18556,
}
if err := state.SyncPending(addr, 101); err != nil {
t.Fatalf("unable to save and serialize channel "+
"state: %v", err)
}
// Close the channel. To make sure we retrieve the correct
// summary later, we make them differ in the SettledBalance.
closeSummary := &ChannelCloseSummary{
ChanPoint: state.FundingOutpoint,
RemotePub: state.IdentityPub,
SettledBalance: btcutil.Amount(500 + i),
}
if err := state.CloseChannel(closeSummary); err != nil {
t.Fatalf("unable to close channel: %v", err)
}
}
// Now run though them all again and make sure we are able to retrieve
// summaries from the DB.
for i := uint32(0); i < numChans; i++ {
state.FundingOutpoint.Index = i
// We calculate the ChannelID and use it to fetch the summary.
cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint)
fetchedSummary, err := cdb.FetchClosedChannelForID(cid)
if err != nil {
t.Fatalf("unable to fetch close summary: %v", err)
}
// Make sure we retrieved the correct one by checking the
// SettledBalance.
if fetchedSummary.SettledBalance != btcutil.Amount(500+i) {
t.Fatalf("summaries don't match: expected %v got %v",
btcutil.Amount(500+i),
fetchedSummary.SettledBalance)
}
}
// As a final test we make sure that we get ErrClosedChannelNotFound
// for a ChannelID we didn't add to the DB.
state.FundingOutpoint.Index++
cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint)
_, err = cdb.FetchClosedChannelForID(cid)
if err != ErrClosedChannelNotFound {
t.Fatalf("expected ErrClosedChannelNotFound, instead got: %v", err)
}
}
// TestAddrsForNode tests the we're able to properly obtain all the addresses
// for a target node.
func TestAddrsForNode(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
graph := cdb.ChannelGraph()
// We'll make a test vertex to insert into the database, as the source
// node, but this node will only have half the number of addresses it
// usually does.
testNode, err := createTestVertex(cdb)
if err != nil {
t.Fatalf("unable to create test node: %v", err)
}
testNode.Addresses = []net.Addr{testAddr}
if err := graph.SetSourceNode(testNode); err != nil {
t.Fatalf("unable to set source node: %v", err)
}
// Next, we'll make a link node with the same pubkey, but with an
// additional address.
nodePub, err := testNode.PubKey()
if err != nil {
t.Fatalf("unable to recv node pub: %v", err)
}
linkNode := cdb.NewLinkNode(
wire.MainNet, nodePub, anotherAddr,
)
if err := linkNode.Sync(); err != nil {
t.Fatalf("unable to sync link node: %v", err)
}
// Now that we've created a link node, as well as a vertex for the
// node, we'll query for all its addresses.
nodeAddrs, err := cdb.AddrsForNode(nodePub)
if err != nil {
t.Fatalf("unable to obtain node addrs: %v", err)
}
expectedAddrs := make(map[string]struct{})
expectedAddrs[testAddr.String()] = struct{}{}
expectedAddrs[anotherAddr.String()] = struct{}{}
// Finally, ensure that all the expected addresses are found.
if len(nodeAddrs) != len(expectedAddrs) {
t.Fatalf("expected %v addrs, got %v",
len(expectedAddrs), len(nodeAddrs))
}
for _, addr := range nodeAddrs {
if _, ok := expectedAddrs[addr.String()]; !ok {
t.Fatalf("unexpected addr: %v", addr)
}
}
}
// TestFetchChannel tests that we're able to fetch an arbitrary channel from
// disk.
func TestFetchChannel(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
// Create the test channel state that we'll sync to the database
// shortly.
channelState, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
// Mark the channel as pending, then immediately mark it as open to it
// can be fully visible.
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
if err := channelState.SyncPending(addr, 9); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
err = channelState.MarkAsOpen(lnwire.NewShortChanIDFromInt(99))
if err != nil {
t.Fatalf("unable to mark channel open: %v", err)
}
// Next, attempt to fetch the channel by its chan point.
dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint)
if err != nil {
t.Fatalf("unable to fetch channel: %v", err)
}
// The decoded channel state should be identical to what we stored
// above.
if !reflect.DeepEqual(channelState, dbChannel) {
t.Fatalf("channel state doesn't match:: %v vs %v",
spew.Sdump(channelState), spew.Sdump(dbChannel))
}
// If we attempt to query for a non-exist ante channel, then we should
// get an error.
channelState2, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
channelState2.FundingOutpoint.Index ^= 1
_, err = cdb.FetchChannel(channelState2.FundingOutpoint)
if err == nil {
t.Fatalf("expected query to fail")
}
}
func genRandomChannelShell() (*ChannelShell, error) {
var testPriv [32]byte
if _, err := rand.Read(testPriv[:]); err != nil {
return nil, err
}
_, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:])
var chanPoint wire.OutPoint
if _, err := rand.Read(chanPoint.Hash[:]); err != nil {
return nil, err
}
pub.Curve = nil
chanPoint.Index = uint32(rand.Intn(math.MaxUint16))
chanStatus := ChanStatusDefault | ChanStatusRestored
var shaChainPriv [32]byte
if _, err := rand.Read(testPriv[:]); err != nil {
return nil, err
}
revRoot, err := chainhash.NewHash(shaChainPriv[:])
if err != nil {
return nil, err
}
shaChainProducer := shachain.NewRevocationProducer(*revRoot)
return &ChannelShell{
NodeAddrs: []net.Addr{&net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}},
Chan: &OpenChannel{
chanStatus: chanStatus,
ChainHash: rev,
FundingOutpoint: chanPoint,
ShortChannelID: lnwire.NewShortChanIDFromInt(
uint64(rand.Int63()),
),
IdentityPub: pub,
LocalChanCfg: ChannelConfig{
ChannelConstraints: ChannelConstraints{
CsvDelay: uint16(rand.Int63()),
},
PaymentBasePoint: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: keychain.KeyFamily(rand.Int63()),
Index: uint32(rand.Int63()),
},
},
},
RemoteCurrentRevocation: pub,
IsPending: false,
RevocationStore: shachain.NewRevocationStore(),
RevocationProducer: shaChainProducer,
},
}, nil
}
// TestRestoreChannelShells tests that we're able to insert a partially channel
// populated to disk. This is useful for channel recovery purposes. We should
// find the new channel shell on disk, and also the db should be populated with
// an edge for that channel.
func TestRestoreChannelShells(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
// First, we'll make our channel shell, it will only have the minimal
// amount of information required for us to initiate the data loss
// protection feature.
channelShell, err := genRandomChannelShell()
if err != nil {
t.Fatalf("unable to gen channel shell: %v", err)
}
graph := cdb.ChannelGraph()
// Before we can restore the channel, we'll need to make a source node
// in the graph as the channel edge we create will need to have a
// origin.
testNode, err := createTestVertex(cdb)
if err != nil {
t.Fatalf("unable to create test node: %v", err)
}
if err := graph.SetSourceNode(testNode); err != nil {
t.Fatalf("unable to set source node: %v", err)
}
// With the channel shell constructed, we'll now insert it into the
// database with the restoration method.
if err := cdb.RestoreChannelShells(channelShell); err != nil {
t.Fatalf("unable to restore channel shell: %v", err)
}
// Now that the channel has been inserted, we'll attempt to query for
// it to ensure we can properly locate it via various means.
//
// First, we'll attempt to query for all channels that we have with the
// node public key that was restored.
nodeChans, err := cdb.FetchOpenChannels(channelShell.Chan.IdentityPub)
if err != nil {
t.Fatalf("unable find channel: %v", err)
}
// We should now find a single channel from the database.
if len(nodeChans) != 1 {
t.Fatalf("unable to find restored channel by node "+
"pubkey: %v", err)
}
// Ensure that it isn't possible to modify the commitment state machine
// of this restored channel.
channel := nodeChans[0]
err = channel.UpdateCommitment(nil)
if err != ErrNoRestoredChannelMutation {
t.Fatalf("able to mutate restored channel")
}
err = channel.AppendRemoteCommitChain(nil)
if err != ErrNoRestoredChannelMutation {
t.Fatalf("able to mutate restored channel")
}
err = channel.AdvanceCommitChainTail(nil)
if err != ErrNoRestoredChannelMutation {
t.Fatalf("able to mutate restored channel")
}
// That single channel should have the proper channel point, and also
// the expected set of flags to indicate that it was a restored
// channel.
if nodeChans[0].FundingOutpoint != channelShell.Chan.FundingOutpoint {
t.Fatalf("wrong funding outpoint: expected %v, got %v",
nodeChans[0].FundingOutpoint,
channelShell.Chan.FundingOutpoint)
}
if !nodeChans[0].HasChanStatus(ChanStatusRestored) {
t.Fatalf("node has wrong status flags: %v",
nodeChans[0].chanStatus)
}
// We should also be able to find the channel if we query for it
// directly.
_, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint)
if err != nil {
t.Fatalf("unable to fetch channel: %v", err)
}
// We should also be able to find the link node that was inserted by
// its public key.
linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub)
if err != nil {
t.Fatalf("unable to fetch link node: %v", err)
}
// The node should have the same address, as specified in the channel
// shell.
if reflect.DeepEqual(linkNode.Addresses, channelShell.NodeAddrs) {
t.Fatalf("addr mismach: expected %v, got %v",
linkNode.Addresses, channelShell.NodeAddrs)
}
// Finally, we'll ensure that the edge for the channel was properly
// inserted.
chanInfos, err := graph.FetchChanInfos(
[]uint64{channelShell.Chan.ShortChannelID.ToUint64()},
)
if err != nil {
t.Fatalf("unable to find edges: %v", err)
}
if len(chanInfos) != 1 {
t.Fatalf("wrong amount of chan infos: expected %v got %v",
len(chanInfos), 1)
}
// We should only find a single edge.
if chanInfos[0].Policy1 != nil && chanInfos[0].Policy2 != nil {
t.Fatalf("only a single edge should be inserted: %v", err)
}
}

View file

@ -1 +0,0 @@
package migration_01_to_11

View file

@ -1,55 +1,23 @@
package migration_01_to_11
import (
"errors"
"fmt"
)
var (
// ErrNoChanDBExists is returned when a channel bucket hasn't been
// created.
ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created")
// ErrDBReversion is returned when detecting an attempt to revert to a
// prior database version.
ErrDBReversion = fmt.Errorf("channel db cannot revert to prior version")
// ErrLinkNodesNotFound is returned when node info bucket hasn't been
// created.
ErrLinkNodesNotFound = fmt.Errorf("no link nodes exist")
// ErrNoActiveChannels is returned when there is no active (open)
// channels within the database.
ErrNoActiveChannels = fmt.Errorf("no active channels exist")
// ErrNoPastDeltas is returned when the channel delta bucket hasn't been
// created.
ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas")
// ErrInvoiceNotFound is returned when a targeted invoice can't be
// found.
ErrInvoiceNotFound = fmt.Errorf("unable to locate invoice")
// ErrNoInvoicesCreated is returned when we don't have invoices in
// our database to return.
ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices")
// ErrDuplicateInvoice is returned when an invoice with the target
// payment hash already exists.
ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists")
// ErrNoPaymentsCreated is returned when bucket of payments hasn't been
// created.
ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments")
// ErrNodeNotFound is returned when node bucket exists, but node with
// specific identity can't be found.
ErrNodeNotFound = fmt.Errorf("link node with target identity not found")
// ErrChannelNotFound is returned when we attempt to locate a channel
// for a specific chain, but it is not found.
ErrChannelNotFound = fmt.Errorf("channel not found")
// ErrMetaNotFound is returned when meta bucket hasn't been
// created.
ErrMetaNotFound = fmt.Errorf("unable to locate meta information")
@ -58,22 +26,11 @@ var (
// graph doesn't exist.
ErrGraphNotFound = fmt.Errorf("graph bucket not initialized")
// ErrGraphNeverPruned is returned when graph was never pruned.
ErrGraphNeverPruned = fmt.Errorf("graph never pruned")
// ErrSourceNodeNotSet is returned if the source node of the graph
// hasn't been added The source node is the center node within a
// star-graph.
ErrSourceNodeNotSet = fmt.Errorf("source node does not exist")
// ErrGraphNodesNotFound is returned in case none of the nodes has
// been added in graph node bucket.
ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist")
// ErrGraphNoEdgesFound is returned in case of none of the channel/edges
// has been added in graph edge bucket.
ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist")
// ErrGraphNodeNotFound is returned when we're unable to find the target
// node.
ErrGraphNodeNotFound = fmt.Errorf("unable to find node")
@ -82,17 +39,6 @@ var (
// can't be found.
ErrEdgeNotFound = fmt.Errorf("edge not found")
// ErrZombieEdge is an error returned when we attempt to look up an edge
// but it is marked as a zombie within the zombie index.
ErrZombieEdge = errors.New("edge marked as zombie")
// ErrEdgeAlreadyExist is returned when edge with specific
// channel id can't be added because it already exist.
ErrEdgeAlreadyExist = fmt.Errorf("edge already exist")
// ErrNodeAliasNotFound is returned when alias for node can't be found.
ErrNodeAliasNotFound = fmt.Errorf("alias for node not found")
// ErrUnknownAddressType is returned when a node's addressType is not
// an expected value.
ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved")
@ -101,20 +47,11 @@ var (
// channels it has closed, but it hasn't yet closed any channels.
ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet")
// ErrNoForwardingEvents is returned in the case that a query fails due
// to the log not having any recorded events.
ErrNoForwardingEvents = fmt.Errorf("no recorded forwarding events")
// ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel
// policy field is not found in the db even though its message flags
// indicate it should be.
ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " +
"present")
// ErrChanAlreadyExists is return when the caller attempts to create a
// channel with a channel point that is already present in the
// database.
ErrChanAlreadyExists = fmt.Errorf("channel already exists")
)
// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the

View file

@ -1 +0,0 @@
package migration_01_to_11

View file

@ -1,274 +0,0 @@
package migration_01_to_11
import (
"bytes"
"io"
"sort"
"time"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// forwardingLogBucket is the bucket that we'll use to store the
// forwarding log. The forwarding log contains a time series database
// of the forwarding history of a lightning daemon. Each key within the
// bucket is a timestamp (in nano seconds since the unix epoch), and
// the value a slice of a forwarding event for that timestamp.
forwardingLogBucket = []byte("circuit-fwd-log")
)
const (
// forwardingEventSize is the size of a forwarding event. The breakdown
// is as follows:
//
// * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in
// || 8 byte value out
//
// From the value in and value out, callers can easily compute the
// total fee extract from a forwarding event.
forwardingEventSize = 32
// MaxResponseEvents is the max number of forwarding events that will
// be returned by a single query response. This size was selected to
// safely remain under gRPC's 4MiB message size response limit. As each
// full forwarding event (including the timestamp) is 40 bytes, we can
// safely return 50k entries in a single response.
MaxResponseEvents = 50000
)
// ForwardingLog returns an instance of the ForwardingLog object backed by the
// target database instance.
func (d *DB) ForwardingLog() *ForwardingLog {
return &ForwardingLog{
db: d,
}
}
// ForwardingLog is a time series database that logs the fulfilment of payment
// circuits by a lightning network daemon. The log contains a series of
// forwarding events which map a timestamp to a forwarding event. A forwarding
// event describes which channels were used to create+settle a circuit, and the
// amount involved. Subtracting the outgoing amount from the incoming amount
// reveals the fee charged for the forwarding service.
type ForwardingLog struct {
db *DB
}
// ForwardingEvent is an event in the forwarding log's time series. Each
// forwarding event logs the creation and tear-down of a payment circuit. A
// circuit is created once an incoming HTLC has been fully forwarded, and
// destroyed once the payment has been settled.
type ForwardingEvent struct {
// Timestamp is the settlement time of this payment circuit.
Timestamp time.Time
// IncomingChanID is the incoming channel ID of the payment circuit.
IncomingChanID lnwire.ShortChannelID
// OutgoingChanID is the outgoing channel ID of the payment circuit.
OutgoingChanID lnwire.ShortChannelID
// AmtIn is the amount of the incoming HTLC. Subtracting this from the
// outgoing amount gives the total fees of this payment circuit.
AmtIn lnwire.MilliSatoshi
// AmtOut is the amount of the outgoing HTLC. Subtracting the incoming
// amount from this gives the total fees for this payment circuit.
AmtOut lnwire.MilliSatoshi
}
// encodeForwardingEvent writes out the target forwarding event to the passed
// io.Writer, using the expected DB format. Note that the timestamp isn't
// serialized as this will be the key value within the bucket.
func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
return WriteElements(
w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut,
)
}
// decodeForwardingEvent attempts to decode the raw bytes of a serialized
// forwarding event into the target ForwardingEvent. Note that the timestamp
// won't be decoded, as the caller is expected to set this due to the bucket
// structure of the forwarding log.
func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error {
return ReadElements(
r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut,
)
}
// AddForwardingEvents adds a series of forwarding events to the database.
// Before inserting, the set of events will be sorted according to their
// timestamp. This ensures that all writes to disk are sequential.
func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error {
// Before we create the database transaction, we'll ensure that the set
// of forwarding events are properly sorted according to their
// timestamp.
sort.Slice(events, func(i, j int) bool {
return events[i].Timestamp.Before(events[j].Timestamp)
})
var timestamp [8]byte
return f.db.Batch(func(tx *bbolt.Tx) error {
// First, we'll fetch the bucket that stores our time series
// log.
logBucket, err := tx.CreateBucketIfNotExists(
forwardingLogBucket,
)
if err != nil {
return err
}
// With the bucket obtained, we can now begin to write out the
// series of events.
for _, event := range events {
var eventBytes [forwardingEventSize]byte
eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize])
// First, we'll serialize this timestamp into our
// timestamp buffer.
byteOrder.PutUint64(
timestamp[:], uint64(event.Timestamp.UnixNano()),
)
// With the key encoded, we'll then encode the event
// into our buffer, then write it out to disk.
err := encodeForwardingEvent(eventBuf, &event)
if err != nil {
return err
}
err = logBucket.Put(timestamp[:], eventBuf.Bytes())
if err != nil {
return err
}
}
return nil
})
}
// ForwardingEventQuery represents a query to the forwarding log payment
// circuit time series database. The query allows a caller to retrieve all
// records for a particular time slice, offset in that time slice, limiting the
// total number of responses returned.
type ForwardingEventQuery struct {
// StartTime is the start time of the time slice.
StartTime time.Time
// EndTime is the end time of the time slice.
EndTime time.Time
// IndexOffset is the offset within the time slice to start at. This
// can be used to start the response at a particular record.
IndexOffset uint32
// NumMaxEvents is the max number of events to return.
NumMaxEvents uint32
}
// ForwardingLogTimeSlice is the response to a forwarding query. It includes
// the original query, the set events that match the query, and an integer
// which represents the offset index of the last item in the set of retuned
// events. This integer allows callers to resume their query using this offset
// in the event that the query's response exceeds the max number of returnable
// events.
type ForwardingLogTimeSlice struct {
ForwardingEventQuery
// ForwardingEvents is the set of events in our time series that answer
// the query embedded above.
ForwardingEvents []ForwardingEvent
// LastIndexOffset is the index of the last element in the set of
// returned ForwardingEvents above. Callers can use this to resume
// their query in the event that the time slice has too many events to
// fit into a single response.
LastIndexOffset uint32
}
// Query allows a caller to query the forwarding event time series for a
// particular time slice. The caller can control the precise time as well as
// the number of events to be returned.
//
// TODO(roasbeef): rename?
func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) {
resp := ForwardingLogTimeSlice{
ForwardingEventQuery: q,
}
// If the user provided an index offset, then we'll not know how many
// records we need to skip. We'll also keep track of the record offset
// as that's part of the final return value.
recordsToSkip := q.IndexOffset
recordOffset := q.IndexOffset
err := f.db.View(func(tx *bbolt.Tx) error {
// If the bucket wasn't found, then there aren't any events to
// be returned.
logBucket := tx.Bucket(forwardingLogBucket)
if logBucket == nil {
return ErrNoForwardingEvents
}
// We'll be using a cursor to seek into the database, so we'll
// populate byte slices that represent the start of the key
// space we're interested in, and the end.
var startTime, endTime [8]byte
byteOrder.PutUint64(startTime[:], uint64(q.StartTime.UnixNano()))
byteOrder.PutUint64(endTime[:], uint64(q.EndTime.UnixNano()))
// If we know that a set of log events exists, then we'll begin
// our seek through the log in order to satisfy the query.
// We'll continue until either we reach the end of the range,
// or reach our max number of events.
logCursor := logBucket.Cursor()
timestamp, events := logCursor.Seek(startTime[:])
for ; timestamp != nil && bytes.Compare(timestamp, endTime[:]) <= 0; timestamp, events = logCursor.Next() {
// If our current return payload exceeds the max number
// of events, then we'll exit now.
if uint32(len(resp.ForwardingEvents)) >= q.NumMaxEvents {
return nil
}
// If we're not yet past the user defined offset, then
// we'll continue to seek forward.
if recordsToSkip > 0 {
recordsToSkip--
continue
}
currentTime := time.Unix(
0, int64(byteOrder.Uint64(timestamp)),
)
// At this point, we've skipped enough records to start
// to collate our query. For each record, we'll
// increment the final record offset so the querier can
// utilize pagination to seek further.
readBuf := bytes.NewReader(events)
for readBuf.Len() != 0 {
var event ForwardingEvent
err := decodeForwardingEvent(readBuf, &event)
if err != nil {
return err
}
event.Timestamp = currentTime
resp.ForwardingEvents = append(resp.ForwardingEvents, event)
recordOffset++
}
}
return nil
})
if err != nil && err != ErrNoForwardingEvents {
return ForwardingLogTimeSlice{}, err
}
resp.LastIndexOffset = recordOffset
return resp, nil
}

View file

@ -1,265 +0,0 @@
package migration_01_to_11
import (
"math/rand"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lnwire"
"time"
)
// TestForwardingLogBasicStorageAndQuery tests that we're able to store and
// then query for items that have previously been added to the event log.
func TestForwardingLogBasicStorageAndQuery(t *testing.T) {
t.Parallel()
// First, we'll set up a test database, and use that to instantiate the
// forwarding event log that we'll be using for the duration of the
// test.
db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
log := ForwardingLog{
db: db,
}
initialTime := time.Unix(1234, 0)
timestamp := time.Unix(1234, 0)
// We'll create 100 random events, which each event being spaced 10
// minutes after the prior event.
numEvents := 100
events := make([]ForwardingEvent, numEvents)
for i := 0; i < numEvents; i++ {
events[i] = ForwardingEvent{
Timestamp: timestamp,
IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
}
timestamp = timestamp.Add(time.Minute * 10)
}
// Now that all of our set of events constructed, we'll add them to the
// database in a batch manner.
if err := log.AddForwardingEvents(events); err != nil {
t.Fatalf("unable to add events: %v", err)
}
// With our events added we'll now construct a basic query to retrieve
// all of the events.
eventQuery := ForwardingEventQuery{
StartTime: initialTime,
EndTime: timestamp,
IndexOffset: 0,
NumMaxEvents: 1000,
}
timeSlice, err := log.Query(eventQuery)
if err != nil {
t.Fatalf("unable to query for events: %v", err)
}
// The set of returned events should match identically, as they should
// be returned in sorted order.
if !reflect.DeepEqual(events, timeSlice.ForwardingEvents) {
t.Fatalf("event mismatch: expected %v vs %v",
spew.Sdump(events), spew.Sdump(timeSlice.ForwardingEvents))
}
// The offset index of the final entry should be numEvents, so the
// number of total events we've written.
if timeSlice.LastIndexOffset != uint32(numEvents) {
t.Fatalf("wrong final offset: expected %v, got %v",
timeSlice.LastIndexOffset, numEvents)
}
}
// TestForwardingLogQueryOptions tests that the query offset works properly. So
// if we add a series of events, then we should be able to seek within the
// timeslice accordingly. This exercises the index offset and num max event
// field in the query, and also the last index offset field int he response.
func TestForwardingLogQueryOptions(t *testing.T) {
t.Parallel()
// First, we'll set up a test database, and use that to instantiate the
// forwarding event log that we'll be using for the duration of the
// test.
db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
log := ForwardingLog{
db: db,
}
initialTime := time.Unix(1234, 0)
endTime := time.Unix(1234, 0)
// We'll create 20 random events, which each event being spaced 10
// minutes after the prior event.
numEvents := 20
events := make([]ForwardingEvent, numEvents)
for i := 0; i < numEvents; i++ {
events[i] = ForwardingEvent{
Timestamp: endTime,
IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
}
endTime = endTime.Add(time.Minute * 10)
}
// Now that all of our set of events constructed, we'll add them to the
// database in a batch manner.
if err := log.AddForwardingEvents(events); err != nil {
t.Fatalf("unable to add events: %v", err)
}
// With all of our events added, we should be able to query for the
// first 10 events using the max event query field.
eventQuery := ForwardingEventQuery{
StartTime: initialTime,
EndTime: endTime,
IndexOffset: 0,
NumMaxEvents: 10,
}
timeSlice, err := log.Query(eventQuery)
if err != nil {
t.Fatalf("unable to query for events: %v", err)
}
// We should get exactly 10 events back.
if len(timeSlice.ForwardingEvents) != 10 {
t.Fatalf("wrong number of events: expected %v, got %v", 10,
len(timeSlice.ForwardingEvents))
}
// The set of events returned should be the first 10 events that we
// added.
if !reflect.DeepEqual(events[:10], timeSlice.ForwardingEvents) {
t.Fatalf("wrong response: expected %v, got %v",
spew.Sdump(events[:10]),
spew.Sdump(timeSlice.ForwardingEvents))
}
// The final offset should be the exact number of events returned.
if timeSlice.LastIndexOffset != 10 {
t.Fatalf("wrong index offset: expected %v, got %v", 10,
timeSlice.LastIndexOffset)
}
// If we use the final offset to query again, then we should get 10
// more events, that are the last 10 events we wrote.
eventQuery.IndexOffset = 10
timeSlice, err = log.Query(eventQuery)
if err != nil {
t.Fatalf("unable to query for events: %v", err)
}
// We should get exactly 10 events back once again.
if len(timeSlice.ForwardingEvents) != 10 {
t.Fatalf("wrong number of events: expected %v, got %v", 10,
len(timeSlice.ForwardingEvents))
}
// The events that we got back should be the last 10 events that we
// wrote out.
if !reflect.DeepEqual(events[10:], timeSlice.ForwardingEvents) {
t.Fatalf("wrong response: expected %v, got %v",
spew.Sdump(events[10:]),
spew.Sdump(timeSlice.ForwardingEvents))
}
// Finally, the last index offset should be 20, or the number of
// records we've written out.
if timeSlice.LastIndexOffset != 20 {
t.Fatalf("wrong index offset: expected %v, got %v", 20,
timeSlice.LastIndexOffset)
}
}
// TestForwardingLogQueryLimit tests that we're able to properly limit the
// number of events that are returned as part of a query.
func TestForwardingLogQueryLimit(t *testing.T) {
t.Parallel()
// First, we'll set up a test database, and use that to instantiate the
// forwarding event log that we'll be using for the duration of the
// test.
db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
log := ForwardingLog{
db: db,
}
initialTime := time.Unix(1234, 0)
endTime := time.Unix(1234, 0)
// We'll create 200 random events, which each event being spaced 10
// minutes after the prior event.
numEvents := 200
events := make([]ForwardingEvent, numEvents)
for i := 0; i < numEvents; i++ {
events[i] = ForwardingEvent{
Timestamp: endTime,
IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
}
endTime = endTime.Add(time.Minute * 10)
}
// Now that all of our set of events constructed, we'll add them to the
// database in a batch manner.
if err := log.AddForwardingEvents(events); err != nil {
t.Fatalf("unable to add events: %v", err)
}
// Once the events have been written out, we'll issue a query over the
// entire range, but restrict the number of events to the first 100.
eventQuery := ForwardingEventQuery{
StartTime: initialTime,
EndTime: endTime,
IndexOffset: 0,
NumMaxEvents: 100,
}
timeSlice, err := log.Query(eventQuery)
if err != nil {
t.Fatalf("unable to query for events: %v", err)
}
// We should get exactly 100 events back.
if len(timeSlice.ForwardingEvents) != 100 {
t.Fatalf("wrong number of events: expected %v, got %v", 10,
len(timeSlice.ForwardingEvents))
}
// The set of events returned should be the first 100 events that we
// added.
if !reflect.DeepEqual(events[:100], timeSlice.ForwardingEvents) {
t.Fatalf("wrong response: expected %v, got %v",
spew.Sdump(events[:100]),
spew.Sdump(timeSlice.ForwardingEvents))
}
// The final offset should be the exact number of events returned.
if timeSlice.LastIndexOffset != 100 {
t.Fatalf("wrong index offset: expected %v, got %v", 100,
timeSlice.LastIndexOffset)
}
}

View file

@ -1,928 +0,0 @@
package migration_01_to_11
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/lnwire"
)
// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding
// package has potentially been mangled.
var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted")
// FwdState is an enum used to describe the lifecycle of a FwdPkg.
type FwdState byte
const (
// FwdStateLockedIn is the starting state for all forwarding packages.
// Packages in this state have not yet committed to the exact set of
// Adds to forward to the switch.
FwdStateLockedIn FwdState = iota
// FwdStateProcessed marks the state in which all Adds have been
// locally processed and the forwarding decision to the switch has been
// persisted.
FwdStateProcessed
// FwdStateCompleted signals that all Adds have been acked, and that all
// settles and fails have been delivered to their sources. Packages in
// this state can be removed permanently.
FwdStateCompleted
)
var (
// fwdPackagesKey is the root-level bucket that all forwarding packages
// are written. This bucket is further subdivided based on the short
// channel ID of each channel.
fwdPackagesKey = []byte("fwd-packages")
// addBucketKey is the bucket to which all Add log updates are written.
addBucketKey = []byte("add-updates")
// failSettleBucketKey is the bucket to which all Settle/Fail log
// updates are written.
failSettleBucketKey = []byte("fail-settle-updates")
// fwdFilterKey is a key used to write the set of Adds that passed
// validation and are to be forwarded to the switch.
// NOTE: The presence of this key within a forwarding package indicates
// that the package has reached FwdStateProcessed.
fwdFilterKey = []byte("fwd-filter-key")
// ackFilterKey is a key used to access the PkgFilter indicating which
// Adds have received a Settle/Fail. This response may come from a
// number of sources, including: exitHop settle/fails, switch failures,
// chain arbiter interjections, as well as settle/fails from the
// next hop in the route.
ackFilterKey = []byte("ack-filter-key")
// settleFailFilterKey is a key used to access the PkgFilter indicating
// which Settles/Fails in have been received and processed by the link
// that originally received the Add.
settleFailFilterKey = []byte("settle-fail-filter-key")
)
// PkgFilter is used to compactly represent a particular subset of the Adds in a
// forwarding package. Each filter is represented as a simple, statically-sized
// bitvector, where the elements are intended to be the indices of the Adds as
// they are written in the FwdPkg.
type PkgFilter struct {
count uint16
filter []byte
}
// NewPkgFilter initializes an empty PkgFilter supporting `count` elements.
func NewPkgFilter(count uint16) *PkgFilter {
// We add 7 to ensure that the integer division yields properly rounded
// values.
filterLen := (count + 7) / 8
return &PkgFilter{
count: count,
filter: make([]byte, filterLen),
}
}
// Count returns the number of elements represented by this PkgFilter.
func (f *PkgFilter) Count() uint16 {
return f.count
}
// Set marks the `i`-th element as included by this filter.
// NOTE: It is assumed that i is always less than count.
func (f *PkgFilter) Set(i uint16) {
byt := i / 8
bit := i % 8
// Set the i-th bit in the filter.
// TODO(conner): ignore if > count to prevent panic?
f.filter[byt] |= byte(1 << (7 - bit))
}
// Contains queries the filter for membership of index `i`.
// NOTE: It is assumed that i is always less than count.
func (f *PkgFilter) Contains(i uint16) bool {
byt := i / 8
bit := i % 8
// Read the i-th bit in the filter.
// TODO(conner): ignore if > count to prevent panic?
return f.filter[byt]&(1<<(7-bit)) != 0
}
// Equal checks two PkgFilters for equality.
func (f *PkgFilter) Equal(f2 *PkgFilter) bool {
if f == f2 {
return true
}
if f.count != f2.count {
return false
}
return bytes.Equal(f.filter, f2.filter)
}
// IsFull returns true if every element in the filter has been Set, and false
// otherwise.
func (f *PkgFilter) IsFull() bool {
// Batch validate bytes that are fully used.
for i := uint16(0); i < f.count/8; i++ {
if f.filter[i] != 0xFF {
return false
}
}
// If the count is not a multiple of 8, check that the filter contains
// all remaining bits.
rem := f.count % 8
for idx := f.count - rem; idx < f.count; idx++ {
if !f.Contains(idx) {
return false
}
}
return true
}
// Size returns number of bytes produced when the PkgFilter is serialized.
func (f *PkgFilter) Size() uint16 {
// 2 bytes for uint16 `count`, then round up number of bytes required to
// represent `count` bits.
return 2 + (f.count+7)/8
}
// Encode writes the filter to the provided io.Writer.
func (f *PkgFilter) Encode(w io.Writer) error {
if err := binary.Write(w, binary.BigEndian, f.count); err != nil {
return err
}
_, err := w.Write(f.filter)
return err
}
// Decode reads the filter from the provided io.Reader.
func (f *PkgFilter) Decode(r io.Reader) error {
if err := binary.Read(r, binary.BigEndian, &f.count); err != nil {
return err
}
f.filter = make([]byte, f.Size()-2)
_, err := io.ReadFull(r, f.filter)
return err
}
// FwdPkg records all adds, settles, and fails that were locked in as a result
// of the remote peer sending us a revocation. Each package is identified by
// the short chanid and remote commitment height corresponding to the revocation
// that locked in the HTLCs. For everything except a locally initiated payment,
// settles and fails in a forwarding package must have a corresponding Add in
// another package, and can be removed individually once the source link has
// received the fail/settle.
//
// Adds cannot be removed, as we need to present the same batch of Adds to
// properly handle replay protection. Instead, we use a PkgFilter to mark that
// we have finished processing a particular Add. A FwdPkg should only be deleted
// after the AckFilter is full and all settles and fails have been persistently
// removed.
type FwdPkg struct {
// Source identifies the channel that wrote this forwarding package.
Source lnwire.ShortChannelID
// Height is the height of the remote commitment chain that locked in
// this forwarding package.
Height uint64
// State signals the persistent condition of the package and directs how
// to reprocess the package in the event of failures.
State FwdState
// Adds contains all add messages which need to be processed and
// forwarded to the switch. Adds does not change over the life of a
// forwarding package.
Adds []LogUpdate
// FwdFilter is a filter containing the indices of all Adds that were
// forwarded to the switch.
FwdFilter *PkgFilter
// AckFilter is a filter containing the indices of all Adds for which
// the source has received a settle or fail and is reflected in the next
// commitment txn. A package should not be removed until IsFull()
// returns true.
AckFilter *PkgFilter
// SettleFails contains all settle and fail messages that should be
// forwarded to the switch.
SettleFails []LogUpdate
// SettleFailFilter is a filter containing the indices of all Settle or
// Fails originating in this package that have been received and locked
// into the incoming link's commitment state.
SettleFailFilter *PkgFilter
}
// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This
// should be used to create a package at the time we receive a revocation.
func NewFwdPkg(source lnwire.ShortChannelID, height uint64,
addUpdates, settleFailUpdates []LogUpdate) *FwdPkg {
nAddUpdates := uint16(len(addUpdates))
nSettleFailUpdates := uint16(len(settleFailUpdates))
return &FwdPkg{
Source: source,
Height: height,
State: FwdStateLockedIn,
Adds: addUpdates,
FwdFilter: NewPkgFilter(nAddUpdates),
AckFilter: NewPkgFilter(nAddUpdates),
SettleFails: settleFailUpdates,
SettleFailFilter: NewPkgFilter(nSettleFailUpdates),
}
}
// ID returns an unique identifier for this package, used to ensure that sphinx
// replay processing of this batch is idempotent.
func (f *FwdPkg) ID() []byte {
var id = make([]byte, 16)
byteOrder.PutUint64(id[:8], f.Source.ToUint64())
byteOrder.PutUint64(id[8:], f.Height)
return id
}
// String returns a human-readable description of the forwarding package.
func (f *FwdPkg) String() string {
return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)",
f, f.Source, f.Height, len(f.Adds), len(f.SettleFails))
}
// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID
// is assumed to be that of the packager.
type AddRef struct {
// Height is the remote commitment height that locked in the Add.
Height uint64
// Index is the index of the Add within the fwd pkg's Adds.
//
// NOTE: This index is static over the lifetime of a forwarding package.
Index uint16
}
// Encode serializes the AddRef to the given io.Writer.
func (a *AddRef) Encode(w io.Writer) error {
if err := binary.Write(w, binary.BigEndian, a.Height); err != nil {
return err
}
return binary.Write(w, binary.BigEndian, a.Index)
}
// Decode deserializes the AddRef from the given io.Reader.
func (a *AddRef) Decode(r io.Reader) error {
if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil {
return err
}
return binary.Read(r, binary.BigEndian, &a.Index)
}
// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A
// channel does not remove its own Settle/Fail htlcs, so the source is provided
// to locate a db bucket belonging to another channel.
type SettleFailRef struct {
// Source identifies the outgoing link that locked in the settle or
// fail. This is then used by the *incoming* link to find the settle
// fail in another link's forwarding packages.
Source lnwire.ShortChannelID
// Height is the remote commitment height that locked in this
// Settle/Fail.
Height uint64
// Index is the index of the Add with the fwd pkg's SettleFails.
//
// NOTE: This index is static over the lifetime of a forwarding package.
Index uint16
}
// SettleFailAcker is a generic interface providing the ability to acknowledge
// settle/fail HTLCs stored in forwarding packages.
type SettleFailAcker interface {
// AckSettleFails atomically updates the settle-fail filters in *other*
// channels' forwarding packages.
AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error
}
// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages
// of any active channel.
type GlobalFwdPkgReader interface {
// LoadChannelFwdPkgs loads all known forwarding packages for the given
// channel.
LoadChannelFwdPkgs(tx *bbolt.Tx,
source lnwire.ShortChannelID) ([]*FwdPkg, error)
}
// FwdOperator defines the interfaces for managing forwarding packages that are
// external to a particular channel. This interface is used by the switch to
// read forwarding packages from arbitrary channels, and acknowledge settles and
// fails for locally-sourced payments.
type FwdOperator interface {
// GlobalFwdPkgReader provides read access to all known forwarding
// packages
GlobalFwdPkgReader
// SettleFailAcker grants the ability to acknowledge settles or fails
// residing in arbitrary forwarding packages.
SettleFailAcker
}
// SwitchPackager is a concrete implementation of the FwdOperator interface.
// A SwitchPackager offers the ability to read any forwarding package, and ack
// arbitrary settle and fail HTLCs.
type SwitchPackager struct{}
// NewSwitchPackager instantiates a new SwitchPackager.
func NewSwitchPackager() *SwitchPackager {
return &SwitchPackager{}
}
// AckSettleFails atomically updates the settle-fail filters in *other*
// channels' forwarding packages, to mark that the switch has received a settle
// or fail residing in the forwarding package of a link.
func (*SwitchPackager) AckSettleFails(tx *bbolt.Tx,
settleFailRefs ...SettleFailRef) error {
return ackSettleFails(tx, settleFailRefs)
}
// LoadChannelFwdPkgs loads all forwarding packages for a particular channel.
func (*SwitchPackager) LoadChannelFwdPkgs(tx *bbolt.Tx,
source lnwire.ShortChannelID) ([]*FwdPkg, error) {
return loadChannelFwdPkgs(tx, source)
}
// FwdPackager supports all operations required to modify fwd packages, such as
// creation, updates, reading, and removal. The interfaces are broken down in
// this way to support future delegation of the subinterfaces.
type FwdPackager interface {
// AddFwdPkg serializes and writes a FwdPkg for this channel at the
// remote commitment height included in the forwarding package.
AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error
// SetFwdFilter looks up the forwarding package at the remote `height`
// and sets the `fwdFilter`, marking the Adds for which:
// 1) We are not the exit node
// 2) Passed all validation
// 3) Should be forwarded to the switch immediately after a failure
SetFwdFilter(tx *bbolt.Tx, height uint64, fwdFilter *PkgFilter) error
// AckAddHtlcs atomically updates the add filters in this channel's
// forwarding packages to mark the resolution of an Add that was
// received from the remote party.
AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error
// SettleFailAcker allows a link to acknowledge settle/fail HTLCs
// belonging to other channels.
SettleFailAcker
// LoadFwdPkgs loads all known forwarding packages owned by this
// channel.
LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error)
// RemovePkg deletes a forwarding package owned by this channel at
// the provided remote `height`.
RemovePkg(tx *bbolt.Tx, height uint64) error
}
// ChannelPackager is used by a channel to manage the lifecycle of its forwarding
// packages. The packager is tied to a particular source channel ID, allowing it
// to create and edit its own packages. Each packager also has the ability to
// remove fail/settle htlcs that correspond to an add contained in one of
// source's packages.
type ChannelPackager struct {
source lnwire.ShortChannelID
}
// NewChannelPackager creates a new packager for a single channel.
func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager {
return &ChannelPackager{
source: source,
}
}
// AddFwdPkg writes a newly locked in forwarding package to disk.
func (*ChannelPackager) AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error {
fwdPkgBkt, err := tx.CreateBucketIfNotExists(fwdPackagesKey)
if err != nil {
return err
}
source := makeLogKey(fwdPkg.Source.ToUint64())
sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:])
if err != nil {
return err
}
heightKey := makeLogKey(fwdPkg.Height)
heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:])
if err != nil {
return err
}
// Write ADD updates we received at this commit height.
addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey)
if err != nil {
return err
}
// Write SETTLE/FAIL updates we received at this commit height.
failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey)
if err != nil {
return err
}
for i := range fwdPkg.Adds {
err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i])
if err != nil {
return err
}
}
// Persist the initialized pkg filter, which will be used to determine
// when we can remove this forwarding package from disk.
var ackFilterBuf bytes.Buffer
if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil {
return err
}
if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil {
return err
}
for i := range fwdPkg.SettleFails {
err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i])
if err != nil {
return err
}
}
var settleFailFilterBuf bytes.Buffer
err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf)
if err != nil {
return err
}
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
}
// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key.
func putLogUpdate(bkt *bbolt.Bucket, idx uint16, htlc *LogUpdate) error {
var b bytes.Buffer
if err := htlc.Encode(&b); err != nil {
return err
}
return bkt.Put(uint16Key(idx), b.Bytes())
}
// LoadFwdPkgs scans the forwarding log for any packages that haven't been
// processed, and returns their deserialized log updates in a map indexed by the
// remote commitment height at which the updates were locked in.
func (p *ChannelPackager) LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) {
return loadChannelFwdPkgs(tx, p.source)
}
// loadChannelFwdPkgs loads all forwarding packages owned by `source`.
func loadChannelFwdPkgs(tx *bbolt.Tx, source lnwire.ShortChannelID) ([]*FwdPkg, error) {
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return nil, nil
}
sourceKey := makeLogKey(source.ToUint64())
sourceBkt := fwdPkgBkt.Bucket(sourceKey[:])
if sourceBkt == nil {
return nil, nil
}
var heights []uint64
if err := sourceBkt.ForEach(func(k, _ []byte) error {
if len(k) != 8 {
return ErrCorruptedFwdPkg
}
heights = append(heights, byteOrder.Uint64(k))
return nil
}); err != nil {
return nil, err
}
// Load the forwarding package for each retrieved height.
fwdPkgs := make([]*FwdPkg, 0, len(heights))
for _, height := range heights {
fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height)
if err != nil {
return nil, err
}
fwdPkgs = append(fwdPkgs, fwdPkg)
}
return fwdPkgs, nil
}
// loadFwPkg reads the packager's fwd pkg at a given height, and determines the
// appropriate FwdState.
func loadFwdPkg(fwdPkgBkt *bbolt.Bucket, source lnwire.ShortChannelID,
height uint64) (*FwdPkg, error) {
sourceKey := makeLogKey(source.ToUint64())
sourceBkt := fwdPkgBkt.Bucket(sourceKey[:])
if sourceBkt == nil {
return nil, ErrCorruptedFwdPkg
}
heightKey := makeLogKey(height)
heightBkt := sourceBkt.Bucket(heightKey[:])
if heightBkt == nil {
return nil, ErrCorruptedFwdPkg
}
// Load ADDs from disk.
addBkt := heightBkt.Bucket(addBucketKey)
if addBkt == nil {
return nil, ErrCorruptedFwdPkg
}
adds, err := loadHtlcs(addBkt)
if err != nil {
return nil, err
}
// Load ack filter from disk.
ackFilterBytes := heightBkt.Get(ackFilterKey)
if ackFilterBytes == nil {
return nil, ErrCorruptedFwdPkg
}
ackFilterReader := bytes.NewReader(ackFilterBytes)
ackFilter := &PkgFilter{}
if err := ackFilter.Decode(ackFilterReader); err != nil {
return nil, err
}
// Load SETTLE/FAILs from disk.
failSettleBkt := heightBkt.Bucket(failSettleBucketKey)
if failSettleBkt == nil {
return nil, ErrCorruptedFwdPkg
}
failSettles, err := loadHtlcs(failSettleBkt)
if err != nil {
return nil, err
}
// Load settle fail filter from disk.
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
if settleFailFilterBytes == nil {
return nil, ErrCorruptedFwdPkg
}
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
settleFailFilter := &PkgFilter{}
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
return nil, err
}
// Initialize the fwding package, which always starts in the
// FwdStateLockedIn. We can determine what state the package was left in
// by examining constraints on the information loaded from disk.
fwdPkg := &FwdPkg{
Source: source,
State: FwdStateLockedIn,
Height: height,
Adds: adds,
AckFilter: ackFilter,
SettleFails: failSettles,
SettleFailFilter: settleFailFilter,
}
// Check to see if we have written the set exported filter adds to
// disk. If we haven't, processing of this package was never started, or
// failed during the last attempt.
fwdFilterBytes := heightBkt.Get(fwdFilterKey)
if fwdFilterBytes == nil {
nAdds := uint16(len(adds))
fwdPkg.FwdFilter = NewPkgFilter(nAdds)
return fwdPkg, nil
}
fwdFilterReader := bytes.NewReader(fwdFilterBytes)
fwdPkg.FwdFilter = &PkgFilter{}
if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil {
return nil, err
}
// Otherwise, a complete round of processing was completed, and we
// advance the package to FwdStateProcessed.
fwdPkg.State = FwdStateProcessed
// If every add, settle, and fail has been fully acknowledged, we can
// safely set the package's state to FwdStateCompleted, signalling that
// it can be garbage collected.
if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() {
fwdPkg.State = FwdStateCompleted
}
return fwdPkg, nil
}
// loadHtlcs retrieves all serialized htlcs in a bucket, returning
// them in order of the indexes they were written under.
func loadHtlcs(bkt *bbolt.Bucket) ([]LogUpdate, error) {
var htlcs []LogUpdate
if err := bkt.ForEach(func(_, v []byte) error {
var htlc LogUpdate
if err := htlc.Decode(bytes.NewReader(v)); err != nil {
return err
}
htlcs = append(htlcs, htlc)
return nil
}); err != nil {
return nil, err
}
return htlcs, nil
}
// SetFwdFilter writes the set of indexes corresponding to Adds at the
// `height` that are to be forwarded to the switch. Calling this method causes
// the forwarding package at `height` to be in FwdStateProcessed. We write this
// forwarding decision so that we always arrive at the same behavior for HTLCs
// leaving this channel. After a restart, we skip validation of these Adds,
// since they are assumed to have already been validated, and make the switch or
// outgoing link responsible for handling replays.
func (p *ChannelPackager) SetFwdFilter(tx *bbolt.Tx, height uint64,
fwdFilter *PkgFilter) error {
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return ErrCorruptedFwdPkg
}
source := makeLogKey(p.source.ToUint64())
sourceBkt := fwdPkgBkt.Bucket(source[:])
if sourceBkt == nil {
return ErrCorruptedFwdPkg
}
heightKey := makeLogKey(height)
heightBkt := sourceBkt.Bucket(heightKey[:])
if heightBkt == nil {
return ErrCorruptedFwdPkg
}
// If the fwd filter has already been written, we return early to avoid
// modifying the persistent state.
forwardedAddsBytes := heightBkt.Get(fwdFilterKey)
if forwardedAddsBytes != nil {
return nil
}
// Otherwise we serialize and write the provided fwd filter.
var b bytes.Buffer
if err := fwdFilter.Encode(&b); err != nil {
return err
}
return heightBkt.Put(fwdFilterKey, b.Bytes())
}
// AckAddHtlcs accepts a list of references to add htlcs, and updates the
// AckAddFilter of those forwarding packages to indicate that a settle or fail
// has been received in response to the add.
func (p *ChannelPackager) AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error {
if len(addRefs) == 0 {
return nil
}
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return ErrCorruptedFwdPkg
}
sourceKey := makeLogKey(p.source.ToUint64())
sourceBkt := fwdPkgBkt.Bucket(sourceKey[:])
if sourceBkt == nil {
return ErrCorruptedFwdPkg
}
// Organize the forward references such that we just get a single slice
// of indexes for each unique height.
heightDiffs := make(map[uint64][]uint16)
for _, addRef := range addRefs {
heightDiffs[addRef.Height] = append(
heightDiffs[addRef.Height],
addRef.Index,
)
}
// Load each height bucket once and remove all acked htlcs at that
// height.
for height, indexes := range heightDiffs {
err := ackAddHtlcsAtHeight(sourceBkt, height, indexes)
if err != nil {
return err
}
}
return nil
}
// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package
// with a list of indexes, writing the resulting filter back in its place.
func ackAddHtlcsAtHeight(sourceBkt *bbolt.Bucket, height uint64,
indexes []uint16) error {
heightKey := makeLogKey(height)
heightBkt := sourceBkt.Bucket(heightKey[:])
if heightBkt == nil {
// If the height bucket isn't found, this could be because the
// forwarding package was already removed. We'll return nil to
// signal that the operation is successful, as there is nothing
// to ack.
return nil
}
// Load ack filter from disk.
ackFilterBytes := heightBkt.Get(ackFilterKey)
if ackFilterBytes == nil {
return ErrCorruptedFwdPkg
}
ackFilter := &PkgFilter{}
ackFilterReader := bytes.NewReader(ackFilterBytes)
if err := ackFilter.Decode(ackFilterReader); err != nil {
return err
}
// Update the ack filter for this height.
for _, index := range indexes {
ackFilter.Set(index)
}
// Write the resulting filter to disk.
var ackFilterBuf bytes.Buffer
if err := ackFilter.Encode(&ackFilterBuf); err != nil {
return err
}
return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes())
}
// AckSettleFails persistently acknowledges settles or fails from a remote forwarding
// package. This should only be called after the source of the Add has locked in
// the settle/fail, or it becomes otherwise safe to forgo retransmitting the
// settle/fail after a restart.
func (p *ChannelPackager) AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error {
return ackSettleFails(tx, settleFailRefs)
}
// ackSettleFails persistently acknowledges a batch of settle fail references.
func ackSettleFails(tx *bbolt.Tx, settleFailRefs []SettleFailRef) error {
if len(settleFailRefs) == 0 {
return nil
}
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return ErrCorruptedFwdPkg
}
// Organize the forward references such that we just get a single slice
// of indexes for each unique destination-height pair.
destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16)
for _, settleFailRef := range settleFailRefs {
destHeights, ok := destHeightDiffs[settleFailRef.Source]
if !ok {
destHeights = make(map[uint64][]uint16)
destHeightDiffs[settleFailRef.Source] = destHeights
}
destHeights[settleFailRef.Height] = append(
destHeights[settleFailRef.Height],
settleFailRef.Index,
)
}
// With the references organized by destination and height, we now load
// each remote bucket, and update the settle fail filter for any
// settle/fail htlcs.
for dest, destHeights := range destHeightDiffs {
destKey := makeLogKey(dest.ToUint64())
destBkt := fwdPkgBkt.Bucket(destKey[:])
if destBkt == nil {
// If the destination bucket is not found, this is
// likely the result of the destination channel being
// closed and having it's forwarding packages wiped. We
// won't treat this as an error, because the response
// will no longer be retransmitted internally.
continue
}
for height, indexes := range destHeights {
err := ackSettleFailsAtHeight(destBkt, height, indexes)
if err != nil {
return err
}
}
}
return nil
}
// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes
// at particular a height by updating the settle fail filter.
func ackSettleFailsAtHeight(destBkt *bbolt.Bucket, height uint64,
indexes []uint16) error {
heightKey := makeLogKey(height)
heightBkt := destBkt.Bucket(heightKey[:])
if heightBkt == nil {
// If the height bucket isn't found, this could be because the
// forwarding package was already removed. We'll return nil to
// signal that the operation is as there is nothing to ack.
return nil
}
// Load ack filter from disk.
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
if settleFailFilterBytes == nil {
return ErrCorruptedFwdPkg
}
settleFailFilter := &PkgFilter{}
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
return err
}
// Update the ack filter for this height.
for _, index := range indexes {
settleFailFilter.Set(index)
}
// Write the resulting filter to disk.
var settleFailFilterBuf bytes.Buffer
if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil {
return err
}
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
}
// RemovePkg deletes the forwarding package at the given height from the
// packager's source bucket.
func (p *ChannelPackager) RemovePkg(tx *bbolt.Tx, height uint64) error {
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return nil
}
sourceBytes := makeLogKey(p.source.ToUint64())
sourceBkt := fwdPkgBkt.Bucket(sourceBytes[:])
if sourceBkt == nil {
return ErrCorruptedFwdPkg
}
heightKey := makeLogKey(height)
return sourceBkt.DeleteBucket(heightKey[:])
}
// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice.
func uint16Key(i uint16) []byte {
key := make([]byte, 2)
byteOrder.PutUint16(key, i)
return key
}
// Compile-time constraint to ensure that ChannelPackager implements the public
// FwdPackager interface.
var _ FwdPackager = (*ChannelPackager)(nil)
// Compile-time constraint to ensure that SwitchPackager implements the public
// FwdOperator interface.
var _ FwdOperator = (*SwitchPackager)(nil)

View file

@ -1,815 +0,0 @@
package migration_01_to_11_test
import (
"bytes"
"io/ioutil"
"path/filepath"
"runtime"
"testing"
"github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
)
// TestPkgFilterBruteForce tests the behavior of a pkg filter up to size 1000,
// which is greater than the number of HTLCs we permit on a commitment txn.
// This should encapsulate every potential filter used in practice.
func TestPkgFilterBruteForce(t *testing.T) {
t.Parallel()
checkPkgFilterRange(t, 1000)
}
// checkPkgFilterRange verifies the behavior of a pkg filter when doing a linear
// insertion of `high` elements. This is primarily to test that IsFull functions
// properly for all relevant sizes of `high`.
func checkPkgFilterRange(t *testing.T, high int) {
for i := uint16(0); i < uint16(high); i++ {
f := channeldb.NewPkgFilter(i)
if f.Count() != i {
t.Fatalf("pkg filter count=%d is actually %d",
i, f.Count())
}
checkPkgFilterEncodeDecode(t, i, f)
for j := uint16(0); j < i; j++ {
if f.Contains(j) {
t.Fatalf("pkg filter count=%d contains %d "+
"before being added", i, j)
}
f.Set(j)
checkPkgFilterEncodeDecode(t, i, f)
if !f.Contains(j) {
t.Fatalf("pkg filter count=%d missing %d "+
"after being added", i, j)
}
if j < i-1 && f.IsFull() {
t.Fatalf("pkg filter count=%d already full", i)
}
}
if !f.IsFull() {
t.Fatalf("pkg filter count=%d not full", i)
}
checkPkgFilterEncodeDecode(t, i, f)
}
}
// TestPkgFilterRand uses a random permutation to verify the proper behavior of
// the pkg filter if the entries are not inserted in-order.
func TestPkgFilterRand(t *testing.T) {
t.Parallel()
checkPkgFilterRand(t, 3, 17)
}
// checkPkgFilterRand checks the behavior of a pkg filter by randomly inserting
// indices and asserting the invariants. The order in which indices are inserted
// is parameterized by a base `b` coprime to `p`, and using modular
// exponentiation to generate all elements in [1,p).
func checkPkgFilterRand(t *testing.T, b, p uint16) {
f := channeldb.NewPkgFilter(p)
var j = b
for i := uint16(1); i < p; i++ {
if f.Contains(j) {
t.Fatalf("pkg filter contains %d-%d "+
"before being added", i, j)
}
f.Set(j)
checkPkgFilterEncodeDecode(t, i, f)
if !f.Contains(j) {
t.Fatalf("pkg filter missing %d-%d "+
"after being added", i, j)
}
if i < p-1 && f.IsFull() {
t.Fatalf("pkg filter %d already full", i)
}
checkPkgFilterEncodeDecode(t, i, f)
j = (b * j) % p
}
// Set 0 independently, since it will never be emitted by the generator.
f.Set(0)
checkPkgFilterEncodeDecode(t, p, f)
if !f.IsFull() {
t.Fatalf("pkg filter count=%d not full", p)
}
checkPkgFilterEncodeDecode(t, p, f)
}
// checkPkgFilterEncodeDecode tests the serialization of a pkg filter by:
// 1) writing it to a buffer
// 2) verifying the number of bytes written matches the filter's Size()
// 3) reconstructing the filter decoding the bytes
// 4) checking that the two filters are the same according to Equal
func checkPkgFilterEncodeDecode(t *testing.T, i uint16, f *channeldb.PkgFilter) {
var b bytes.Buffer
if err := f.Encode(&b); err != nil {
t.Fatalf("unable to serialize pkg filter: %v", err)
}
// +2 for uint16 length
size := uint16(len(b.Bytes()))
if size != f.Size() {
t.Fatalf("pkg filter count=%d serialized size differs, "+
"Size(): %d, len(bytes): %v", i, f.Size(), size)
}
reader := bytes.NewReader(b.Bytes())
f2 := &channeldb.PkgFilter{}
if err := f2.Decode(reader); err != nil {
t.Fatalf("unable to deserialize pkg filter: %v", err)
}
if !f.Equal(f2) {
t.Fatalf("pkg filter count=%v does is not equal "+
"after deserialization, want: %v, got %v",
i, f, f2)
}
}
var (
chanID = lnwire.NewChanIDFromOutPoint(&wire.OutPoint{})
adds = []channeldb.LogUpdate{
{
LogIndex: 0,
UpdateMsg: &lnwire.UpdateAddHTLC{
ChanID: chanID,
ID: 1,
Amount: 100,
Expiry: 1000,
PaymentHash: [32]byte{0},
},
},
{
LogIndex: 1,
UpdateMsg: &lnwire.UpdateAddHTLC{
ChanID: chanID,
ID: 1,
Amount: 101,
Expiry: 1001,
PaymentHash: [32]byte{1},
},
},
}
settleFails = []channeldb.LogUpdate{
{
LogIndex: 2,
UpdateMsg: &lnwire.UpdateFulfillHTLC{
ChanID: chanID,
ID: 0,
PaymentPreimage: [32]byte{0},
},
},
{
LogIndex: 3,
UpdateMsg: &lnwire.UpdateFailHTLC{
ChanID: chanID,
ID: 1,
Reason: []byte{},
},
},
}
)
// TestPackagerEmptyFwdPkg checks that the state transitions exhibited by a
// forwarding package that contains no adds, fails or settles. We expect that
// the fwdpkg reaches FwdStateCompleted immediately after writing the forwarding
// decision via SetFwdFilter.
func TestPackagerEmptyFwdPkg(t *testing.T) {
t.Parallel()
db := makeFwdPkgDB(t, "")
shortChanID := lnwire.NewShortChanIDFromInt(1)
packager := channeldb.NewChannelPackager(shortChanID)
// To begin, there should be no forwarding packages on disk.
fwdPkgs := loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
// Next, create and write a new forwarding package with no htlcs.
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, nil)
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
// There should now be one fwdpkg on disk. Since no forwarding decision
// has been written, we expect it to be FwdStateLockedIn. With no HTLCs,
// the ack filter will have no elements, and should always return true.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0)
assertAckFilterIsFull(t, fwdPkgs[0], true)
// Now, write the forwarding decision. In this case, its just an empty
// fwd filter.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
// We should still have one package on disk. Since the forwarding
// decision has been written, it will minimally be in FwdStateProcessed.
// However with no htlcs, it should leap frog to FwdStateCompleted.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0)
assertAckFilterIsFull(t, fwdPkgs[0], true)
// Lastly, remove the completed forwarding package from disk.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
// Check that the fwd package was actually removed.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
}
// TestPackagerOnlyAdds checks that the fwdpkg does not reach FwdStateCompleted
// as soon as all the adds in the package have been acked using AckAddHtlcs.
func TestPackagerOnlyAdds(t *testing.T) {
t.Parallel()
db := makeFwdPkgDB(t, "")
shortChanID := lnwire.NewShortChanIDFromInt(1)
packager := channeldb.NewChannelPackager(shortChanID)
// To begin, there should be no forwarding packages on disk.
fwdPkgs := loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
// Next, create and write a new forwarding package that only has add
// htlcs.
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, nil)
nAdds := len(adds)
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
// There should now be one fwdpkg on disk. Since no forwarding decision
// has been written, we expect it to be FwdStateLockedIn. The package
// has unacked add HTLCs, so the ack filter should not be full.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0)
assertAckFilterIsFull(t, fwdPkgs[0], false)
// Now, write the forwarding decision. Since we have not explicitly
// added any adds to the fwdfilter, this would indicate that all of the
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
// was failed locally.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
for i := range adds {
// We should still have one package on disk. Since the forwarding
// decision has been written, it will minimally be in FwdStateProcessed.
// However not allf of the HTLCs have been acked, so should not
// have advanced further.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0)
assertAckFilterIsFull(t, fwdPkgs[0], false)
addRef := channeldb.AddRef{
Height: fwdPkg.Height,
Index: uint16(i),
}
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AckAddHtlcs(tx, addRef)
}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err)
}
}
// We should still have one package on disk. Now that all adds have been
// acked, the ack filter should return true and the package should be
// FwdStateCompleted since there are no other settle/fail packets.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0)
assertAckFilterIsFull(t, fwdPkgs[0], true)
// Lastly, remove the completed forwarding package from disk.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
// Check that the fwd package was actually removed.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
}
// TestPackagerOnlySettleFails asserts that the fwdpkg remains in
// FwdStateProcessed after writing the forwarding decision when there are no
// adds in the fwdpkg. We expect this because an empty FwdFilter will always
// return true, but we are still waiting for the remaining fails and settles to
// be deleted.
func TestPackagerOnlySettleFails(t *testing.T) {
t.Parallel()
db := makeFwdPkgDB(t, "")
shortChanID := lnwire.NewShortChanIDFromInt(1)
packager := channeldb.NewChannelPackager(shortChanID)
// To begin, there should be no forwarding packages on disk.
fwdPkgs := loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
// Next, create and write a new forwarding package that only has add
// htlcs.
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, settleFails)
nSettleFails := len(settleFails)
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
// There should now be one fwdpkg on disk. Since no forwarding decision
// has been written, we expect it to be FwdStateLockedIn. The package
// has unacked add HTLCs, so the ack filter should not be full.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails)
assertAckFilterIsFull(t, fwdPkgs[0], true)
// Now, write the forwarding decision. Since we have not explicitly
// added any adds to the fwdfilter, this would indicate that all of the
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
// was failed locally.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
for i := range settleFails {
// We should still have one package on disk. Since the
// forwarding decision has been written, it will minimally be in
// FwdStateProcessed. However, not all of the HTLCs have been
// acked, so should not have advanced further.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails)
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
assertAckFilterIsFull(t, fwdPkgs[0], true)
failSettleRef := channeldb.SettleFailRef{
Source: shortChanID,
Height: fwdPkg.Height,
Index: uint16(i),
}
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AckSettleFails(tx, failSettleRef)
}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err)
}
}
// We should still have one package on disk. Now that all settles and
// fails have been removed, package should be FwdStateCompleted since
// there are no other add packets.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails)
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
assertAckFilterIsFull(t, fwdPkgs[0], true)
// Lastly, remove the completed forwarding package from disk.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
// Check that the fwd package was actually removed.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
}
// TestPackagerAddsThenSettleFails writes a fwdpkg containing both adds and
// settle/fails, then checks the behavior when the adds are acked before any of
// the settle fails. Here we expect pkg to remain in FwdStateProcessed while the
// remainder of the fail/settles are being deleted.
func TestPackagerAddsThenSettleFails(t *testing.T) {
t.Parallel()
db := makeFwdPkgDB(t, "")
shortChanID := lnwire.NewShortChanIDFromInt(1)
packager := channeldb.NewChannelPackager(shortChanID)
// To begin, there should be no forwarding packages on disk.
fwdPkgs := loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
// Next, create and write a new forwarding package that only has add
// htlcs.
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails)
nAdds := len(adds)
nSettleFails := len(settleFails)
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
// There should now be one fwdpkg on disk. Since no forwarding decision
// has been written, we expect it to be FwdStateLockedIn. The package
// has unacked add HTLCs, so the ack filter should not be full.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
assertAckFilterIsFull(t, fwdPkgs[0], false)
// Now, write the forwarding decision. Since we have not explicitly
// added any adds to the fwdfilter, this would indicate that all of the
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
// was failed locally.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
for i := range adds {
// We should still have one package on disk. Since the forwarding
// decision has been written, it will minimally be in FwdStateProcessed.
// However not allf of the HTLCs have been acked, so should not
// have advanced further.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
assertAckFilterIsFull(t, fwdPkgs[0], false)
addRef := channeldb.AddRef{
Height: fwdPkg.Height,
Index: uint16(i),
}
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AckAddHtlcs(tx, addRef)
}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err)
}
}
for i := range settleFails {
// We should still have one package on disk. Since the
// forwarding decision has been written, it will minimally be in
// FwdStateProcessed. However not allf of the HTLCs have been
// acked, so should not have advanced further.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
assertAckFilterIsFull(t, fwdPkgs[0], true)
failSettleRef := channeldb.SettleFailRef{
Source: shortChanID,
Height: fwdPkg.Height,
Index: uint16(i),
}
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AckSettleFails(tx, failSettleRef)
}); err != nil {
t.Fatalf("unable to remove settle/fail htlc: %v", err)
}
}
// We should still have one package on disk. Now that all settles and
// fails have been removed, package should be FwdStateCompleted since
// there are no other add packets.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
assertAckFilterIsFull(t, fwdPkgs[0], true)
// Lastly, remove the completed forwarding package from disk.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
// Check that the fwd package was actually removed.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
}
// TestPackagerSettleFailsThenAdds writes a fwdpkg with both adds and
// settle/fails, then checks the behavior when the settle/fails are removed
// before any of the adds have been acked. This should cause the fwdpkg to
// remain in FwdStateProcessed until the final ack is recorded, at which point
// it should be promoted directly to FwdStateCompleted.since all adds have been
// removed.
func TestPackagerSettleFailsThenAdds(t *testing.T) {
t.Parallel()
db := makeFwdPkgDB(t, "")
shortChanID := lnwire.NewShortChanIDFromInt(1)
packager := channeldb.NewChannelPackager(shortChanID)
// To begin, there should be no forwarding packages on disk.
fwdPkgs := loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
// Next, create and write a new forwarding package that has both add
// and settle/fail htlcs.
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails)
nAdds := len(adds)
nSettleFails := len(settleFails)
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
// There should now be one fwdpkg on disk. Since no forwarding decision
// has been written, we expect it to be FwdStateLockedIn. The package
// has unacked add HTLCs, so the ack filter should not be full.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
assertAckFilterIsFull(t, fwdPkgs[0], false)
// Now, write the forwarding decision. Since we have not explicitly
// added any adds to the fwdfilter, this would indicate that all of the
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
// was failed locally.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
// Simulate another channel deleting the settle/fails it received from
// the original fwd pkg.
// TODO(conner): use different packager/s?
for i := range settleFails {
// We should still have one package on disk. Since the
// forwarding decision has been written, it will minimally be in
// FwdStateProcessed. However none all of the add HTLCs have
// been acked, so should not have advanced further.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
assertAckFilterIsFull(t, fwdPkgs[0], false)
failSettleRef := channeldb.SettleFailRef{
Source: shortChanID,
Height: fwdPkg.Height,
Index: uint16(i),
}
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AckSettleFails(tx, failSettleRef)
}); err != nil {
t.Fatalf("unable to remove settle/fail htlc: %v", err)
}
}
// Now simulate this channel receiving a fail/settle for the adds in the
// fwdpkg.
for i := range adds {
// Again, we should still have one package on disk and be in
// FwdStateProcessed. This should not change until all of the
// add htlcs have been acked.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
assertAckFilterIsFull(t, fwdPkgs[0], false)
addRef := channeldb.AddRef{
Height: fwdPkg.Height,
Index: uint16(i),
}
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.AckAddHtlcs(tx, addRef)
}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err)
}
}
// We should still have one package on disk. Now that all settles and
// fails have been removed, package should be FwdStateCompleted since
// there are no other add packets.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 1 {
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
}
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
assertAckFilterIsFull(t, fwdPkgs[0], true)
// Lastly, remove the completed forwarding package from disk.
if err := db.Update(func(tx *bbolt.Tx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
// Check that the fwd package was actually removed.
fwdPkgs = loadFwdPkgs(t, db, packager)
if len(fwdPkgs) != 0 {
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
}
}
// assertFwdPkgState checks the current state of a fwdpkg meets our
// expectations.
func assertFwdPkgState(t *testing.T, fwdPkg *channeldb.FwdPkg,
state channeldb.FwdState) {
_, _, line, _ := runtime.Caller(1)
if fwdPkg.State != state {
t.Fatalf("line %d: expected fwdpkg in state %v, found %v",
line, state, fwdPkg.State)
}
}
// assertFwdPkgNumAddsSettleFails checks that the number of adds and
// settle/fail log updates are correct.
func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *channeldb.FwdPkg,
expectedNumAdds, expectedNumSettleFails int) {
_, _, line, _ := runtime.Caller(1)
if len(fwdPkg.Adds) != expectedNumAdds {
t.Fatalf("line %d: expected fwdpkg to have %d adds, found %d",
line, expectedNumAdds, len(fwdPkg.Adds))
}
if len(fwdPkg.SettleFails) != expectedNumSettleFails {
t.Fatalf("line %d: expected fwdpkg to have %d settle/fails, found %d",
line, expectedNumSettleFails, len(fwdPkg.SettleFails))
}
}
// assertAckFilterIsFull checks whether or not a fwdpkg's ack filter matches our
// expected full-ness.
func assertAckFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) {
_, _, line, _ := runtime.Caller(1)
if fwdPkg.AckFilter.IsFull() != expected {
t.Fatalf("line %d: expected fwdpkg ack filter IsFull to be %v, "+
"found %v", line, expected, fwdPkg.AckFilter.IsFull())
}
}
// assertSettleFailFilterIsFull checks whether or not a fwdpkg's settle fail
// filter matches our expected full-ness.
func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) {
_, _, line, _ := runtime.Caller(1)
if fwdPkg.SettleFailFilter.IsFull() != expected {
t.Fatalf("line %d: expected fwdpkg settle/fail filter IsFull to be %v, "+
"found %v", line, expected, fwdPkg.SettleFailFilter.IsFull())
}
}
// loadFwdPkgs is a helper method that reads all forwarding packages for a
// particular packager.
func loadFwdPkgs(t *testing.T, db *bbolt.DB,
packager channeldb.FwdPackager) []*channeldb.FwdPkg {
var fwdPkgs []*channeldb.FwdPkg
if err := db.View(func(tx *bbolt.Tx) error {
var err error
fwdPkgs, err = packager.LoadFwdPkgs(tx)
return err
}); err != nil {
t.Fatalf("unable to load fwd pkgs: %v", err)
}
return fwdPkgs
}
// makeFwdPkgDB initializes a test database for forwarding packages. If the
// provided path is an empty, it will create a temp dir/file to use.
func makeFwdPkgDB(t *testing.T, path string) *bbolt.DB {
if path == "" {
var err error
path, err = ioutil.TempDir("", "fwdpkgdb")
if err != nil {
t.Fatalf("unable to create temp path: %v", err)
}
path = filepath.Join(path, "fwdpkg.db")
}
db, err := bbolt.Open(path, 0600, nil)
if err != nil {
t.Fatalf("unable to open boltdb: %v", err)
}
return db
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,694 +0,0 @@
package migration_01_to_11
import (
"crypto/rand"
"reflect"
"testing"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lnwire"
)
func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
var pre [32]byte
if _, err := rand.Read(pre[:]); err != nil {
return nil, err
}
i := &Invoice{
// Use single second precision to avoid false positive test
// failures due to the monotonic time component.
CreationDate: time.Unix(time.Now().Unix(), 0),
Terms: ContractTerm{
PaymentPreimage: pre,
Value: value,
},
Htlcs: map[CircuitKey]*InvoiceHTLC{},
Expiry: 4000,
}
i.Memo = []byte("memo")
i.Receipt = []byte("receipt")
// Create a random byte slice of MaxPaymentRequestSize bytes to be used
// as a dummy paymentrequest, and determine if it should be set based
// on one of the random bytes.
var r [MaxPaymentRequestSize]byte
if _, err := rand.Read(r[:]); err != nil {
return nil, err
}
if r[0]&1 == 0 {
i.PaymentRequest = r[:]
} else {
i.PaymentRequest = []byte("")
}
return i, nil
}
func TestInvoiceWorkflow(t *testing.T) {
t.Parallel()
db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
// Create a fake invoice which we'll use several times in the tests
// below.
fakeInvoice := &Invoice{
// Use single second precision to avoid false positive test
// failures due to the monotonic time component.
CreationDate: time.Unix(time.Now().Unix(), 0),
Htlcs: map[CircuitKey]*InvoiceHTLC{},
}
fakeInvoice.Memo = []byte("memo")
fakeInvoice.Receipt = []byte("receipt")
fakeInvoice.PaymentRequest = []byte("")
copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:])
fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000)
paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash()
// Add the invoice to the database, this should succeed as there aren't
// any existing invoices within the database with the same payment
// hash.
if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil {
t.Fatalf("unable to find invoice: %v", err)
}
// Attempt to retrieve the invoice which was just added to the
// database. It should be found, and the invoice returned should be
// identical to the one created above.
dbInvoice, err := db.LookupInvoice(paymentHash)
if err != nil {
t.Fatalf("unable to find invoice: %v", err)
}
if !reflect.DeepEqual(*fakeInvoice, dbInvoice) {
t.Fatalf("invoice fetched from db doesn't match original %v vs %v",
spew.Sdump(fakeInvoice), spew.Sdump(dbInvoice))
}
// The add index of the invoice retrieved from the database should now
// be fully populated. As this is the first index written to the DB,
// the addIndex should be 1.
if dbInvoice.AddIndex != 1 {
t.Fatalf("wrong add index: expected %v, got %v", 1,
dbInvoice.AddIndex)
}
// Settle the invoice, the version retrieved from the database should
// now have the settled bit toggle to true and a non-default
// SettledDate
payAmt := fakeInvoice.Terms.Value * 2
_, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt))
if err != nil {
t.Fatalf("unable to settle invoice: %v", err)
}
dbInvoice2, err := db.LookupInvoice(paymentHash)
if err != nil {
t.Fatalf("unable to fetch invoice: %v", err)
}
if dbInvoice2.Terms.State != ContractSettled {
t.Fatalf("invoice should now be settled but isn't")
}
if dbInvoice2.SettleDate.IsZero() {
t.Fatalf("invoice should have non-zero SettledDate but isn't")
}
// Our 2x payment should be reflected, and also the settle index of 1
// should also have been committed for this index.
if dbInvoice2.AmtPaid != payAmt {
t.Fatalf("wrong amt paid: expected %v, got %v", payAmt,
dbInvoice2.AmtPaid)
}
if dbInvoice2.SettleIndex != 1 {
t.Fatalf("wrong settle index: expected %v, got %v", 1,
dbInvoice2.SettleIndex)
}
// Attempt to insert generated above again, this should fail as
// duplicates are rejected by the processing logic.
if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice {
t.Fatalf("invoice insertion should fail due to duplication, "+
"instead %v", err)
}
// Attempt to look up a non-existent invoice, this should also fail but
// with a "not found" error.
var fakeHash [32]byte
if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound {
t.Fatalf("lookup should have failed, instead %v", err)
}
// Add 10 random invoices.
const numInvoices = 10
amt := lnwire.NewMSatFromSatoshis(1000)
invoices := make([]*Invoice, numInvoices+1)
invoices[0] = &dbInvoice2
for i := 1; i < len(invoices)-1; i++ {
invoice, err := randInvoice(amt)
if err != nil {
t.Fatalf("unable to create invoice: %v", err)
}
hash := invoice.Terms.PaymentPreimage.Hash()
if _, err := db.AddInvoice(invoice, hash); err != nil {
t.Fatalf("unable to add invoice %v", err)
}
invoices[i] = invoice
}
// Perform a scan to collect all the active invoices.
dbInvoices, err := db.FetchAllInvoices(false)
if err != nil {
t.Fatalf("unable to fetch all invoices: %v", err)
}
// The retrieve list of invoices should be identical as since we're
// using big endian, the invoices should be retrieved in ascending
// order (and the primary key should be incremented with each
// insertion).
for i := 0; i < len(invoices)-1; i++ {
if !reflect.DeepEqual(*invoices[i], dbInvoices[i]) {
t.Fatalf("retrieved invoices don't match %v vs %v",
spew.Sdump(invoices[i]),
spew.Sdump(dbInvoices[i]))
}
}
}
// TestInvoiceTimeSeries tests that newly added invoices invoices, as well as
// settled invoices are added to the database are properly placed in the add
// add or settle index which serves as an event time series.
func TestInvoiceAddTimeSeries(t *testing.T) {
t.Parallel()
db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
// We'll start off by creating 20 random invoices, and inserting them
// into the database.
const numInvoices = 20
amt := lnwire.NewMSatFromSatoshis(1000)
invoices := make([]Invoice, numInvoices)
for i := 0; i < len(invoices); i++ {
invoice, err := randInvoice(amt)
if err != nil {
t.Fatalf("unable to create invoice: %v", err)
}
paymentHash := invoice.Terms.PaymentPreimage.Hash()
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
t.Fatalf("unable to add invoice %v", err)
}
invoices[i] = *invoice
}
// With the invoices constructed, we'll now create a series of queries
// that we'll use to assert expected return values of
// InvoicesAddedSince.
addQueries := []struct {
sinceAddIndex uint64
resp []Invoice
}{
// If we specify a value of zero, we shouldn't get any invoices
// back.
{
sinceAddIndex: 0,
},
// If we specify a value well beyond the number of inserted
// invoices, we shouldn't get any invoices back.
{
sinceAddIndex: 99999999,
},
// Using an index of 1 should result in all values, but the
// first one being returned.
{
sinceAddIndex: 1,
resp: invoices[1:],
},
// If we use an index of 10, then we should retrieve the
// reaming 10 invoices.
{
sinceAddIndex: 10,
resp: invoices[10:],
},
}
for i, query := range addQueries {
resp, err := db.InvoicesAddedSince(query.sinceAddIndex)
if err != nil {
t.Fatalf("unable to query: %v", err)
}
if !reflect.DeepEqual(query.resp, resp) {
t.Fatalf("test #%v: expected %v, got %v", i,
spew.Sdump(query.resp), spew.Sdump(resp))
}
}
// We'll now only settle the latter half of each of those invoices.
for i := 10; i < len(invoices); i++ {
invoice := &invoices[i]
paymentHash := invoice.Terms.PaymentPreimage.Hash()
_, err := db.UpdateInvoice(
paymentHash, getUpdateInvoice(0),
)
if err != nil {
t.Fatalf("unable to settle invoice: %v", err)
}
}
invoices, err = db.FetchAllInvoices(false)
if err != nil {
t.Fatalf("unable to fetch invoices: %v", err)
}
// We'll slice off the first 10 invoices, as we only settled the last
// 10.
invoices = invoices[10:]
// We'll now prepare an additional set of queries to ensure the settle
// time series has properly been maintained in the database.
settleQueries := []struct {
sinceSettleIndex uint64
resp []Invoice
}{
// If we specify a value of zero, we shouldn't get any settled
// invoices back.
{
sinceSettleIndex: 0,
},
// If we specify a value well beyond the number of settled
// invoices, we shouldn't get any invoices back.
{
sinceSettleIndex: 99999999,
},
// Using an index of 1 should result in the final 10 invoices
// being returned, as we only settled those.
{
sinceSettleIndex: 1,
resp: invoices[1:],
},
}
for i, query := range settleQueries {
resp, err := db.InvoicesSettledSince(query.sinceSettleIndex)
if err != nil {
t.Fatalf("unable to query: %v", err)
}
if !reflect.DeepEqual(query.resp, resp) {
t.Fatalf("test #%v: expected %v, got %v", i,
spew.Sdump(query.resp), spew.Sdump(resp))
}
}
}
// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it
// twice, then the second time we also receive the invoice that we settled as a
// return argument.
func TestDuplicateSettleInvoice(t *testing.T) {
t.Parallel()
db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
db.now = func() time.Time { return time.Unix(1, 0) }
// We'll start out by creating an invoice and writing it to the DB.
amt := lnwire.NewMSatFromSatoshis(1000)
invoice, err := randInvoice(amt)
if err != nil {
t.Fatalf("unable to create invoice: %v", err)
}
payHash := invoice.Terms.PaymentPreimage.Hash()
if _, err := db.AddInvoice(invoice, payHash); err != nil {
t.Fatalf("unable to add invoice %v", err)
}
// With the invoice in the DB, we'll now attempt to settle the invoice.
dbInvoice, err := db.UpdateInvoice(
payHash, getUpdateInvoice(amt),
)
if err != nil {
t.Fatalf("unable to settle invoice: %v", err)
}
// We'll update what we expect the settle invoice to be so that our
// comparison below has the correct assumption.
invoice.SettleIndex = 1
invoice.Terms.State = ContractSettled
invoice.AmtPaid = amt
invoice.SettleDate = dbInvoice.SettleDate
invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{
{}: {
Amt: amt,
AcceptTime: time.Unix(1, 0),
ResolveTime: time.Unix(1, 0),
State: HtlcStateSettled,
},
}
// We should get back the exact same invoice that we just inserted.
if !reflect.DeepEqual(dbInvoice, invoice) {
t.Fatalf("wrong invoice after settle, expected %v got %v",
spew.Sdump(invoice), spew.Sdump(dbInvoice))
}
// If we try to settle the invoice again, then we should get the very
// same invoice back, but with an error this time.
dbInvoice, err = db.UpdateInvoice(
payHash, getUpdateInvoice(amt),
)
if err != ErrInvoiceAlreadySettled {
t.Fatalf("expected ErrInvoiceAlreadySettled")
}
if dbInvoice == nil {
t.Fatalf("invoice from db is nil after settle!")
}
invoice.SettleDate = dbInvoice.SettleDate
if !reflect.DeepEqual(dbInvoice, invoice) {
t.Fatalf("wrong invoice after second settle, expected %v got %v",
spew.Sdump(invoice), spew.Sdump(dbInvoice))
}
}
// TestQueryInvoices ensures that we can properly query the invoice database for
// invoices using different types of queries.
func TestQueryInvoices(t *testing.T) {
t.Parallel()
db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
// To begin the test, we'll add 50 invoices to the database. We'll
// assume that the index of the invoice within the database is the same
// as the amount of the invoice itself.
const numInvoices = 50
for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ {
invoice, err := randInvoice(i)
if err != nil {
t.Fatalf("unable to create invoice: %v", err)
}
paymentHash := invoice.Terms.PaymentPreimage.Hash()
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
// We'll only settle half of all invoices created.
if i%2 == 0 {
_, err := db.UpdateInvoice(
paymentHash, getUpdateInvoice(i),
)
if err != nil {
t.Fatalf("unable to settle invoice: %v", err)
}
}
}
// We'll then retrieve the set of all invoices and pending invoices.
// This will serve useful when comparing the expected responses of the
// query with the actual ones.
invoices, err := db.FetchAllInvoices(false)
if err != nil {
t.Fatalf("unable to retrieve invoices: %v", err)
}
pendingInvoices, err := db.FetchAllInvoices(true)
if err != nil {
t.Fatalf("unable to retrieve pending invoices: %v", err)
}
// The test will consist of several queries along with their respective
// expected response. Each query response should match its expected one.
testCases := []struct {
query InvoiceQuery
expected []Invoice
}{
// Fetch all invoices with a single query.
{
query: InvoiceQuery{
NumMaxInvoices: numInvoices,
},
expected: invoices,
},
// Fetch all invoices with a single query, reversed.
{
query: InvoiceQuery{
Reversed: true,
NumMaxInvoices: numInvoices,
},
expected: invoices,
},
// Fetch the first 25 invoices.
{
query: InvoiceQuery{
NumMaxInvoices: numInvoices / 2,
},
expected: invoices[:numInvoices/2],
},
// Fetch the first 10 invoices, but this time iterating
// backwards.
{
query: InvoiceQuery{
IndexOffset: 11,
Reversed: true,
NumMaxInvoices: numInvoices,
},
expected: invoices[:10],
},
// Fetch the last 40 invoices.
{
query: InvoiceQuery{
IndexOffset: 10,
NumMaxInvoices: numInvoices,
},
expected: invoices[10:],
},
// Fetch all but the first invoice.
{
query: InvoiceQuery{
IndexOffset: 1,
NumMaxInvoices: numInvoices,
},
expected: invoices[1:],
},
// Fetch one invoice, reversed, with index offset 3. This
// should give us the second invoice in the array.
{
query: InvoiceQuery{
IndexOffset: 3,
Reversed: true,
NumMaxInvoices: 1,
},
expected: invoices[1:2],
},
// Same as above, at index 2.
{
query: InvoiceQuery{
IndexOffset: 2,
Reversed: true,
NumMaxInvoices: 1,
},
expected: invoices[0:1],
},
// Fetch one invoice, at index 1, reversed. Since invoice#1 is
// the very first, there won't be any left in a reverse search,
// so we expect no invoices to be returned.
{
query: InvoiceQuery{
IndexOffset: 1,
Reversed: true,
NumMaxInvoices: 1,
},
expected: nil,
},
// Same as above, but don't restrict the number of invoices to
// 1.
{
query: InvoiceQuery{
IndexOffset: 1,
Reversed: true,
NumMaxInvoices: numInvoices,
},
expected: nil,
},
// Fetch one invoice, reversed, with no offset set. We expect
// the last invoice in the response.
{
query: InvoiceQuery{
Reversed: true,
NumMaxInvoices: 1,
},
expected: invoices[numInvoices-1:],
},
// Fetch one invoice, reversed, the offset set at numInvoices+1.
// We expect this to return the last invoice.
{
query: InvoiceQuery{
IndexOffset: numInvoices + 1,
Reversed: true,
NumMaxInvoices: 1,
},
expected: invoices[numInvoices-1:],
},
// Same as above, at offset numInvoices.
{
query: InvoiceQuery{
IndexOffset: numInvoices,
Reversed: true,
NumMaxInvoices: 1,
},
expected: invoices[numInvoices-2 : numInvoices-1],
},
// Fetch one invoice, at no offset (same as offset 0). We
// expect the first invoice only in the response.
{
query: InvoiceQuery{
NumMaxInvoices: 1,
},
expected: invoices[:1],
},
// Same as above, at offset 1.
{
query: InvoiceQuery{
IndexOffset: 1,
NumMaxInvoices: 1,
},
expected: invoices[1:2],
},
// Same as above, at offset 2.
{
query: InvoiceQuery{
IndexOffset: 2,
NumMaxInvoices: 1,
},
expected: invoices[2:3],
},
// Same as above, at offset numInvoices-1. Expect the last
// invoice to be returned.
{
query: InvoiceQuery{
IndexOffset: numInvoices - 1,
NumMaxInvoices: 1,
},
expected: invoices[numInvoices-1:],
},
// Same as above, at offset numInvoices. No invoices should be
// returned, as there are no invoices after this offset.
{
query: InvoiceQuery{
IndexOffset: numInvoices,
NumMaxInvoices: 1,
},
expected: nil,
},
// Fetch all pending invoices with a single query.
{
query: InvoiceQuery{
PendingOnly: true,
NumMaxInvoices: numInvoices,
},
expected: pendingInvoices,
},
// Fetch the first 12 pending invoices.
{
query: InvoiceQuery{
PendingOnly: true,
NumMaxInvoices: numInvoices / 4,
},
expected: pendingInvoices[:len(pendingInvoices)/2],
},
// Fetch the first 5 pending invoices, but this time iterating
// backwards.
{
query: InvoiceQuery{
IndexOffset: 10,
PendingOnly: true,
Reversed: true,
NumMaxInvoices: numInvoices,
},
// Since we seek to the invoice with index 10 and
// iterate backwards, there should only be 5 pending
// invoices before it as every other invoice within the
// index is settled.
expected: pendingInvoices[:5],
},
// Fetch the last 15 invoices.
{
query: InvoiceQuery{
IndexOffset: 20,
PendingOnly: true,
NumMaxInvoices: numInvoices,
},
// Since we seek to the invoice with index 20, there are
// 30 invoices left. From these 30, only 15 of them are
// still pending.
expected: pendingInvoices[len(pendingInvoices)-15:],
},
}
for i, testCase := range testCases {
response, err := db.QueryInvoices(testCase.query)
if err != nil {
t.Fatalf("unable to query invoice database: %v", err)
}
if !reflect.DeepEqual(response.Invoices, testCase.expected) {
t.Fatalf("test #%d: query returned incorrect set of "+
"invoices: expcted %v, got %v", i,
spew.Sdump(response.Invoices),
spew.Sdump(testCase.expected))
}
}
}
// getUpdateInvoice returns an invoice update callback that, when called,
// settles the invoice with the given amount.
func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
return func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
if invoice.Terms.State == ContractSettled {
return nil, ErrInvoiceAlreadySettled
}
update := &InvoiceUpdateDesc{
Preimage: invoice.Terms.PaymentPreimage,
State: ContractSettled,
Htlcs: map[CircuitKey]*HtlcAcceptDesc{
{}: {
Amt: amt,
},
},
}
return update, nil
}
}

View file

@ -3,7 +3,6 @@ package migration_01_to_11
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"time"
@ -16,9 +15,6 @@ import (
)
var (
// UnknownPreimage is an all-zeroes preimage that indicates that the
// preimage for this invoice is not yet known.
UnknownPreimage lntypes.Preimage
// invoiceBucket is the name of the bucket within the database that
// stores all data related to invoices no matter their final state.
@ -26,23 +22,6 @@ var (
// which is a monotonically increasing uint32.
invoiceBucket = []byte("invoices")
// paymentHashIndexBucket is the name of the sub-bucket within the
// invoiceBucket which indexes all invoices by their payment hash. The
// payment hash is the sha256 of the invoice's payment preimage. This
// index is used to detect duplicates, and also to provide a fast path
// for looking up incoming HTLCs to determine if we're able to settle
// them fully.
//
// maps: payHash => invoiceKey
invoiceIndexBucket = []byte("paymenthashes")
// numInvoicesKey is the name of key which houses the auto-incrementing
// invoice ID which is essentially used as a primary key. With each
// invoice inserted, the primary key is incremented by one. This key is
// stored within the invoiceIndexBucket. Within the invoiceBucket
// invoices are uniquely identified by the invoice ID.
numInvoicesKey = []byte("nik")
// addIndexBucket is an index bucket that we'll use to create a
// monotonically increasing set of add indexes. Each time we add a new
// invoice, this sequence number will be incremented and then populated
@ -62,21 +41,6 @@ var (
//
// settleIndexNo => invoiceKey
settleIndexBucket = []byte("invoice-settle-index")
// ErrInvoiceAlreadySettled is returned when the invoice is already
// settled.
ErrInvoiceAlreadySettled = errors.New("invoice already settled")
// ErrInvoiceAlreadyCanceled is returned when the invoice is already
// canceled.
ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled")
// ErrInvoiceAlreadyAccepted is returned when the invoice is already
// accepted.
ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted")
// ErrInvoiceStillOpen is returned when the invoice is still open.
ErrInvoiceStillOpen = errors.New("invoice still open")
)
const (
@ -237,18 +201,6 @@ type Invoice struct {
// HtlcState defines the states an htlc paying to an invoice can be in.
type HtlcState uint8
const (
// HtlcStateAccepted indicates the htlc is locked-in, but not resolved.
HtlcStateAccepted HtlcState = iota
// HtlcStateCanceled indicates the htlc is canceled back to the
// sender.
HtlcStateCanceled
// HtlcStateSettled indicates the htlc is settled.
HtlcStateSettled
)
// InvoiceHTLC contains details about an htlc paying to this invoice.
type InvoiceHTLC struct {
// Amt is the amount that is carried by this htlc.
@ -276,37 +228,6 @@ type InvoiceHTLC struct {
State HtlcState
}
// HtlcAcceptDesc describes the details of a newly accepted htlc.
type HtlcAcceptDesc struct {
// AcceptHeight is the block height at which this htlc was accepted.
AcceptHeight int32
// Amt is the amount that is carried by this htlc.
Amt lnwire.MilliSatoshi
// Expiry is the expiry height of this htlc.
Expiry uint32
}
// InvoiceUpdateDesc describes the changes that should be applied to the
// invoice.
type InvoiceUpdateDesc struct {
// State is the new state that this invoice should progress to.
State ContractState
// Htlcs describes the changes that need to be made to the invoice htlcs
// in the database. Htlc map entries with their value set should be
// added. If the map value is nil, the htlc should be canceled.
Htlcs map[CircuitKey]*HtlcAcceptDesc
// Preimage must be set to the preimage when state is settled.
Preimage lntypes.Preimage
}
// InvoiceUpdateCallback is a callback used in the db transaction to update the
// invoice.
type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error)
func validateInvoice(i *Invoice) error {
if len(i.Memo) > MaxMemoSize {
return fmt.Errorf("max length a memo is %v, and invoice "+
@ -325,186 +246,6 @@ func validateInvoice(i *Invoice) error {
return nil
}
// AddInvoice inserts the targeted invoice into the database. If the invoice has
// *any* payment hashes which already exists within the database, then the
// insertion will be aborted and rejected due to the strict policy banning any
// duplicate payment hashes. A side effect of this function is that it sets
// AddIndex on newInvoice.
func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) (
uint64, error) {
if err := validateInvoice(newInvoice); err != nil {
return 0, err
}
var invoiceAddIndex uint64
err := d.Update(func(tx *bbolt.Tx) error {
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
if err != nil {
return err
}
invoiceIndex, err := invoices.CreateBucketIfNotExists(
invoiceIndexBucket,
)
if err != nil {
return err
}
addIndex, err := invoices.CreateBucketIfNotExists(
addIndexBucket,
)
if err != nil {
return err
}
// Ensure that an invoice an identical payment hash doesn't
// already exist within the index.
if invoiceIndex.Get(paymentHash[:]) != nil {
return ErrDuplicateInvoice
}
// If the current running payment ID counter hasn't yet been
// created, then create it now.
var invoiceNum uint32
invoiceCounter := invoiceIndex.Get(numInvoicesKey)
if invoiceCounter == nil {
var scratch [4]byte
byteOrder.PutUint32(scratch[:], invoiceNum)
err := invoiceIndex.Put(numInvoicesKey, scratch[:])
if err != nil {
return err
}
} else {
invoiceNum = byteOrder.Uint32(invoiceCounter)
}
newIndex, err := putInvoice(
invoices, invoiceIndex, addIndex, newInvoice, invoiceNum,
paymentHash,
)
if err != nil {
return err
}
invoiceAddIndex = newIndex
return nil
})
if err != nil {
return 0, err
}
return invoiceAddIndex, err
}
// InvoicesAddedSince can be used by callers to seek into the event time series
// of all the invoices added in the database. The specified sinceAddIndex
// should be the highest add index that the caller knows of. This method will
// return all invoices with an add index greater than the specified
// sinceAddIndex.
//
// NOTE: The index starts from 1, as a result. We enforce that specifying a
// value below the starting index value is a noop.
func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) {
var newInvoices []Invoice
// If an index of zero was specified, then in order to maintain
// backwards compat, we won't send out any new invoices.
if sinceAddIndex == 0 {
return newInvoices, nil
}
var startIndex [8]byte
byteOrder.PutUint64(startIndex[:], sinceAddIndex)
err := d.DB.View(func(tx *bbolt.Tx) error {
invoices := tx.Bucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
}
addIndex := invoices.Bucket(addIndexBucket)
if addIndex == nil {
return ErrNoInvoicesCreated
}
// We'll now run through each entry in the add index starting
// at our starting index. We'll continue until we reach the
// very end of the current key space.
invoiceCursor := addIndex.Cursor()
// We'll seek to the starting index, then manually advance the
// cursor in order to skip the entry with the since add index.
invoiceCursor.Seek(startIndex[:])
addSeqNo, invoiceKey := invoiceCursor.Next()
for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(invoiceKey, invoices)
if err != nil {
return err
}
newInvoices = append(newInvoices, invoice)
}
return nil
})
switch {
// If no invoices have been created, then we'll return the empty set of
// invoices.
case err == ErrNoInvoicesCreated:
case err != nil:
return nil, err
}
return newInvoices, nil
}
// LookupInvoice attempts to look up an invoice according to its 32 byte
// payment hash. If an invoice which can settle the HTLC identified by the
// passed payment hash isn't found, then an error is returned. Otherwise, the
// full invoice is returned. Before setting the incoming HTLC, the values
// SHOULD be checked to ensure the payer meets the agreed upon contractual
// terms of the payment.
func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
var invoice Invoice
err := d.View(func(tx *bbolt.Tx) error {
invoices := tx.Bucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
}
invoiceIndex := invoices.Bucket(invoiceIndexBucket)
if invoiceIndex == nil {
return ErrNoInvoicesCreated
}
// Check the invoice index to see if an invoice paying to this
// hash exists within the DB.
invoiceNum := invoiceIndex.Get(paymentHash[:])
if invoiceNum == nil {
return ErrInvoiceNotFound
}
// An invoice matching the payment hash has been found, so
// retrieve the record of the invoice itself.
i, err := fetchInvoice(invoiceNum, invoices)
if err != nil {
return err
}
invoice = i
return nil
})
if err != nil {
return invoice, err
}
return invoice, nil
}
// FetchAllInvoices returns all invoices currently stored within the database.
// If the pendingOnly param is true, then only unsettled invoices will be
// returned, skipping all invoices that are fully settled.
@ -549,343 +290,6 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) {
return invoices, nil
}
// InvoiceQuery represents a query to the invoice database. The query allows a
// caller to retrieve all invoices starting from a particular add index and
// limit the number of results returned.
type InvoiceQuery struct {
// IndexOffset is the offset within the add indices to start at. This
// can be used to start the response at a particular invoice.
IndexOffset uint64
// NumMaxInvoices is the maximum number of invoices that should be
// starting from the add index.
NumMaxInvoices uint64
// PendingOnly, if set, returns unsettled invoices starting from the
// add index.
PendingOnly bool
// Reversed, if set, indicates that the invoices returned should start
// from the IndexOffset and go backwards.
Reversed bool
}
// InvoiceSlice is the response to a invoice query. It includes the original
// query, the set of invoices that match the query, and an integer which
// represents the offset index of the last item in the set of returned invoices.
// This integer allows callers to resume their query using this offset in the
// event that the query's response exceeds the maximum number of returnable
// invoices.
type InvoiceSlice struct {
InvoiceQuery
// Invoices is the set of invoices that matched the query above.
Invoices []Invoice
// FirstIndexOffset is the index of the first element in the set of
// returned Invoices above. Callers can use this to resume their query
// in the event that the slice has too many events to fit into a single
// response.
FirstIndexOffset uint64
// LastIndexOffset is the index of the last element in the set of
// returned Invoices above. Callers can use this to resume their query
// in the event that the slice has too many events to fit into a single
// response.
LastIndexOffset uint64
}
// QueryInvoices allows a caller to query the invoice database for invoices
// within the specified add index range.
func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
resp := InvoiceSlice{
InvoiceQuery: q,
}
err := d.View(func(tx *bbolt.Tx) error {
// If the bucket wasn't found, then there aren't any invoices
// within the database yet, so we can simply exit.
invoices := tx.Bucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
}
invoiceAddIndex := invoices.Bucket(addIndexBucket)
if invoiceAddIndex == nil {
return ErrNoInvoicesCreated
}
// keyForIndex is a helper closure that retrieves the invoice
// key for the given add index of an invoice.
keyForIndex := func(c *bbolt.Cursor, index uint64) []byte {
var keyIndex [8]byte
byteOrder.PutUint64(keyIndex[:], index)
_, invoiceKey := c.Seek(keyIndex[:])
return invoiceKey
}
// nextKey is a helper closure to determine what the next
// invoice key is when iterating over the invoice add index.
nextKey := func(c *bbolt.Cursor) ([]byte, []byte) {
if q.Reversed {
return c.Prev()
}
return c.Next()
}
// We'll be using a cursor to seek into the database and return
// a slice of invoices. We'll need to determine where to start
// our cursor depending on the parameters set within the query.
c := invoiceAddIndex.Cursor()
invoiceKey := keyForIndex(c, q.IndexOffset+1)
// If the query is specifying reverse iteration, then we must
// handle a few offset cases.
if q.Reversed {
switch q.IndexOffset {
// This indicates the default case, where no offset was
// specified. In that case we just start from the last
// invoice.
case 0:
_, invoiceKey = c.Last()
// This indicates the offset being set to the very
// first invoice. Since there are no invoices before
// this offset, and the direction is reversed, we can
// return without adding any invoices to the response.
case 1:
return nil
// Otherwise we start iteration at the invoice prior to
// the offset.
default:
invoiceKey = keyForIndex(c, q.IndexOffset-1)
}
}
// If we know that a set of invoices exists, then we'll begin
// our seek through the bucket in order to satisfy the query.
// We'll continue until either we reach the end of the range, or
// reach our max number of invoices.
for ; invoiceKey != nil; _, invoiceKey = nextKey(c) {
// If our current return payload exceeds the max number
// of invoices, then we'll exit now.
if uint64(len(resp.Invoices)) >= q.NumMaxInvoices {
break
}
invoice, err := fetchInvoice(invoiceKey, invoices)
if err != nil {
return err
}
// Skip any settled invoices if the caller is only
// interested in unsettled.
if q.PendingOnly &&
invoice.Terms.State == ContractSettled {
continue
}
// At this point, we've exhausted the offset, so we'll
// begin collecting invoices found within the range.
resp.Invoices = append(resp.Invoices, invoice)
}
// If we iterated through the add index in reverse order, then
// we'll need to reverse the slice of invoices to return them in
// forward order.
if q.Reversed {
numInvoices := len(resp.Invoices)
for i := 0; i < numInvoices/2; i++ {
opposite := numInvoices - i - 1
resp.Invoices[i], resp.Invoices[opposite] =
resp.Invoices[opposite], resp.Invoices[i]
}
}
return nil
})
if err != nil && err != ErrNoInvoicesCreated {
return resp, err
}
// Finally, record the indexes of the first and last invoices returned
// so that the caller can resume from this point later on.
if len(resp.Invoices) > 0 {
resp.FirstIndexOffset = resp.Invoices[0].AddIndex
resp.LastIndexOffset = resp.Invoices[len(resp.Invoices)-1].AddIndex
}
return resp, nil
}
// UpdateInvoice attempts to update an invoice corresponding to the passed
// payment hash. If an invoice matching the passed payment hash doesn't exist
// within the database, then the action will fail with a "not found" error.
//
// The update is performed inside the same database transaction that fetches the
// invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback.
func (d *DB) UpdateInvoice(paymentHash lntypes.Hash,
callback InvoiceUpdateCallback) (*Invoice, error) {
var updatedInvoice *Invoice
err := d.Update(func(tx *bbolt.Tx) error {
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
if err != nil {
return err
}
invoiceIndex, err := invoices.CreateBucketIfNotExists(
invoiceIndexBucket,
)
if err != nil {
return err
}
settleIndex, err := invoices.CreateBucketIfNotExists(
settleIndexBucket,
)
if err != nil {
return err
}
// Check the invoice index to see if an invoice paying to this
// hash exists within the DB.
invoiceNum := invoiceIndex.Get(paymentHash[:])
if invoiceNum == nil {
return ErrInvoiceNotFound
}
updatedInvoice, err = d.updateInvoice(
paymentHash, invoices, settleIndex, invoiceNum,
callback,
)
return err
})
return updatedInvoice, err
}
// InvoicesSettledSince can be used by callers to catch up any settled invoices
// they missed within the settled invoice time series. We'll return all known
// settled invoice that have a settle index higher than the passed
// sinceSettleIndex.
//
// NOTE: The index starts from 1, as a result. We enforce that specifying a
// value below the starting index value is a noop.
func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
var settledInvoices []Invoice
// If an index of zero was specified, then in order to maintain
// backwards compat, we won't send out any new invoices.
if sinceSettleIndex == 0 {
return settledInvoices, nil
}
var startIndex [8]byte
byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
err := d.DB.View(func(tx *bbolt.Tx) error {
invoices := tx.Bucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
}
settleIndex := invoices.Bucket(settleIndexBucket)
if settleIndex == nil {
return ErrNoInvoicesCreated
}
// We'll now run through each entry in the add index starting
// at our starting index. We'll continue until we reach the
// very end of the current key space.
invoiceCursor := settleIndex.Cursor()
// We'll seek to the starting index, then manually advance the
// cursor in order to skip the entry with the since add index.
invoiceCursor.Seek(startIndex[:])
seqNo, invoiceKey := invoiceCursor.Next()
for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, invoiceKey = invoiceCursor.Next() {
// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(invoiceKey, invoices)
if err != nil {
return err
}
settledInvoices = append(settledInvoices, invoice)
}
return nil
})
if err != nil {
return nil, err
}
return settledInvoices, nil
}
func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket,
i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
uint64, error) {
// Create the invoice key which is just the big-endian representation
// of the invoice number.
var invoiceKey [4]byte
byteOrder.PutUint32(invoiceKey[:], invoiceNum)
// Increment the num invoice counter index so the next invoice bares
// the proper ID.
var scratch [4]byte
invoiceCounter := invoiceNum + 1
byteOrder.PutUint32(scratch[:], invoiceCounter)
if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
return 0, err
}
// Add the payment hash to the invoice index. This will let us quickly
// identify if we can settle an incoming payment, and also to possibly
// allow a single invoice to have multiple payment installations.
err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
if err != nil {
return 0, err
}
// Next, we'll obtain the next add invoice index (sequence
// number), so we can properly place this invoice within this
// event stream.
nextAddSeqNo, err := addIndex.NextSequence()
if err != nil {
return 0, err
}
// With the next sequence obtained, we'll updating the event series in
// the add index bucket to map this current add counter to the index of
// this new invoice.
var seqNoBytes [8]byte
byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
return 0, err
}
i.AddIndex = nextAddSeqNo
// Finally, serialize the invoice itself to be written to the disk.
var buf bytes.Buffer
if err := serializeInvoice(&buf, i); err != nil {
return 0, err
}
if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
return 0, err
}
return nextAddSeqNo, nil
}
// serializeInvoice serializes an invoice to a writer.
//
// Note: this function is in use for a migration. Before making changes that
@ -1006,17 +410,6 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
return nil
}
func fetchInvoice(invoiceNum []byte, invoices *bbolt.Bucket) (Invoice, error) {
invoiceBytes := invoices.Get(invoiceNum)
if invoiceBytes == nil {
return Invoice{}, ErrInvoiceNotFound
}
invoiceReader := bytes.NewReader(invoiceBytes)
return deserializeInvoice(invoiceReader)
}
func deserializeInvoice(r io.Reader) (Invoice, error) {
var err error
invoice := Invoice{}
@ -1155,166 +548,3 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
return htlcs, nil
}
// copySlice allocates a new slice and copies the source into it.
func copySlice(src []byte) []byte {
dest := make([]byte, len(src))
copy(dest, src)
return dest
}
// copyInvoice makes a deep copy of the supplied invoice.
func copyInvoice(src *Invoice) *Invoice {
dest := Invoice{
Memo: copySlice(src.Memo),
Receipt: copySlice(src.Receipt),
PaymentRequest: copySlice(src.PaymentRequest),
FinalCltvDelta: src.FinalCltvDelta,
CreationDate: src.CreationDate,
SettleDate: src.SettleDate,
Terms: src.Terms,
AddIndex: src.AddIndex,
SettleIndex: src.SettleIndex,
AmtPaid: src.AmtPaid,
Htlcs: make(
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
),
}
for k, v := range src.Htlcs {
dest.Htlcs[k] = v
}
return &dest
}
// updateInvoice fetches the invoice, obtains the update descriptor from the
// callback and applies the updates in a single db transaction.
func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucket,
invoiceNum []byte, callback InvoiceUpdateCallback) (*Invoice, error) {
invoice, err := fetchInvoice(invoiceNum, invoices)
if err != nil {
return nil, err
}
preUpdateState := invoice.Terms.State
// Create deep copy to prevent any accidental modification in the
// callback.
copy := copyInvoice(&invoice)
// Call the callback and obtain the update descriptor.
update, err := callback(copy)
if err != nil {
return &invoice, err
}
// Update invoice state.
invoice.Terms.State = update.State
now := d.now()
// Update htlc set.
for key, htlcUpdate := range update.Htlcs {
htlc, ok := invoice.Htlcs[key]
// No update means the htlc needs to be canceled.
if htlcUpdate == nil {
if !ok {
return nil, fmt.Errorf("unknown htlc %v", key)
}
if htlc.State != HtlcStateAccepted {
return nil, fmt.Errorf("can only cancel " +
"accepted htlcs")
}
htlc.State = HtlcStateCanceled
htlc.ResolveTime = now
invoice.AmtPaid -= htlc.Amt
continue
}
// Add new htlc paying to the invoice.
if ok {
return nil, fmt.Errorf("htlc %v already exists", key)
}
htlc = &InvoiceHTLC{
Amt: htlcUpdate.Amt,
Expiry: htlcUpdate.Expiry,
AcceptHeight: uint32(htlcUpdate.AcceptHeight),
AcceptTime: now,
}
if preUpdateState == ContractSettled {
htlc.State = HtlcStateSettled
htlc.ResolveTime = now
} else {
htlc.State = HtlcStateAccepted
}
invoice.Htlcs[key] = htlc
invoice.AmtPaid += htlc.Amt
}
// If invoice moved to the settled state, update settle index and settle
// time.
if preUpdateState != invoice.Terms.State &&
invoice.Terms.State == ContractSettled {
if update.Preimage.Hash() != hash {
return nil, fmt.Errorf("preimage does not match")
}
invoice.Terms.PaymentPreimage = update.Preimage
// Settle all accepted htlcs.
for _, htlc := range invoice.Htlcs {
if htlc.State != HtlcStateAccepted {
continue
}
htlc.State = HtlcStateSettled
htlc.ResolveTime = now
}
err := setSettleFields(settleIndex, invoiceNum, &invoice, now)
if err != nil {
return nil, err
}
}
var buf bytes.Buffer
if err := serializeInvoice(&buf, &invoice); err != nil {
return nil, err
}
if err := invoices.Put(invoiceNum[:], buf.Bytes()); err != nil {
return nil, err
}
return &invoice, nil
}
func setSettleFields(settleIndex *bbolt.Bucket, invoiceNum []byte,
invoice *Invoice, now time.Time) error {
// Now that we know the invoice hasn't already been settled, we'll
// update the settle index so we can place this settle event in the
// proper location within our time series.
nextSettleSeqNo, err := settleIndex.NextSequence()
if err != nil {
return err
}
var seqNoBytes [8]byte
byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil {
return err
}
invoice.Terms.State = ContractSettled
invoice.SettleDate = now
invoice.SettleIndex = nextSettleSeqNo
return nil
}

View file

@ -1,316 +0,0 @@
package migration_01_to_11
import (
"bytes"
"io"
"net"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt"
)
var (
// nodeInfoBucket stores metadata pertaining to nodes that we've had
// direct channel-based correspondence with. This bucket allows one to
// query for all open channels pertaining to the node by exploring each
// node's sub-bucket within the openChanBucket.
nodeInfoBucket = []byte("nib")
)
// LinkNode stores metadata related to node's that we have/had a direct
// channel open with. Information such as the Bitcoin network the node
// advertised, and its identity public key are also stored. Additionally, this
// struct and the bucket its stored within have store data similar to that of
// Bitcoin's addrmanager. The TCP address information stored within the struct
// can be used to establish persistent connections will all channel
// counterparties on daemon startup.
//
// TODO(roasbeef): also add current OnionKey plus rotation schedule?
// TODO(roasbeef): add bitfield for supported services
// * possibly add a wire.NetAddress type, type
type LinkNode struct {
// Network indicates the Bitcoin network that the LinkNode advertises
// for incoming channel creation.
Network wire.BitcoinNet
// IdentityPub is the node's current identity public key. Any
// channel/topology related information received by this node MUST be
// signed by this public key.
IdentityPub *btcec.PublicKey
// LastSeen tracks the last time this node was seen within the network.
// A node should be marked as seen if the daemon either is able to
// establish an outgoing connection to the node or receives a new
// incoming connection from the node. This timestamp (stored in unix
// epoch) may be used within a heuristic which aims to determine when a
// channel should be unilaterally closed due to inactivity.
//
// TODO(roasbeef): replace with block hash/height?
// * possibly add a time-value metric into the heuristic?
LastSeen time.Time
// Addresses is a list of IP address in which either we were able to
// reach the node over in the past, OR we received an incoming
// authenticated connection for the stored identity public key.
Addresses []net.Addr
db *DB
}
// NewLinkNode creates a new LinkNode from the provided parameters, which is
// backed by an instance of channeldb.
func (db *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey,
addrs ...net.Addr) *LinkNode {
return &LinkNode{
Network: bitNet,
IdentityPub: pub,
LastSeen: time.Now(),
Addresses: addrs,
db: db,
}
}
// UpdateLastSeen updates the last time this node was directly encountered on
// the Lightning Network.
func (l *LinkNode) UpdateLastSeen(lastSeen time.Time) error {
l.LastSeen = lastSeen
return l.Sync()
}
// AddAddress appends the specified TCP address to the list of known addresses
// this node is/was known to be reachable at.
func (l *LinkNode) AddAddress(addr net.Addr) error {
for _, a := range l.Addresses {
if a.String() == addr.String() {
return nil
}
}
l.Addresses = append(l.Addresses, addr)
return l.Sync()
}
// Sync performs a full database sync which writes the current up-to-date data
// within the struct to the database.
func (l *LinkNode) Sync() error {
// Finally update the database by storing the link node and updating
// any relevant indexes.
return l.db.Update(func(tx *bbolt.Tx) error {
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return ErrLinkNodesNotFound
}
return putLinkNode(nodeMetaBucket, l)
})
}
// putLinkNode serializes then writes the encoded version of the passed link
// node into the nodeMetaBucket. This function is provided in order to allow
// the ability to re-use a database transaction across many operations.
func putLinkNode(nodeMetaBucket *bbolt.Bucket, l *LinkNode) error {
// First serialize the LinkNode into its raw-bytes encoding.
var b bytes.Buffer
if err := serializeLinkNode(&b, l); err != nil {
return err
}
// Finally insert the link-node into the node metadata bucket keyed
// according to the its pubkey serialized in compressed form.
nodePub := l.IdentityPub.SerializeCompressed()
return nodeMetaBucket.Put(nodePub, b.Bytes())
}
// DeleteLinkNode removes the link node with the given identity from the
// database.
func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error {
return db.Update(func(tx *bbolt.Tx) error {
return db.deleteLinkNode(tx, identity)
})
}
func (db *DB) deleteLinkNode(tx *bbolt.Tx, identity *btcec.PublicKey) error {
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return ErrLinkNodesNotFound
}
pubKey := identity.SerializeCompressed()
return nodeMetaBucket.Delete(pubKey)
}
// FetchLinkNode attempts to lookup the data for a LinkNode based on a target
// identity public key. If a particular LinkNode for the passed identity public
// key cannot be found, then ErrNodeNotFound if returned.
func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
var linkNode *LinkNode
err := db.View(func(tx *bbolt.Tx) error {
node, err := fetchLinkNode(tx, identity)
if err != nil {
return err
}
linkNode = node
return nil
})
return linkNode, err
}
func fetchLinkNode(tx *bbolt.Tx, targetPub *btcec.PublicKey) (*LinkNode, error) {
// First fetch the bucket for storing node metadata, bailing out early
// if it hasn't been created yet.
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return nil, ErrLinkNodesNotFound
}
// If a link node for that particular public key cannot be located,
// then exit early with an ErrNodeNotFound.
pubKey := targetPub.SerializeCompressed()
nodeBytes := nodeMetaBucket.Get(pubKey)
if nodeBytes == nil {
return nil, ErrNodeNotFound
}
// Finally, decode and allocate a fresh LinkNode object to be returned
// to the caller.
nodeReader := bytes.NewReader(nodeBytes)
return deserializeLinkNode(nodeReader)
}
// TODO(roasbeef): update link node addrs in server upon connection
// FetchAllLinkNodes starts a new database transaction to fetch all nodes with
// whom we have active channels with.
func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
var linkNodes []*LinkNode
err := db.View(func(tx *bbolt.Tx) error {
nodes, err := db.fetchAllLinkNodes(tx)
if err != nil {
return err
}
linkNodes = nodes
return nil
})
if err != nil {
return nil, err
}
return linkNodes, nil
}
// fetchAllLinkNodes uses an existing database transaction to fetch all nodes
// with whom we have active channels with.
func (db *DB) fetchAllLinkNodes(tx *bbolt.Tx) ([]*LinkNode, error) {
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return nil, ErrLinkNodesNotFound
}
var linkNodes []*LinkNode
err := nodeMetaBucket.ForEach(func(k, v []byte) error {
if v == nil {
return nil
}
nodeReader := bytes.NewReader(v)
linkNode, err := deserializeLinkNode(nodeReader)
if err != nil {
return err
}
linkNodes = append(linkNodes, linkNode)
return nil
})
if err != nil {
return nil, err
}
return linkNodes, nil
}
func serializeLinkNode(w io.Writer, l *LinkNode) error {
var buf [8]byte
byteOrder.PutUint32(buf[:4], uint32(l.Network))
if _, err := w.Write(buf[:4]); err != nil {
return err
}
serializedID := l.IdentityPub.SerializeCompressed()
if _, err := w.Write(serializedID); err != nil {
return err
}
seenUnix := uint64(l.LastSeen.Unix())
byteOrder.PutUint64(buf[:], seenUnix)
if _, err := w.Write(buf[:]); err != nil {
return err
}
numAddrs := uint32(len(l.Addresses))
byteOrder.PutUint32(buf[:4], numAddrs)
if _, err := w.Write(buf[:4]); err != nil {
return err
}
for _, addr := range l.Addresses {
if err := serializeAddr(w, addr); err != nil {
return err
}
}
return nil
}
func deserializeLinkNode(r io.Reader) (*LinkNode, error) {
var (
err error
buf [8]byte
)
node := &LinkNode{}
if _, err := io.ReadFull(r, buf[:4]); err != nil {
return nil, err
}
node.Network = wire.BitcoinNet(byteOrder.Uint32(buf[:4]))
var pub [33]byte
if _, err := io.ReadFull(r, pub[:]); err != nil {
return nil, err
}
node.IdentityPub, err = btcec.ParsePubKey(pub[:], btcec.S256())
if err != nil {
return nil, err
}
if _, err := io.ReadFull(r, buf[:]); err != nil {
return nil, err
}
node.LastSeen = time.Unix(int64(byteOrder.Uint64(buf[:])), 0)
if _, err := io.ReadFull(r, buf[:4]); err != nil {
return nil, err
}
numAddrs := byteOrder.Uint32(buf[:4])
node.Addresses = make([]net.Addr, numAddrs)
for i := uint32(0); i < numAddrs; i++ {
addr, err := deserializeAddr(r)
if err != nil {
return nil, err
}
node.Addresses[i] = addr
}
return node, nil
}

View file

@ -1,140 +0,0 @@
package migration_01_to_11
import (
"bytes"
"net"
"testing"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/wire"
)
func TestLinkNodeEncodeDecode(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
// First we'll create some initial data to use for populating our test
// LinkNode instances.
_, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
_, pub2 := btcec.PrivKeyFromBytes(btcec.S256(), rev[:])
addr1, err := net.ResolveTCPAddr("tcp", "10.0.0.1:9000")
if err != nil {
t.Fatalf("unable to create test addr: %v", err)
}
addr2, err := net.ResolveTCPAddr("tcp", "10.0.0.2:9000")
if err != nil {
t.Fatalf("unable to create test addr: %v", err)
}
// Create two fresh link node instances with the above dummy data, then
// fully sync both instances to disk.
node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1)
node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2)
if err := node1.Sync(); err != nil {
t.Fatalf("unable to sync node: %v", err)
}
if err := node2.Sync(); err != nil {
t.Fatalf("unable to sync node: %v", err)
}
// Fetch all current link nodes from the database, they should exactly
// match the two created above.
originalNodes := []*LinkNode{node2, node1}
linkNodes, err := cdb.FetchAllLinkNodes()
if err != nil {
t.Fatalf("unable to fetch nodes: %v", err)
}
for i, node := range linkNodes {
if originalNodes[i].Network != node.Network {
t.Fatalf("node networks don't match: expected %v, got %v",
originalNodes[i].Network, node.Network)
}
originalPubkey := originalNodes[i].IdentityPub.SerializeCompressed()
dbPubkey := node.IdentityPub.SerializeCompressed()
if !bytes.Equal(originalPubkey, dbPubkey) {
t.Fatalf("node pubkeys don't match: expected %x, got %x",
originalPubkey, dbPubkey)
}
if originalNodes[i].LastSeen.Unix() != node.LastSeen.Unix() {
t.Fatalf("last seen timestamps don't match: expected %v got %v",
originalNodes[i].LastSeen.Unix(), node.LastSeen.Unix())
}
if originalNodes[i].Addresses[0].String() != node.Addresses[0].String() {
t.Fatalf("addresses don't match: expected %v, got %v",
originalNodes[i].Addresses, node.Addresses)
}
}
// Next, we'll exercise the methods to append additional IP
// addresses, and also to update the last seen time.
if err := node1.UpdateLastSeen(time.Now()); err != nil {
t.Fatalf("unable to update last seen: %v", err)
}
if err := node1.AddAddress(addr2); err != nil {
t.Fatalf("unable to update addr: %v", err)
}
// Fetch the same node from the database according to its public key.
node1DB, err := cdb.FetchLinkNode(pub1)
if err != nil {
t.Fatalf("unable to find node: %v", err)
}
// Both the last seen timestamp and the list of reachable addresses for
// the node should be updated.
if node1DB.LastSeen.Unix() != node1.LastSeen.Unix() {
t.Fatalf("last seen timestamps don't match: expected %v got %v",
node1.LastSeen.Unix(), node1DB.LastSeen.Unix())
}
if len(node1DB.Addresses) != 2 {
t.Fatalf("wrong length for node1 addresses: expected %v, got %v",
2, len(node1DB.Addresses))
}
if node1DB.Addresses[0].String() != addr1.String() {
t.Fatalf("wrong address for node: expected %v, got %v",
addr1.String(), node1DB.Addresses[0].String())
}
if node1DB.Addresses[1].String() != addr2.String() {
t.Fatalf("wrong address for node: expected %v, got %v",
addr2.String(), node1DB.Addresses[1].String())
}
}
func TestDeleteLinkNode(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
_, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1337,
}
linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr)
if err := linkNode.Sync(); err != nil {
t.Fatalf("unable to write link node to db: %v", err)
}
if _, err := cdb.FetchLinkNode(pubKey); err != nil {
t.Fatalf("unable to find link node: %v", err)
}
if err := cdb.DeleteLinkNode(pubKey); err != nil {
t.Fatalf("unable to delete link node from db: %v", err)
}
if _, err := cdb.FetchLinkNode(pubKey); err == nil {
t.Fatal("should not have found link node in db, but did")
}
}

View file

@ -39,24 +39,3 @@ func DefaultOptions() Options {
// OptionModifier is a function signature for modifying the default Options.
type OptionModifier func(*Options)
// OptionSetRejectCacheSize sets the RejectCacheSize to n.
func OptionSetRejectCacheSize(n int) OptionModifier {
return func(o *Options) {
o.RejectCacheSize = n
}
}
// OptionSetChannelCacheSize sets the ChannelCacheSize to n.
func OptionSetChannelCacheSize(n int) OptionModifier {
return func(o *Options) {
o.ChannelCacheSize = n
}
}
// OptionSetSyncFreelist allows the database to sync its freelist.
func OptionSetSyncFreelist(b bool) OptionModifier {
return func(o *Options) {
o.NoFreelistSync = !b
}
}

View file

@ -1,373 +1,9 @@
package migration_01_to_11
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/routing/route"
)
var (
// ErrAlreadyPaid signals we have already paid this payment hash.
ErrAlreadyPaid = errors.New("invoice is already paid")
// ErrPaymentInFlight signals that payment for this payment hash is
// already "in flight" on the network.
ErrPaymentInFlight = errors.New("payment is in transition")
// ErrPaymentNotInitiated is returned if payment wasn't initiated in
// switch.
ErrPaymentNotInitiated = errors.New("payment isn't initiated")
// ErrPaymentAlreadySucceeded is returned in the event we attempt to
// change the status of a payment already succeeded.
ErrPaymentAlreadySucceeded = errors.New("payment is already succeeded")
// ErrPaymentAlreadyFailed is returned in the event we attempt to
// re-fail a failed payment.
ErrPaymentAlreadyFailed = errors.New("payment has already failed")
// ErrUnknownPaymentStatus is returned when we do not recognize the
// existing state of a payment.
ErrUnknownPaymentStatus = errors.New("unknown payment status")
// errNoAttemptInfo is returned when no attempt info is stored yet.
errNoAttemptInfo = errors.New("unable to find attempt info for " +
"inflight payment")
)
// PaymentControl implements persistence for payments and payment attempts.
type PaymentControl struct {
db *DB
}
// NewPaymentControl creates a new instance of the PaymentControl.
func NewPaymentControl(db *DB) *PaymentControl {
return &PaymentControl{
db: db,
}
}
// InitPayment checks or records the given PaymentCreationInfo with the DB,
// making sure it does not already exist as an in-flight payment. Then this
// method returns successfully, the payment is guranteeed to be in the InFlight
// state.
func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
info *PaymentCreationInfo) error {
var b bytes.Buffer
if err := serializePaymentCreationInfo(&b, info); err != nil {
return err
}
infoBytes := b.Bytes()
var updateErr error
err := p.db.Batch(func(tx *bbolt.Tx) error {
// Reset the update error, to avoid carrying over an error
// from a previous execution of the batched db transaction.
updateErr = nil
bucket, err := createPaymentBucket(tx, paymentHash)
if err != nil {
return err
}
// Get the existing status of this payment, if any.
paymentStatus := fetchPaymentStatus(bucket)
switch paymentStatus {
// We allow retrying failed payments.
case StatusFailed:
// This is a new payment that is being initialized for the
// first time.
case StatusUnknown:
// We already have an InFlight payment on the network. We will
// disallow any new payments.
case StatusInFlight:
updateErr = ErrPaymentInFlight
return nil
// We've already succeeded a payment to this payment hash,
// forbid the switch from sending another.
case StatusSucceeded:
updateErr = ErrAlreadyPaid
return nil
default:
updateErr = ErrUnknownPaymentStatus
return nil
}
// Obtain a new sequence number for this payment. This is used
// to sort the payments in order of creation, and also acts as
// a unique identifier for each payment.
sequenceNum, err := nextPaymentSequence(tx)
if err != nil {
return err
}
err = bucket.Put(paymentSequenceKey, sequenceNum)
if err != nil {
return err
}
// Add the payment info to the bucket, which contains the
// static information for this payment
err = bucket.Put(paymentCreationInfoKey, infoBytes)
if err != nil {
return err
}
// We'll delete any lingering attempt info to start with, in
// case we are initializing a payment that was attempted
// earlier, but left in a state where we could retry.
err = bucket.Delete(paymentAttemptInfoKey)
if err != nil {
return err
}
// Also delete any lingering failure info now that we are
// re-attempting.
return bucket.Delete(paymentFailInfoKey)
})
if err != nil {
return err
}
return updateErr
}
// RegisterAttempt atomically records the provided PaymentAttemptInfo to the
// DB.
func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
attempt *PaymentAttemptInfo) error {
// Serialize the information before opening the db transaction.
var a bytes.Buffer
if err := serializePaymentAttemptInfo(&a, attempt); err != nil {
return err
}
attemptBytes := a.Bytes()
var updateErr error
err := p.db.Batch(func(tx *bbolt.Tx) error {
// Reset the update error, to avoid carrying over an error
// from a previous execution of the batched db transaction.
updateErr = nil
bucket, err := fetchPaymentBucket(tx, paymentHash)
if err == ErrPaymentNotInitiated {
updateErr = ErrPaymentNotInitiated
return nil
} else if err != nil {
return err
}
// We can only register attempts for payments that are
// in-flight.
if err := ensureInFlight(bucket); err != nil {
updateErr = err
return nil
}
// Add the payment attempt to the payments bucket.
return bucket.Put(paymentAttemptInfoKey, attemptBytes)
})
if err != nil {
return err
}
return updateErr
}
// Success transitions a payment into the Succeeded state. After invoking this
// method, InitPayment should always return an error to prevent us from making
// duplicate payments to the same payment hash. The provided preimage is
// atomically saved to the DB for record keeping.
func (p *PaymentControl) Success(paymentHash lntypes.Hash,
preimage lntypes.Preimage) (*route.Route, error) {
var (
updateErr error
route *route.Route
)
err := p.db.Batch(func(tx *bbolt.Tx) error {
// Reset the update error, to avoid carrying over an error
// from a previous execution of the batched db transaction.
updateErr = nil
bucket, err := fetchPaymentBucket(tx, paymentHash)
if err == ErrPaymentNotInitiated {
updateErr = ErrPaymentNotInitiated
return nil
} else if err != nil {
return err
}
// We can only mark in-flight payments as succeeded.
if err := ensureInFlight(bucket); err != nil {
updateErr = err
return nil
}
// Record the successful payment info atomically to the
// payments record.
err = bucket.Put(paymentSettleInfoKey, preimage[:])
if err != nil {
return err
}
// Retrieve attempt info for the notification.
attempt, err := fetchPaymentAttempt(bucket)
if err != nil {
return err
}
route = &attempt.Route
return nil
})
if err != nil {
return nil, err
}
return route, updateErr
}
// Fail transitions a payment into the Failed state, and records the reason the
// payment failed. After invoking this method, InitPayment should return nil on
// its next call for this payment hash, allowing the switch to make a
// subsequent payment.
func (p *PaymentControl) Fail(paymentHash lntypes.Hash,
reason FailureReason) (*route.Route, error) {
var (
updateErr error
route *route.Route
)
err := p.db.Batch(func(tx *bbolt.Tx) error {
// Reset the update error, to avoid carrying over an error
// from a previous execution of the batched db transaction.
updateErr = nil
bucket, err := fetchPaymentBucket(tx, paymentHash)
if err == ErrPaymentNotInitiated {
updateErr = ErrPaymentNotInitiated
return nil
} else if err != nil {
return err
}
// We can only mark in-flight payments as failed.
if err := ensureInFlight(bucket); err != nil {
updateErr = err
return nil
}
// Put the failure reason in the bucket for record keeping.
v := []byte{byte(reason)}
err = bucket.Put(paymentFailInfoKey, v)
if err != nil {
return err
}
// Retrieve attempt info for the notification, if available.
attempt, err := fetchPaymentAttempt(bucket)
if err != nil && err != errNoAttemptInfo {
return err
}
if err != errNoAttemptInfo {
route = &attempt.Route
}
return nil
})
if err != nil {
return nil, err
}
return route, updateErr
}
// FetchPayment returns information about a payment from the database.
func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
*Payment, error) {
var payment *Payment
err := p.db.View(func(tx *bbolt.Tx) error {
bucket, err := fetchPaymentBucket(tx, paymentHash)
if err != nil {
return err
}
payment, err = fetchPayment(bucket)
return err
})
if err != nil {
return nil, err
}
return payment, nil
}
// createPaymentBucket creates or fetches the sub-bucket assigned to this
// payment hash.
func createPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) (
*bbolt.Bucket, error) {
payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket)
if err != nil {
return nil, err
}
return payments.CreateBucketIfNotExists(paymentHash[:])
}
// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If
// the bucket does not exist, it returns ErrPaymentNotInitiated.
func fetchPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) (
*bbolt.Bucket, error) {
payments := tx.Bucket(paymentsRootBucket)
if payments == nil {
return nil, ErrPaymentNotInitiated
}
bucket := payments.Bucket(paymentHash[:])
if bucket == nil {
return nil, ErrPaymentNotInitiated
}
return bucket, nil
}
// nextPaymentSequence returns the next sequence number to store for a new
// payment.
func nextPaymentSequence(tx *bbolt.Tx) ([]byte, error) {
payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket)
if err != nil {
return nil, err
}
seq, err := payments.NextSequence()
if err != nil {
return nil, err
}
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, seq)
return b, nil
}
// fetchPaymentStatus fetches the payment status of the payment. If the payment
// isn't found, it will default to "StatusUnknown".
func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus {
@ -385,113 +21,3 @@ func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus {
return StatusUnknown
}
// ensureInFlight checks whether the payment found in the given bucket has
// status InFlight, and returns an error otherwise. This should be used to
// ensure we only mark in-flight payments as succeeded or failed.
func ensureInFlight(bucket *bbolt.Bucket) error {
paymentStatus := fetchPaymentStatus(bucket)
switch {
// The payment was indeed InFlight, return.
case paymentStatus == StatusInFlight:
return nil
// Our records show the payment as unknown, meaning it never
// should have left the switch.
case paymentStatus == StatusUnknown:
return ErrPaymentNotInitiated
// The payment succeeded previously.
case paymentStatus == StatusSucceeded:
return ErrPaymentAlreadySucceeded
// The payment was already failed.
case paymentStatus == StatusFailed:
return ErrPaymentAlreadyFailed
default:
return ErrUnknownPaymentStatus
}
}
// fetchPaymentAttempt fetches the payment attempt from the bucket.
func fetchPaymentAttempt(bucket *bbolt.Bucket) (*PaymentAttemptInfo, error) {
attemptData := bucket.Get(paymentAttemptInfoKey)
if attemptData == nil {
return nil, errNoAttemptInfo
}
r := bytes.NewReader(attemptData)
return deserializePaymentAttemptInfo(r)
}
// InFlightPayment is a wrapper around a payment that has status InFlight.
type InFlightPayment struct {
// Info is the PaymentCreationInfo of the in-flight payment.
Info *PaymentCreationInfo
// Attempt contains information about the last payment attempt that was
// made to this payment hash.
//
// NOTE: Might be nil.
Attempt *PaymentAttemptInfo
}
// FetchInFlightPayments returns all payments with status InFlight.
func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) {
var inFlights []*InFlightPayment
err := p.db.View(func(tx *bbolt.Tx) error {
payments := tx.Bucket(paymentsRootBucket)
if payments == nil {
return nil
}
return payments.ForEach(func(k, _ []byte) error {
bucket := payments.Bucket(k)
if bucket == nil {
return fmt.Errorf("non bucket element")
}
// If the status is not InFlight, we can return early.
paymentStatus := fetchPaymentStatus(bucket)
if paymentStatus != StatusInFlight {
return nil
}
var (
inFlight = &InFlightPayment{}
err error
)
// Get the CreationInfo.
b := bucket.Get(paymentCreationInfoKey)
if b == nil {
return fmt.Errorf("unable to find creation " +
"info for inflight payment")
}
r := bytes.NewReader(b)
inFlight.Info, err = deserializePaymentCreationInfo(r)
if err != nil {
return err
}
// Now get the attempt info. It could be that there is
// no attempt info yet.
inFlight.Attempt, err = fetchPaymentAttempt(bucket)
if err != nil && err != errNoAttemptInfo {
return err
}
inFlights = append(inFlights, inFlight)
return nil
})
})
if err != nil {
return nil, err
}
return inFlights, nil
}

View file

@ -1,550 +0,0 @@
package migration_01_to_11
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
"reflect"
"testing"
"time"
"github.com/btcsuite/fastsha256"
"github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/routing/route"
)
func initDB() (*DB, error) {
tempPath, err := ioutil.TempDir("", "switchdb")
if err != nil {
return nil, err
}
db, err := Open(tempPath)
if err != nil {
return nil, err
}
return db, err
}
func genPreimage() ([32]byte, error) {
var preimage [32]byte
if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil {
return preimage, err
}
return preimage, nil
}
func genInfo() (*PaymentCreationInfo, *PaymentAttemptInfo,
lntypes.Preimage, error) {
preimage, err := genPreimage()
if err != nil {
return nil, nil, preimage, fmt.Errorf("unable to "+
"generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
return &PaymentCreationInfo{
PaymentHash: rhash,
Value: 1,
CreationDate: time.Unix(time.Now().Unix(), 0),
PaymentRequest: []byte("hola"),
},
&PaymentAttemptInfo{
PaymentID: 1,
SessionKey: priv,
Route: testRoute,
}, preimage, nil
}
// TestPaymentControlSwitchFail checks that payment status returns to Failed
// status after failing, and that InitPayment allows another HTLC for the
// same payment hash.
func TestPaymentControlSwitchFail(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(db)
info, attempt, preimg, err := genInfo()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
// Sends base htlc message which initiate StatusInFlight.
err = pControl.InitPayment(info.PaymentHash, info)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
nil,
)
// Fail the payment, which should moved it to Failed.
failReason := FailureReasonNoRoute
_, err = pControl.Fail(info.PaymentHash, failReason)
if err != nil {
t.Fatalf("unable to fail payment hash: %v", err)
}
// Verify the status is indeed Failed.
assertPaymentStatus(t, db, info.PaymentHash, StatusFailed)
assertPaymentInfo(
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
&failReason,
)
// Sends the htlc again, which should succeed since the prior payment
// failed.
err = pControl.InitPayment(info.PaymentHash, info)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
nil,
)
// Record a new attempt.
attempt.PaymentID = 2
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt, lntypes.Preimage{},
nil,
)
// Verifies that status was changed to StatusSucceeded.
var route *route.Route
route, err = pControl.Success(info.PaymentHash, preimg)
if err != nil {
t.Fatalf("error shouldn't have been received, got: %v", err)
}
err = assertRouteEqual(route, &attempt.Route)
if err != nil {
t.Fatalf("unexpected route returned: %v vs %v: %v",
spew.Sdump(attempt.Route), spew.Sdump(*route), err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil)
// Attempt a final payment, which should now fail since the prior
// payment succeed.
err = pControl.InitPayment(info.PaymentHash, info)
if err != ErrAlreadyPaid {
t.Fatalf("unable to send htlc message: %v", err)
}
}
// TestPaymentControlSwitchDoubleSend checks the ability of payment control to
// prevent double sending of htlc message, when message is in StatusInFlight.
func TestPaymentControlSwitchDoubleSend(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(db)
info, attempt, preimg, err := genInfo()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
// Sends base htlc message which initiate base status and move it to
// StatusInFlight and verifies that it was changed.
err = pControl.InitPayment(info.PaymentHash, info)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
nil,
)
// Try to initiate double sending of htlc message with the same
// payment hash, should result in error indicating that payment has
// already been sent.
err = pControl.InitPayment(info.PaymentHash, info)
if err != ErrPaymentInFlight {
t.Fatalf("payment control wrong behaviour: " +
"double sending must trigger ErrPaymentInFlight error")
}
// Record an attempt.
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt, lntypes.Preimage{},
nil,
)
// Sends base htlc message which initiate StatusInFlight.
err = pControl.InitPayment(info.PaymentHash, info)
if err != ErrPaymentInFlight {
t.Fatalf("payment control wrong behaviour: " +
"double sending must trigger ErrPaymentInFlight error")
}
// After settling, the error should be ErrAlreadyPaid.
if _, err := pControl.Success(info.PaymentHash, preimg); err != nil {
t.Fatalf("error shouldn't have been received, got: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil)
err = pControl.InitPayment(info.PaymentHash, info)
if err != ErrAlreadyPaid {
t.Fatalf("unable to send htlc message: %v", err)
}
}
// TestPaymentControlSuccessesWithoutInFlight checks that the payment
// control will disallow calls to Success when no payment is in flight.
func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(db)
info, _, preimg, err := genInfo()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
// Attempt to complete the payment should fail.
_, err = pControl.Success(info.PaymentHash, preimg)
if err != ErrPaymentNotInitiated {
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown)
assertPaymentInfo(
t, db, info.PaymentHash, nil, nil, lntypes.Preimage{},
nil,
)
}
// TestPaymentControlFailsWithoutInFlight checks that a strict payment
// control will disallow calls to Fail when no payment is in flight.
func TestPaymentControlFailsWithoutInFlight(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(db)
info, _, _, err := genInfo()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
// Calling Fail should return an error.
_, err = pControl.Fail(info.PaymentHash, FailureReasonNoRoute)
if err != ErrPaymentNotInitiated {
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown)
assertPaymentInfo(
t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil,
)
}
// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only
// deletes payments from the database that are not in-flight.
func TestPaymentControlDeleteNonInFligt(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(db)
payments := []struct {
failed bool
success bool
}{
{
failed: true,
success: false,
},
{
failed: false,
success: true,
},
{
failed: false,
success: false,
},
}
for _, p := range payments {
info, attempt, preimg, err := genInfo()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
// Sends base htlc message which initiate StatusInFlight.
err = pControl.InitPayment(info.PaymentHash, info)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
if p.failed {
// Fail the payment, which should moved it to Failed.
failReason := FailureReasonNoRoute
_, err = pControl.Fail(info.PaymentHash, failReason)
if err != nil {
t.Fatalf("unable to fail payment hash: %v", err)
}
// Verify the status is indeed Failed.
assertPaymentStatus(t, db, info.PaymentHash, StatusFailed)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt,
lntypes.Preimage{}, &failReason,
)
} else if p.success {
// Verifies that status was changed to StatusSucceeded.
_, err := pControl.Success(info.PaymentHash, preimg)
if err != nil {
t.Fatalf("error shouldn't have been received, got: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt, preimg, nil,
)
} else {
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt,
lntypes.Preimage{}, nil,
)
}
}
// Delete payments.
if err := db.DeletePayments(); err != nil {
t.Fatal(err)
}
// This should leave the in-flight payment.
dbPayments, err := db.FetchPayments()
if err != nil {
t.Fatal(err)
}
if len(dbPayments) != 1 {
t.Fatalf("expected one payment, got %d", len(dbPayments))
}
status := dbPayments[0].Status
if status != StatusInFlight {
t.Fatalf("expected in-fligth status, got %v", status)
}
}
func assertPaymentStatus(t *testing.T, db *DB,
hash [32]byte, expStatus PaymentStatus) {
t.Helper()
var paymentStatus = StatusUnknown
err := db.View(func(tx *bbolt.Tx) error {
payments := tx.Bucket(paymentsRootBucket)
if payments == nil {
return nil
}
bucket := payments.Bucket(hash[:])
if bucket == nil {
return nil
}
// Get the existing status of this payment, if any.
paymentStatus = fetchPaymentStatus(bucket)
return nil
})
if err != nil {
t.Fatalf("unable to fetch payment status: %v", err)
}
if paymentStatus != expStatus {
t.Fatalf("payment status mismatch: expected %v, got %v",
expStatus, paymentStatus)
}
}
func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error {
b := bucket.Get(paymentCreationInfoKey)
switch {
case b == nil && c == nil:
return nil
case b == nil:
return fmt.Errorf("expected creation info not found")
case c == nil:
return fmt.Errorf("unexpected creation info found")
}
r := bytes.NewReader(b)
c2, err := deserializePaymentCreationInfo(r)
if err != nil {
return err
}
if !reflect.DeepEqual(c, c2) {
return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v",
spew.Sdump(c), spew.Sdump(c2))
}
return nil
}
func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error {
b := bucket.Get(paymentAttemptInfoKey)
switch {
case b == nil && a == nil:
return nil
case b == nil:
return fmt.Errorf("expected attempt info not found")
case a == nil:
return fmt.Errorf("unexpected attempt info found")
}
r := bytes.NewReader(b)
a2, err := deserializePaymentAttemptInfo(r)
if err != nil {
return err
}
return assertRouteEqual(&a.Route, &a2.Route)
}
func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error {
zero := lntypes.Preimage{}
b := bucket.Get(paymentSettleInfoKey)
switch {
case b == nil && preimg == zero:
return nil
case b == nil:
return fmt.Errorf("expected preimage not found")
case preimg == zero:
return fmt.Errorf("unexpected preimage found")
}
var pre2 lntypes.Preimage
copy(pre2[:], b[:])
if preimg != pre2 {
return fmt.Errorf("Preimages don't match: %x vs %x",
preimg, pre2)
}
return nil
}
func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error {
b := bucket.Get(paymentFailInfoKey)
switch {
case b == nil && failReason == nil:
return nil
case b == nil:
return fmt.Errorf("expected fail info not found")
case failReason == nil:
return fmt.Errorf("unexpected fail info found")
}
failReason2 := FailureReason(b[0])
if *failReason != failReason2 {
return fmt.Errorf("Failure infos don't match: %v vs %v",
*failReason, failReason2)
}
return nil
}
func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash,
c *PaymentCreationInfo, a *PaymentAttemptInfo, s lntypes.Preimage,
f *FailureReason) {
t.Helper()
err := db.View(func(tx *bbolt.Tx) error {
payments := tx.Bucket(paymentsRootBucket)
if payments == nil && c == nil {
return nil
}
if payments == nil {
return fmt.Errorf("sent payments not found")
}
bucket := payments.Bucket(hash[:])
if bucket == nil && c == nil {
return nil
}
if bucket == nil {
return fmt.Errorf("payment not found")
}
if err := checkPaymentCreationInfo(bucket, c); err != nil {
return err
}
if err := checkPaymentAttemptInfo(bucket, a); err != nil {
return err
}
if err := checkSettleInfo(bucket, s); err != nil {
return err
}
if err := checkFailInfo(bucket, f); err != nil {
return err
}
return nil
})
if err != nil {
t.Fatalf("assert payment info failed: %v", err)
}
}

View file

@ -375,48 +375,6 @@ func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) {
return p, nil
}
// DeletePayments deletes all completed and failed payments from the DB.
func (db *DB) DeletePayments() error {
return db.Update(func(tx *bbolt.Tx) error {
payments := tx.Bucket(paymentsRootBucket)
if payments == nil {
return nil
}
var deleteBuckets [][]byte
err := payments.ForEach(func(k, _ []byte) error {
bucket := payments.Bucket(k)
if bucket == nil {
// We only expect sub-buckets to be found in
// this top-level bucket.
return fmt.Errorf("non bucket element in " +
"payments bucket")
}
// If the status is InFlight, we cannot safely delete
// the payment information, so we return early.
paymentStatus := fetchPaymentStatus(bucket)
if paymentStatus == StatusInFlight {
return nil
}
deleteBuckets = append(deleteBuckets, k)
return nil
})
if err != nil {
return err
}
for _, k := range deleteBuckets {
if err := payments.DeleteBucket(k); err != nil {
return err
}
}
return nil
})
}
func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
var scratch [8]byte

View file

@ -2,55 +2,17 @@ package migration_01_to_11
import (
"bytes"
"errors"
"fmt"
"math/rand"
"reflect"
"testing"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)
var (
priv, _ = btcec.NewPrivateKey(btcec.S256())
pub = priv.PubKey()
tlvBytes = []byte{1, 2, 3}
tlvEncoder = tlv.StubEncoder(tlvBytes)
testHop1 = &route.Hop{
PubKeyBytes: route.NewVertex(pub),
ChannelID: 12345,
OutgoingTimeLock: 111,
AmtToForward: 555,
TLVRecords: []tlv.Record{
tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil),
tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil),
},
}
testHop2 = &route.Hop{
PubKeyBytes: route.NewVertex(pub),
ChannelID: 12345,
OutgoingTimeLock: 111,
AmtToForward: 555,
LegacyPayload: true,
}
testRoute = route.Route{
TotalTimeLock: 123,
TotalAmount: 1234567,
SourcePubKey: route.NewVertex(pub),
Hops: []*route.Hop{
testHop1,
testHop2,
},
}
)
func makeFakePayment() *outgoingPayment {
@ -81,27 +43,6 @@ func makeFakePayment() *outgoingPayment {
return fakePayment
}
func makeFakeInfo() (*PaymentCreationInfo, *PaymentAttemptInfo) {
var preimg lntypes.Preimage
copy(preimg[:], rev[:])
c := &PaymentCreationInfo{
PaymentHash: preimg.Hash(),
Value: 1000,
// Use single second precision to avoid false positive test
// failures due to the monotonic time component.
CreationDate: time.Unix(time.Now().Unix(), 0),
PaymentRequest: []byte(""),
}
a := &PaymentAttemptInfo{
PaymentID: 44,
SessionKey: priv,
Route: testRoute,
}
return c, a
}
// randomBytes creates random []byte with length in range [minLen, maxLen)
func randomBytes(minLen, maxLen int) ([]byte, error) {
randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen))
@ -165,160 +106,3 @@ func makeRandomFakePayment() (*outgoingPayment, error) {
return fakePayment, nil
}
func TestSentPaymentSerialization(t *testing.T) {
t.Parallel()
c, s := makeFakeInfo()
var b bytes.Buffer
if err := serializePaymentCreationInfo(&b, c); err != nil {
t.Fatalf("unable to serialize creation info: %v", err)
}
newCreationInfo, err := deserializePaymentCreationInfo(&b)
if err != nil {
t.Fatalf("unable to deserialize creation info: %v", err)
}
if !reflect.DeepEqual(c, newCreationInfo) {
t.Fatalf("Payments do not match after "+
"serialization/deserialization %v vs %v",
spew.Sdump(c), spew.Sdump(newCreationInfo),
)
}
b.Reset()
if err := serializePaymentAttemptInfo(&b, s); err != nil {
t.Fatalf("unable to serialize info: %v", err)
}
newAttemptInfo, err := deserializePaymentAttemptInfo(&b)
if err != nil {
t.Fatalf("unable to deserialize info: %v", err)
}
// First we verify all the records match up porperly, as they aren't
// able to be properly compared using reflect.DeepEqual.
err = assertRouteEqual(&s.Route, &newAttemptInfo.Route)
if err != nil {
t.Fatalf("Routes do not match after "+
"serialization/deserialization: %v", err)
}
// Clear routes to allow DeepEqual to compare the remaining fields.
newAttemptInfo.Route = route.Route{}
s.Route = route.Route{}
if !reflect.DeepEqual(s, newAttemptInfo) {
s.SessionKey.Curve = nil
newAttemptInfo.SessionKey.Curve = nil
t.Fatalf("Payments do not match after "+
"serialization/deserialization %v vs %v",
spew.Sdump(s), spew.Sdump(newAttemptInfo),
)
}
}
// assertRouteEquals compares to routes for equality and returns an error if
// they are not equal.
func assertRouteEqual(a, b *route.Route) error {
err := assertRouteHopRecordsEqual(a, b)
if err != nil {
return err
}
// TLV records have already been compared and need to be cleared to
// properly compare the remaining fields using DeepEqual.
copyRouteNoHops := func(r *route.Route) *route.Route {
copy := *r
copy.Hops = make([]*route.Hop, len(r.Hops))
for i, hop := range r.Hops {
hopCopy := *hop
hopCopy.TLVRecords = nil
copy.Hops[i] = &hopCopy
}
return &copy
}
if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) {
return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v",
spew.Sdump(a), spew.Sdump(b))
}
return nil
}
func assertRouteHopRecordsEqual(r1, r2 *route.Route) error {
if len(r1.Hops) != len(r2.Hops) {
return errors.New("route hop count mismatch")
}
for i := 0; i < len(r1.Hops); i++ {
records1 := r1.Hops[i].TLVRecords
records2 := r2.Hops[i].TLVRecords
if len(records1) != len(records2) {
return fmt.Errorf("route record count for hop %v "+
"mismatch", i)
}
for j := 0; j < len(records1); j++ {
expectedRecord := records1[j]
newRecord := records2[j]
err := assertHopRecordsEqual(expectedRecord, newRecord)
if err != nil {
return fmt.Errorf("route record mismatch: %v", err)
}
}
}
return nil
}
func assertHopRecordsEqual(h1, h2 tlv.Record) error {
if h1.Type() != h2.Type() {
return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(),
h2.Type())
}
var b bytes.Buffer
if err := h2.Encode(&b); err != nil {
return fmt.Errorf("unable to encode record: %v", err)
}
if !bytes.Equal(b.Bytes(), tlvBytes) {
return fmt.Errorf("wrong raw record: expected %x, got %x",
tlvBytes, b.Bytes())
}
if h1.Size() != h2.Size() {
return fmt.Errorf("wrong size: expected %v, "+
"got %v", h1.Size(), h2.Size())
}
return nil
}
func TestRouteSerialization(t *testing.T) {
t.Parallel()
var b bytes.Buffer
if err := SerializeRoute(&b, testRoute); err != nil {
t.Fatal(err)
}
r := bytes.NewReader(b.Bytes())
route2, err := DeserializeRoute(r)
if err != nil {
t.Fatal(err)
}
// First we verify all the records match up porperly, as they aren't
// able to be properly compared using reflect.DeepEqual.
err = assertRouteEqual(&testRoute, &route2)
if err != nil {
t.Fatalf("routes not equal: \n%v vs \n%v",
spew.Sdump(testRoute), spew.Sdump(route2))
}
}

View file

@ -1,95 +0,0 @@
package migration_01_to_11
// rejectFlags is a compact representation of various metadata stored by the
// reject cache about a particular channel.
type rejectFlags uint8
const (
// rejectFlagExists is a flag indicating whether the channel exists,
// i.e. the channel is open and has a recent channel update. If this
// flag is not set, the channel is either a zombie or unknown.
rejectFlagExists rejectFlags = 1 << iota
// rejectFlagZombie is a flag indicating whether the channel is a
// zombie, i.e. the channel is open but has no recent channel updates.
rejectFlagZombie
)
// packRejectFlags computes the rejectFlags corresponding to the passed boolean
// values indicating whether the edge exists or is a zombie.
func packRejectFlags(exists, isZombie bool) rejectFlags {
var flags rejectFlags
if exists {
flags |= rejectFlagExists
}
if isZombie {
flags |= rejectFlagZombie
}
return flags
}
// unpack returns the booleans packed into the rejectFlags. The first indicates
// if the edge exists in our graph, the second indicates if the edge is a
// zombie.
func (f rejectFlags) unpack() (bool, bool) {
return f&rejectFlagExists == rejectFlagExists,
f&rejectFlagZombie == rejectFlagZombie
}
// rejectCacheEntry caches frequently accessed information about a channel,
// including the timestamps of its latest edge policies and whether or not the
// channel exists in the graph.
type rejectCacheEntry struct {
upd1Time int64
upd2Time int64
flags rejectFlags
}
// rejectCache is an in-memory cache used to improve the performance of
// HasChannelEdge. It caches information about the whether or channel exists, as
// well as the most recent timestamps for each policy (if they exists).
type rejectCache struct {
n int
edges map[uint64]rejectCacheEntry
}
// newRejectCache creates a new rejectCache with maximum capacity of n entries.
func newRejectCache(n int) *rejectCache {
return &rejectCache{
n: n,
edges: make(map[uint64]rejectCacheEntry, n),
}
}
// get returns the entry from the cache for chanid, if it exists.
func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) {
entry, ok := c.edges[chanid]
return entry, ok
}
// insert adds the entry to the reject cache. If an entry for chanid already
// exists, it will be replaced with the new entry. If the entry doesn't exists,
// it will be inserted to the cache, performing a random eviction if the cache
// is at capacity.
func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) {
// If entry exists, replace it.
if _, ok := c.edges[chanid]; ok {
c.edges[chanid] = entry
return
}
// Otherwise, evict an entry at random and insert.
if len(c.edges) == c.n {
for id := range c.edges {
delete(c.edges, id)
break
}
}
c.edges[chanid] = entry
}
// remove deletes an entry for chanid from the cache, if it exists.
func (c *rejectCache) remove(chanid uint64) {
delete(c.edges, chanid)
}

View file

@ -1,107 +0,0 @@
package migration_01_to_11
import (
"reflect"
"testing"
)
// TestRejectCache checks the behavior of the rejectCache with respect to insertion,
// eviction, and removal of cache entries.
func TestRejectCache(t *testing.T) {
const cacheSize = 100
// Create a new reject cache with the configured max size.
c := newRejectCache(cacheSize)
// As a sanity check, assert that querying the empty cache does not
// return an entry.
_, ok := c.get(0)
if ok {
t.Fatalf("reject cache should be empty")
}
// Now, fill up the cache entirely.
for i := uint64(0); i < cacheSize; i++ {
c.insert(i, entryForInt(i))
}
// Assert that the cache has all of the entries just inserted, since no
// eviction should occur until we try to surpass the max size.
assertHasEntries(t, c, 0, cacheSize)
// Now, insert a new element that causes the cache to evict an element.
c.insert(cacheSize, entryForInt(cacheSize))
// Assert that the cache has this last entry, as the cache should evict
// some prior element and not the newly inserted one.
assertHasEntries(t, c, cacheSize, cacheSize)
// Iterate over all inserted elements and construct a set of the evicted
// elements.
evicted := make(map[uint64]struct{})
for i := uint64(0); i < cacheSize+1; i++ {
_, ok := c.get(i)
if !ok {
evicted[i] = struct{}{}
}
}
// Assert that exactly one element has been evicted.
numEvicted := len(evicted)
if numEvicted != 1 {
t.Fatalf("expected one evicted entry, got: %d", numEvicted)
}
// Remove the highest item which initially caused the eviction and
// reinsert the element that was evicted prior.
c.remove(cacheSize)
for i := range evicted {
c.insert(i, entryForInt(i))
}
// Since the removal created an extra slot, the last insertion should
// not have caused an eviction and the entries for all channels in the
// original set that filled the cache should be present.
assertHasEntries(t, c, 0, cacheSize)
// Finally, reinsert the existing set back into the cache and test that
// the cache still has all the entries. If the randomized eviction were
// happening on inserts for existing cache items, we expect this to fail
// with high probability.
for i := uint64(0); i < cacheSize; i++ {
c.insert(i, entryForInt(i))
}
assertHasEntries(t, c, 0, cacheSize)
}
// assertHasEntries queries the reject cache for all channels in the range [start,
// end), asserting that they exist and their value matches the entry produced by
// entryForInt.
func assertHasEntries(t *testing.T, c *rejectCache, start, end uint64) {
t.Helper()
for i := start; i < end; i++ {
entry, ok := c.get(i)
if !ok {
t.Fatalf("reject cache should contain chan %d", i)
}
expEntry := entryForInt(i)
if !reflect.DeepEqual(entry, expEntry) {
t.Fatalf("entry mismatch, want: %v, got: %v",
expEntry, entry)
}
}
}
// entryForInt generates a unique rejectCacheEntry given an integer.
func entryForInt(i uint64) rejectCacheEntry {
exists := i%2 == 0
isZombie := i%3 == 0
return rejectCacheEntry{
upd1Time: int64(2 * i),
upd2Time: int64(2*i + 1),
flags: packRejectFlags(exists, isZombie),
}
}

View file

@ -1,251 +0,0 @@
package migration_01_to_11
import (
"encoding/binary"
"sync"
"io"
"bytes"
"github.com/coreos/bbolt"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// waitingProofsBucketKey byte string name of the waiting proofs store.
waitingProofsBucketKey = []byte("waitingproofs")
// ErrWaitingProofNotFound is returned if waiting proofs haven't been
// found by db.
ErrWaitingProofNotFound = errors.New("waiting proofs haven't been " +
"found")
// ErrWaitingProofAlreadyExist is returned if waiting proofs haven't been
// found by db.
ErrWaitingProofAlreadyExist = errors.New("waiting proof with such " +
"key already exist")
)
// WaitingProofStore is the bold db map-like storage for half announcement
// signatures. The one responsibility of this storage is to be able to
// retrieve waiting proofs after client restart.
type WaitingProofStore struct {
// cache is used in order to reduce the number of redundant get
// calls, when object isn't stored in it.
cache map[WaitingProofKey]struct{}
db *DB
mu sync.RWMutex
}
// NewWaitingProofStore creates new instance of proofs storage.
func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) {
s := &WaitingProofStore{
db: db,
cache: make(map[WaitingProofKey]struct{}),
}
if err := s.ForAll(func(proof *WaitingProof) error {
s.cache[proof.Key()] = struct{}{}
return nil
}); err != nil && err != ErrWaitingProofNotFound {
return nil, err
}
return s, nil
}
// Add adds new waiting proof in the storage.
func (s *WaitingProofStore) Add(proof *WaitingProof) error {
s.mu.Lock()
defer s.mu.Unlock()
err := s.db.Update(func(tx *bbolt.Tx) error {
var err error
var b bytes.Buffer
// Get or create the bucket.
bucket, err := tx.CreateBucketIfNotExists(waitingProofsBucketKey)
if err != nil {
return err
}
// Encode the objects and place it in the bucket.
if err := proof.Encode(&b); err != nil {
return err
}
key := proof.Key()
return bucket.Put(key[:], b.Bytes())
})
if err != nil {
return err
}
// Knowing that the write succeeded, we can now update the in-memory
// cache with the proof's key.
s.cache[proof.Key()] = struct{}{}
return nil
}
// Remove removes the proof from storage by its key.
func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.cache[key]; !ok {
return ErrWaitingProofNotFound
}
err := s.db.Update(func(tx *bbolt.Tx) error {
// Get or create the top bucket.
bucket := tx.Bucket(waitingProofsBucketKey)
if bucket == nil {
return ErrWaitingProofNotFound
}
return bucket.Delete(key[:])
})
if err != nil {
return err
}
// Since the proof was successfully deleted from the store, we can now
// remove it from the in-memory cache.
delete(s.cache, key)
return nil
}
// ForAll iterates thought all waiting proofs and passing the waiting proof
// in the given callback.
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error {
return s.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(waitingProofsBucketKey)
if bucket == nil {
return ErrWaitingProofNotFound
}
// Iterate over objects buckets.
return bucket.ForEach(func(k, v []byte) error {
// Skip buckets fields.
if v == nil {
return nil
}
r := bytes.NewReader(v)
proof := &WaitingProof{}
if err := proof.Decode(r); err != nil {
return err
}
return cb(proof)
})
})
}
// Get returns the object which corresponds to the given index.
func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
proof := &WaitingProof{}
s.mu.RLock()
defer s.mu.RUnlock()
if _, ok := s.cache[key]; !ok {
return nil, ErrWaitingProofNotFound
}
err := s.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(waitingProofsBucketKey)
if bucket == nil {
return ErrWaitingProofNotFound
}
// Iterate over objects buckets.
v := bucket.Get(key[:])
if v == nil {
return ErrWaitingProofNotFound
}
r := bytes.NewReader(v)
return proof.Decode(r)
})
return proof, err
}
// WaitingProofKey is the proof key which uniquely identifies the waiting
// proof object. The goal of this key is distinguish the local and remote
// proof for the same channel id.
type WaitingProofKey [9]byte
// WaitingProof is the storable object, which encapsulate the half proof and
// the information about from which side this proof came. This structure is
// needed to make channel proof exchange persistent, so that after client
// restart we may receive remote/local half proof and process it.
type WaitingProof struct {
*lnwire.AnnounceSignatures
isRemote bool
}
// NewWaitingProof constructs a new waiting prof instance.
func NewWaitingProof(isRemote bool, proof *lnwire.AnnounceSignatures) *WaitingProof {
return &WaitingProof{
AnnounceSignatures: proof,
isRemote: isRemote,
}
}
// OppositeKey returns the key which uniquely identifies opposite waiting proof.
func (p *WaitingProof) OppositeKey() WaitingProofKey {
var key [9]byte
binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64())
if !p.isRemote {
key[8] = 1
}
return key
}
// Key returns the key which uniquely identifies waiting proof.
func (p *WaitingProof) Key() WaitingProofKey {
var key [9]byte
binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64())
if p.isRemote {
key[8] = 1
}
return key
}
// Encode writes the internal representation of waiting proof in byte stream.
func (p *WaitingProof) Encode(w io.Writer) error {
if err := binary.Write(w, byteOrder, p.isRemote); err != nil {
return err
}
if err := p.AnnounceSignatures.Encode(w, 0); err != nil {
return err
}
return nil
}
// Decode reads the data from the byte stream and initializes the
// waiting proof object with it.
func (p *WaitingProof) Decode(r io.Reader) error {
if err := binary.Read(r, byteOrder, &p.isRemote); err != nil {
return err
}
msg := &lnwire.AnnounceSignatures{}
if err := msg.Decode(r, 0); err != nil {
return err
}
(*p).AnnounceSignatures = msg
return nil
}

View file

@ -1,59 +0,0 @@
package migration_01_to_11
import (
"testing"
"reflect"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lnwire"
)
// TestWaitingProofStore tests add/get/remove functions of the waiting proof
// storage.
func TestWaitingProofStore(t *testing.T) {
t.Parallel()
db, cleanup, err := makeTestDB()
if err != nil {
t.Fatalf("failed to make test database: %s", err)
}
defer cleanup()
proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{
NodeSignature: wireSig,
BitcoinSignature: wireSig,
})
store, err := NewWaitingProofStore(db)
if err != nil {
t.Fatalf("unable to create the waiting proofs storage: %v",
err)
}
if err := store.Add(proof1); err != nil {
t.Fatalf("unable add proof to storage: %v", err)
}
proof2, err := store.Get(proof1.Key())
if err != nil {
t.Fatalf("unable retrieve proof from storage: %v", err)
}
if !reflect.DeepEqual(proof1, proof2) {
t.Fatal("wrong proof retrieved")
}
if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound {
t.Fatalf("proof shouldn't be found: %v", err)
}
if err := store.Remove(proof1.Key()); err != nil {
t.Fatalf("unable remove proof from storage: %v", err)
}
if err := store.ForAll(func(proof *WaitingProof) error {
return errors.New("storage should be empty")
}); err != nil && err != ErrWaitingProofNotFound {
t.Fatal(err)
}
}

View file

@ -1,229 +0,0 @@
package migration_01_to_11
import (
"fmt"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/lntypes"
)
var (
// ErrNoWitnesses is an error that's returned when no new witnesses have
// been added to the WitnessCache.
ErrNoWitnesses = fmt.Errorf("no witnesses")
// ErrUnknownWitnessType is returned if a caller attempts to
ErrUnknownWitnessType = fmt.Errorf("unknown witness type")
)
// WitnessType is enum that denotes what "type" of witness is being
// stored/retrieved. As the WitnessCache itself is agnostic and doesn't enforce
// any structure on added witnesses, we use this type to partition the
// witnesses on disk, and also to know how to map a witness to its look up key.
type WitnessType uint8
var (
// Sha256HashWitness is a witness that is simply the pre image to a
// hash image. In order to map to its key, we'll use sha256.
Sha256HashWitness WitnessType = 1
)
// toDBKey is a helper method that maps a witness type to the key that we'll
// use to store it within the database.
func (w WitnessType) toDBKey() ([]byte, error) {
switch w {
case Sha256HashWitness:
return []byte{byte(w)}, nil
default:
return nil, ErrUnknownWitnessType
}
}
var (
// witnessBucketKey is the name of the bucket that we use to store all
// witnesses encountered. Within this bucket, we'll create a sub-bucket for
// each witness type.
witnessBucketKey = []byte("byte")
)
// WitnessCache is a persistent cache of all witnesses we've encountered on the
// network. In the case of multi-hop, multi-step contracts, a cache of all
// witnesses can be useful in the case of partial contract resolution. If
// negotiations break down, we may be forced to locate the witness for a
// portion of the contract on-chain. In this case, we'll then add that witness
// to the cache so the incoming contract can fully resolve witness.
// Additionally, as one MUST always use a unique witness on the network, we may
// use this cache to detect duplicate witnesses.
//
// TODO(roasbeef): need expiry policy?
// * encrypt?
type WitnessCache struct {
db *DB
}
// NewWitnessCache returns a new instance of the witness cache.
func (d *DB) NewWitnessCache() *WitnessCache {
return &WitnessCache{
db: d,
}
}
// witnessEntry is a key-value struct that holds each key -> witness pair, used
// when inserting records into the cache.
type witnessEntry struct {
key []byte
witness []byte
}
// AddSha256Witnesses adds a batch of new sha256 preimages into the witness
// cache. This is an alias for AddWitnesses that uses Sha256HashWitness as the
// preimages' witness type.
func (w *WitnessCache) AddSha256Witnesses(preimages ...lntypes.Preimage) error {
// Optimistically compute the preimages' hashes before attempting to
// start the db transaction.
entries := make([]witnessEntry, 0, len(preimages))
for i := range preimages {
hash := preimages[i].Hash()
entries = append(entries, witnessEntry{
key: hash[:],
witness: preimages[i][:],
})
}
return w.addWitnessEntries(Sha256HashWitness, entries)
}
// addWitnessEntries inserts the witnessEntry key-value pairs into the cache,
// using the appropriate witness type to segment the namespace of possible
// witness types.
func (w *WitnessCache) addWitnessEntries(wType WitnessType,
entries []witnessEntry) error {
// Exit early if there are no witnesses to add.
if len(entries) == 0 {
return nil
}
return w.db.Batch(func(tx *bbolt.Tx) error {
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
if err != nil {
return err
}
witnessTypeBucketKey, err := wType.toDBKey()
if err != nil {
return err
}
witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists(
witnessTypeBucketKey,
)
if err != nil {
return err
}
for _, entry := range entries {
err = witnessTypeBucket.Put(entry.key, entry.witness)
if err != nil {
return err
}
}
return nil
})
}
// LookupSha256Witness attempts to lookup the preimage for a sha256 hash. If
// the witness isn't found, ErrNoWitnesses will be returned.
func (w *WitnessCache) LookupSha256Witness(hash lntypes.Hash) (lntypes.Preimage, error) {
witness, err := w.lookupWitness(Sha256HashWitness, hash[:])
if err != nil {
return lntypes.Preimage{}, err
}
return lntypes.MakePreimage(witness)
}
// lookupWitness attempts to lookup a witness according to its type and also
// its witness key. In the case that the witness isn't found, ErrNoWitnesses
// will be returned.
func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) {
var witness []byte
err := w.db.View(func(tx *bbolt.Tx) error {
witnessBucket := tx.Bucket(witnessBucketKey)
if witnessBucket == nil {
return ErrNoWitnesses
}
witnessTypeBucketKey, err := wType.toDBKey()
if err != nil {
return err
}
witnessTypeBucket := witnessBucket.Bucket(witnessTypeBucketKey)
if witnessTypeBucket == nil {
return ErrNoWitnesses
}
dbWitness := witnessTypeBucket.Get(witnessKey)
if dbWitness == nil {
return ErrNoWitnesses
}
witness = make([]byte, len(dbWitness))
copy(witness[:], dbWitness)
return nil
})
if err != nil {
return nil, err
}
return witness, nil
}
// DeleteSha256Witness attempts to delete a sha256 preimage identified by hash.
func (w *WitnessCache) DeleteSha256Witness(hash lntypes.Hash) error {
return w.deleteWitness(Sha256HashWitness, hash[:])
}
// deleteWitness attempts to delete a particular witness from the database.
func (w *WitnessCache) deleteWitness(wType WitnessType, witnessKey []byte) error {
return w.db.Batch(func(tx *bbolt.Tx) error {
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
if err != nil {
return err
}
witnessTypeBucketKey, err := wType.toDBKey()
if err != nil {
return err
}
witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists(
witnessTypeBucketKey,
)
if err != nil {
return err
}
return witnessTypeBucket.Delete(witnessKey)
})
}
// DeleteWitnessClass attempts to delete an *entire* class of witnesses. After
// this function return with a non-nil error,
func (w *WitnessCache) DeleteWitnessClass(wType WitnessType) error {
return w.db.Batch(func(tx *bbolt.Tx) error {
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
if err != nil {
return err
}
witnessTypeBucketKey, err := wType.toDBKey()
if err != nil {
return err
}
return witnessBucket.DeleteBucket(witnessTypeBucketKey)
})
}

View file

@ -1,238 +0,0 @@
package migration_01_to_11
import (
"crypto/sha256"
"testing"
"github.com/lightningnetwork/lnd/lntypes"
)
// TestWitnessCacheSha256Retrieval tests that we're able to add and lookup new
// sha256 preimages to the witness cache.
func TestWitnessCacheSha256Retrieval(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
wCache := cdb.NewWitnessCache()
// We'll be attempting to add then lookup two simple sha256 preimages
// within this test.
preimage1 := lntypes.Preimage(rev)
preimage2 := lntypes.Preimage(key)
preimages := []lntypes.Preimage{preimage1, preimage2}
hashes := []lntypes.Hash{preimage1.Hash(), preimage2.Hash()}
// First, we'll attempt to add the preimages to the database.
err = wCache.AddSha256Witnesses(preimages...)
if err != nil {
t.Fatalf("unable to add witness: %v", err)
}
// With the preimages stored, we'll now attempt to look them up.
for i, hash := range hashes {
preimage := preimages[i]
// We should get back the *exact* same preimage as we originally
// stored.
dbPreimage, err := wCache.LookupSha256Witness(hash)
if err != nil {
t.Fatalf("unable to look up witness: %v", err)
}
if preimage != dbPreimage {
t.Fatalf("witnesses don't match: expected %x, got %x",
preimage[:], dbPreimage[:])
}
}
}
// TestWitnessCacheSha256Deletion tests that we're able to delete a single
// sha256 preimage, and also a class of witnesses from the cache.
func TestWitnessCacheSha256Deletion(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
wCache := cdb.NewWitnessCache()
// We'll start by adding two preimages to the cache.
preimage1 := lntypes.Preimage(key)
hash1 := preimage1.Hash()
preimage2 := lntypes.Preimage(rev)
hash2 := preimage2.Hash()
if err := wCache.AddSha256Witnesses(preimage1); err != nil {
t.Fatalf("unable to add witness: %v", err)
}
if err := wCache.AddSha256Witnesses(preimage2); err != nil {
t.Fatalf("unable to add witness: %v", err)
}
// We'll now delete the first preimage. If we attempt to look it up, we
// should get ErrNoWitnesses.
err = wCache.DeleteSha256Witness(hash1)
if err != nil {
t.Fatalf("unable to delete witness: %v", err)
}
_, err = wCache.LookupSha256Witness(hash1)
if err != ErrNoWitnesses {
t.Fatalf("expected ErrNoWitnesses instead got: %v", err)
}
// Next, we'll attempt to delete the entire witness class itself. When
// we try to lookup the second preimage, we should again get
// ErrNoWitnesses.
if err := wCache.DeleteWitnessClass(Sha256HashWitness); err != nil {
t.Fatalf("unable to delete witness class: %v", err)
}
_, err = wCache.LookupSha256Witness(hash2)
if err != ErrNoWitnesses {
t.Fatalf("expected ErrNoWitnesses instead got: %v", err)
}
}
// TestWitnessCacheUnknownWitness tests that we get an error if we attempt to
// query/add/delete an unknown witness.
func TestWitnessCacheUnknownWitness(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
wCache := cdb.NewWitnessCache()
// We'll attempt to add a new, undefined witness type to the database.
// We should get an error.
err = wCache.legacyAddWitnesses(234, key[:])
if err != ErrUnknownWitnessType {
t.Fatalf("expected ErrUnknownWitnessType, got %v", err)
}
}
// TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves
// identically to the insertion via the generalized interface.
func TestAddSha256Witnesses(t *testing.T) {
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
wCache := cdb.NewWitnessCache()
// We'll start by adding a witnesses to the cache using the generic
// AddWitnesses method.
witness1 := rev[:]
preimage1 := lntypes.Preimage(rev)
hash1 := preimage1.Hash()
witness2 := key[:]
preimage2 := lntypes.Preimage(key)
hash2 := preimage2.Hash()
var (
witnesses = [][]byte{witness1, witness2}
preimages = []lntypes.Preimage{preimage1, preimage2}
hashes = []lntypes.Hash{hash1, hash2}
)
err = wCache.legacyAddWitnesses(Sha256HashWitness, witnesses...)
if err != nil {
t.Fatalf("unable to add witness: %v", err)
}
for i, hash := range hashes {
preimage := preimages[i]
dbPreimage, err := wCache.LookupSha256Witness(hash)
if err != nil {
t.Fatalf("unable to lookup witness: %v", err)
}
// Assert that the retrieved witness matches the original.
if dbPreimage != preimage {
t.Fatalf("retrieved witness mismatch, want: %x, "+
"got: %x", preimage, dbPreimage)
}
// We'll now delete the witness, as we'll be reinserting it
// using the specialized AddSha256Witnesses method.
err = wCache.DeleteSha256Witness(hash)
if err != nil {
t.Fatalf("unable to delete witness: %v", err)
}
}
// Now, add the same witnesses using the type-safe interface for
// lntypes.Preimages..
err = wCache.AddSha256Witnesses(preimages...)
if err != nil {
t.Fatalf("unable to add sha256 preimage: %v", err)
}
// Finally, iterate over the keys and assert that the returned witnesses
// match the original witnesses. This asserts that the specialized
// insertion method behaves identically to the generalized interface.
for i, hash := range hashes {
preimage := preimages[i]
dbPreimage, err := wCache.LookupSha256Witness(hash)
if err != nil {
t.Fatalf("unable to lookup witness: %v", err)
}
// Assert that the retrieved witness matches the original.
if dbPreimage != preimage {
t.Fatalf("retrieved witness mismatch, want: %x, "+
"got: %x", preimage, dbPreimage)
}
}
}
// legacyAddWitnesses adds a batch of new witnesses of wType to the witness
// cache. The type of the witness will be used to map each witness to the key
// that will be used to look it up. All witnesses should be of the same
// WitnessType.
//
// NOTE: Previously this method exposed a generic interface for adding
// witnesses, which has since been deprecated in favor of a strongly typed
// interface for each witness class. We keep this method around to assert the
// correctness of specialized witness adding methods.
func (w *WitnessCache) legacyAddWitnesses(wType WitnessType,
witnesses ...[]byte) error {
// Optimistically compute the witness keys before attempting to start
// the db transaction.
entries := make([]witnessEntry, 0, len(witnesses))
for _, witness := range witnesses {
// Map each witness to its key by applying the appropriate
// transformation for the given witness type.
switch wType {
case Sha256HashWitness:
key := sha256.Sum256(witness)
entries = append(entries, witnessEntry{
key: key[:],
witness: witness,
})
default:
return ErrUnknownWitnessType
}
}
return w.addWitnessEntries(wType, entries)
}