mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-02-23 06:35:07 +01:00
channeldb/migration_01_to_11: remove unused code
This commit is contained in:
parent
f5191440c5
commit
60503d6c44
33 changed files with 1 additions and 17048 deletions
|
@ -1,24 +0,0 @@
|
|||
channeldb
|
||||
==========
|
||||
|
||||
[](https://travis-ci.org/lightningnetwork/lnd)
|
||||
[](https://github.com/lightningnetwork/lnd/blob/master/LICENSE)
|
||||
[](http://godoc.org/github.com/lightningnetwork/lnd/channeldb)
|
||||
|
||||
The channeldb implements the persistent storage engine for `lnd` and
|
||||
generically a data storage layer for the required state within the Lightning
|
||||
Network. The backing storage engine is
|
||||
[boltdb](https://github.com/coreos/bbolt), an embedded pure-go key-value store
|
||||
based off of LMDB.
|
||||
|
||||
The package implements an object-oriented storage model with queries and
|
||||
mutations flowing through a particular object instance rather than the database
|
||||
itself. The storage implemented by the objects includes: open channels, past
|
||||
commitment revocation states, the channel graph which includes authenticated
|
||||
node and channel announcements, outgoing payments, and invoices
|
||||
|
||||
## Installation and Updating
|
||||
|
||||
```bash
|
||||
$ go get -u github.com/lightningnetwork/lnd/channeldb
|
||||
```
|
|
@ -1,149 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/tor"
|
||||
)
|
||||
|
||||
type unknownAddrType struct{}
|
||||
|
||||
func (t unknownAddrType) Network() string { return "unknown" }
|
||||
func (t unknownAddrType) String() string { return "unknown" }
|
||||
|
||||
var testIP4 = net.ParseIP("192.168.1.1")
|
||||
var testIP6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329")
|
||||
|
||||
var addrTests = []struct {
|
||||
expAddr net.Addr
|
||||
serErr string
|
||||
}{
|
||||
// Valid addresses.
|
||||
{
|
||||
expAddr: &net.TCPAddr{
|
||||
IP: testIP4,
|
||||
Port: 12345,
|
||||
},
|
||||
},
|
||||
{
|
||||
expAddr: &net.TCPAddr{
|
||||
IP: testIP6,
|
||||
Port: 65535,
|
||||
},
|
||||
},
|
||||
{
|
||||
expAddr: &tor.OnionAddr{
|
||||
OnionService: "3g2upl4pq6kufc4m.onion",
|
||||
Port: 9735,
|
||||
},
|
||||
},
|
||||
{
|
||||
expAddr: &tor.OnionAddr{
|
||||
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion",
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
|
||||
// Invalid addresses.
|
||||
{
|
||||
expAddr: unknownAddrType{},
|
||||
serErr: ErrUnknownAddressType.Error(),
|
||||
},
|
||||
{
|
||||
expAddr: &net.TCPAddr{
|
||||
// Remove last byte of IPv4 address.
|
||||
IP: testIP4[:len(testIP4)-1],
|
||||
Port: 12345,
|
||||
},
|
||||
serErr: "unable to encode",
|
||||
},
|
||||
{
|
||||
expAddr: &net.TCPAddr{
|
||||
// Add an extra byte of IPv4 address.
|
||||
IP: append(testIP4, 0xff),
|
||||
Port: 12345,
|
||||
},
|
||||
serErr: "unable to encode",
|
||||
},
|
||||
{
|
||||
expAddr: &net.TCPAddr{
|
||||
// Remove last byte of IPv6 address.
|
||||
IP: testIP6[:len(testIP6)-1],
|
||||
Port: 65535,
|
||||
},
|
||||
serErr: "unable to encode",
|
||||
},
|
||||
{
|
||||
expAddr: &net.TCPAddr{
|
||||
// Add an extra byte to the IPv6 address.
|
||||
IP: append(testIP6, 0xff),
|
||||
Port: 65535,
|
||||
},
|
||||
serErr: "unable to encode",
|
||||
},
|
||||
{
|
||||
expAddr: &tor.OnionAddr{
|
||||
// Invalid suffix.
|
||||
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.inion",
|
||||
Port: 80,
|
||||
},
|
||||
serErr: "invalid suffix",
|
||||
},
|
||||
{
|
||||
expAddr: &tor.OnionAddr{
|
||||
// Invalid length.
|
||||
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyy.onion",
|
||||
Port: 80,
|
||||
},
|
||||
serErr: "unknown onion service length",
|
||||
},
|
||||
{
|
||||
expAddr: &tor.OnionAddr{
|
||||
// Invalid encoding.
|
||||
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyA.onion",
|
||||
Port: 80,
|
||||
},
|
||||
serErr: "illegal base32",
|
||||
},
|
||||
}
|
||||
|
||||
// TestAddrSerialization tests that the serialization method used by channeldb
|
||||
// for net.Addr's works as intended.
|
||||
func TestAddrSerialization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b bytes.Buffer
|
||||
for _, test := range addrTests {
|
||||
err := serializeAddr(&b, test.expAddr)
|
||||
switch {
|
||||
case err == nil && test.serErr != "":
|
||||
t.Fatalf("expected serialization err for addr %v",
|
||||
test.expAddr)
|
||||
|
||||
case err != nil && test.serErr == "":
|
||||
t.Fatalf("unexpected serialization err for addr %v: %v",
|
||||
test.expAddr, err)
|
||||
|
||||
case err != nil && !strings.Contains(err.Error(), test.serErr):
|
||||
t.Fatalf("unexpected serialization err for addr %v, "+
|
||||
"want: %v, got %v", test.expAddr, test.serErr,
|
||||
err)
|
||||
|
||||
case err != nil:
|
||||
continue
|
||||
}
|
||||
|
||||
addr, err := deserializeAddr(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to deserialize address: %v", err)
|
||||
}
|
||||
|
||||
if addr.String() != test.expAddr.String() {
|
||||
t.Fatalf("expected address %v after serialization, "+
|
||||
"got %v", addr, test.expAddr)
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,50 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
// channelCache is an in-memory cache used to improve the performance of
|
||||
// ChanUpdatesInHorizon. It caches the chan info and edge policies for a
|
||||
// particular channel.
|
||||
type channelCache struct {
|
||||
n int
|
||||
channels map[uint64]ChannelEdge
|
||||
}
|
||||
|
||||
// newChannelCache creates a new channelCache with maximum capacity of n
|
||||
// channels.
|
||||
func newChannelCache(n int) *channelCache {
|
||||
return &channelCache{
|
||||
n: n,
|
||||
channels: make(map[uint64]ChannelEdge),
|
||||
}
|
||||
}
|
||||
|
||||
// get returns the channel from the cache, if it exists.
|
||||
func (c *channelCache) get(chanid uint64) (ChannelEdge, bool) {
|
||||
channel, ok := c.channels[chanid]
|
||||
return channel, ok
|
||||
}
|
||||
|
||||
// insert adds the entry to the channel cache. If an entry for chanid already
|
||||
// exists, it will be replaced with the new entry. If the entry doesn't exist,
|
||||
// it will be inserted to the cache, performing a random eviction if the cache
|
||||
// is at capacity.
|
||||
func (c *channelCache) insert(chanid uint64, channel ChannelEdge) {
|
||||
// If entry exists, replace it.
|
||||
if _, ok := c.channels[chanid]; ok {
|
||||
c.channels[chanid] = channel
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, evict an entry at random and insert.
|
||||
if len(c.channels) == c.n {
|
||||
for id := range c.channels {
|
||||
delete(c.channels, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
c.channels[chanid] = channel
|
||||
}
|
||||
|
||||
// remove deletes an edge for chanid from the cache, if it exists.
|
||||
func (c *channelCache) remove(chanid uint64) {
|
||||
delete(c.channels, chanid)
|
||||
}
|
|
@ -1,105 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestChannelCache checks the behavior of the channelCache with respect to
|
||||
// insertion, eviction, and removal of cache entries.
|
||||
func TestChannelCache(t *testing.T) {
|
||||
const cacheSize = 100
|
||||
|
||||
// Create a new channel cache with the configured max size.
|
||||
c := newChannelCache(cacheSize)
|
||||
|
||||
// As a sanity check, assert that querying the empty cache does not
|
||||
// return an entry.
|
||||
_, ok := c.get(0)
|
||||
if ok {
|
||||
t.Fatalf("channel cache should be empty")
|
||||
}
|
||||
|
||||
// Now, fill up the cache entirely.
|
||||
for i := uint64(0); i < cacheSize; i++ {
|
||||
c.insert(i, channelForInt(i))
|
||||
}
|
||||
|
||||
// Assert that the cache has all of the entries just inserted, since no
|
||||
// eviction should occur until we try to surpass the max size.
|
||||
assertHasChanEntries(t, c, 0, cacheSize)
|
||||
|
||||
// Now, insert a new element that causes the cache to evict an element.
|
||||
c.insert(cacheSize, channelForInt(cacheSize))
|
||||
|
||||
// Assert that the cache has this last entry, as the cache should evict
|
||||
// some prior element and not the newly inserted one.
|
||||
assertHasChanEntries(t, c, cacheSize, cacheSize)
|
||||
|
||||
// Iterate over all inserted elements and construct a set of the evicted
|
||||
// elements.
|
||||
evicted := make(map[uint64]struct{})
|
||||
for i := uint64(0); i < cacheSize+1; i++ {
|
||||
_, ok := c.get(i)
|
||||
if !ok {
|
||||
evicted[i] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Assert that exactly one element has been evicted.
|
||||
numEvicted := len(evicted)
|
||||
if numEvicted != 1 {
|
||||
t.Fatalf("expected one evicted entry, got: %d", numEvicted)
|
||||
}
|
||||
|
||||
// Remove the highest item which initially caused the eviction and
|
||||
// reinsert the element that was evicted prior.
|
||||
c.remove(cacheSize)
|
||||
for i := range evicted {
|
||||
c.insert(i, channelForInt(i))
|
||||
}
|
||||
|
||||
// Since the removal created an extra slot, the last insertion should
|
||||
// not have caused an eviction and the entries for all channels in the
|
||||
// original set that filled the cache should be present.
|
||||
assertHasChanEntries(t, c, 0, cacheSize)
|
||||
|
||||
// Finally, reinsert the existing set back into the cache and test that
|
||||
// the cache still has all the entries. If the randomized eviction were
|
||||
// happening on inserts for existing cache items, we expect this to fail
|
||||
// with high probability.
|
||||
for i := uint64(0); i < cacheSize; i++ {
|
||||
c.insert(i, channelForInt(i))
|
||||
}
|
||||
assertHasChanEntries(t, c, 0, cacheSize)
|
||||
|
||||
}
|
||||
|
||||
// assertHasEntries queries the edge cache for all channels in the range [start,
|
||||
// end), asserting that they exist and their value matches the entry produced by
|
||||
// entryForInt.
|
||||
func assertHasChanEntries(t *testing.T, c *channelCache, start, end uint64) {
|
||||
t.Helper()
|
||||
|
||||
for i := start; i < end; i++ {
|
||||
entry, ok := c.get(i)
|
||||
if !ok {
|
||||
t.Fatalf("channel cache should contain chan %d", i)
|
||||
}
|
||||
|
||||
expEntry := channelForInt(i)
|
||||
if !reflect.DeepEqual(entry, expEntry) {
|
||||
t.Fatalf("entry mismatch, want: %v, got: %v",
|
||||
expEntry, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// channelForInt generates a unique ChannelEdge given an integer.
|
||||
func channelForInt(i uint64) ChannelEdge {
|
||||
return ChannelEdge{
|
||||
Info: &ChannelEdgeInfo{
|
||||
ChannelID: i,
|
||||
},
|
||||
}
|
||||
}
|
|
@ -4,18 +4,13 @@ import (
|
|||
"bytes"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/shachain"
|
||||
|
@ -66,8 +61,6 @@ var (
|
|||
LockTime: 5,
|
||||
}
|
||||
privKey, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:])
|
||||
|
||||
wireSig, _ = lnwire.NewSigFromSignature(testSig)
|
||||
)
|
||||
|
||||
// makeTestDB creates a new instance of the ChannelDB for testing purposes. A
|
||||
|
@ -223,819 +216,6 @@ func createTestChannelState(cdb *DB) (*OpenChannel, error) {
|
|||
RevocationProducer: producer,
|
||||
RevocationStore: store,
|
||||
Db: cdb,
|
||||
Packager: NewChannelPackager(chanID),
|
||||
FundingTxn: testTx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestOpenChannelPutGetDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// Create the test channel state, then add an additional fake HTLC
|
||||
// before syncing to disk.
|
||||
state, err := createTestChannelState(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel state: %v", err)
|
||||
}
|
||||
state.LocalCommitment.Htlcs = []HTLC{
|
||||
{
|
||||
Signature: testSig.Serialize(),
|
||||
Incoming: true,
|
||||
Amt: 10,
|
||||
RHash: key,
|
||||
RefundTimeout: 1,
|
||||
OnionBlob: []byte("onionblob"),
|
||||
},
|
||||
}
|
||||
state.RemoteCommitment.Htlcs = []HTLC{
|
||||
{
|
||||
Signature: testSig.Serialize(),
|
||||
Incoming: false,
|
||||
Amt: 10,
|
||||
RHash: key,
|
||||
RefundTimeout: 1,
|
||||
OnionBlob: []byte("onionblob"),
|
||||
},
|
||||
}
|
||||
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18556,
|
||||
}
|
||||
if err := state.SyncPending(addr, 101); err != nil {
|
||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
||||
}
|
||||
|
||||
openChannels, err := cdb.FetchOpenChannels(state.IdentityPub)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch open channel: %v", err)
|
||||
}
|
||||
|
||||
newState := openChannels[0]
|
||||
|
||||
// The decoded channel state should be identical to what we stored
|
||||
// above.
|
||||
if !reflect.DeepEqual(state, newState) {
|
||||
t.Fatalf("channel state doesn't match:: %v vs %v",
|
||||
spew.Sdump(state), spew.Sdump(newState))
|
||||
}
|
||||
|
||||
// We'll also test that the channel is properly able to hot swap the
|
||||
// next revocation for the state machine. This tests the initial
|
||||
// post-funding revocation exchange.
|
||||
nextRevKey, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create new private key: %v", err)
|
||||
}
|
||||
if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil {
|
||||
t.Fatalf("unable to update revocation: %v", err)
|
||||
}
|
||||
|
||||
openChannels, err = cdb.FetchOpenChannels(state.IdentityPub)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch open channel: %v", err)
|
||||
}
|
||||
updatedChan := openChannels[0]
|
||||
|
||||
// Ensure that the revocation was set properly.
|
||||
if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) {
|
||||
t.Fatalf("next revocation wasn't updated")
|
||||
}
|
||||
|
||||
// Finally to wrap up the test, delete the state of the channel within
|
||||
// the database. This involves "closing" the channel which removes all
|
||||
// written state, and creates a small "summary" elsewhere within the
|
||||
// database.
|
||||
closeSummary := &ChannelCloseSummary{
|
||||
ChanPoint: state.FundingOutpoint,
|
||||
RemotePub: state.IdentityPub,
|
||||
SettledBalance: btcutil.Amount(500),
|
||||
TimeLockedBalance: btcutil.Amount(10000),
|
||||
IsPending: false,
|
||||
CloseType: CooperativeClose,
|
||||
}
|
||||
if err := state.CloseChannel(closeSummary); err != nil {
|
||||
t.Fatalf("unable to close channel: %v", err)
|
||||
}
|
||||
|
||||
// As the channel is now closed, attempting to fetch all open channels
|
||||
// for our fake node ID should return an empty slice.
|
||||
openChans, err := cdb.FetchOpenChannels(state.IdentityPub)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch open channels: %v", err)
|
||||
}
|
||||
if len(openChans) != 0 {
|
||||
t.Fatalf("all channels not deleted, found %v", len(openChans))
|
||||
}
|
||||
|
||||
// Additionally, attempting to fetch all the open channels globally
|
||||
// should yield no results.
|
||||
openChans, err = cdb.FetchAllChannels()
|
||||
if err != nil {
|
||||
t.Fatal("unable to fetch all open chans")
|
||||
}
|
||||
if len(openChans) != 0 {
|
||||
t.Fatalf("all channels not deleted, found %v", len(openChans))
|
||||
}
|
||||
}
|
||||
|
||||
func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
|
||||
if !reflect.DeepEqual(a, b) {
|
||||
_, _, line, _ := runtime.Caller(1)
|
||||
t.Fatalf("line %v: commitments don't match: %v vs %v",
|
||||
line, spew.Sdump(a), spew.Sdump(b))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelStateTransition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// First create a minimal channel, then perform a full sync in order to
|
||||
// persist the data.
|
||||
channel, err := createTestChannelState(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel state: %v", err)
|
||||
}
|
||||
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18556,
|
||||
}
|
||||
if err := channel.SyncPending(addr, 101); err != nil {
|
||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
||||
}
|
||||
|
||||
// Add some HTLCs which were added during this new state transition.
|
||||
// Half of the HTLCs are incoming, while the other half are outgoing.
|
||||
var (
|
||||
htlcs []HTLC
|
||||
htlcAmt lnwire.MilliSatoshi
|
||||
)
|
||||
for i := uint32(0); i < 10; i++ {
|
||||
var incoming bool
|
||||
if i > 5 {
|
||||
incoming = true
|
||||
}
|
||||
htlc := HTLC{
|
||||
Signature: testSig.Serialize(),
|
||||
Incoming: incoming,
|
||||
Amt: 10,
|
||||
RHash: key,
|
||||
RefundTimeout: i,
|
||||
OutputIndex: int32(i * 3),
|
||||
LogIndex: uint64(i * 2),
|
||||
HtlcIndex: uint64(i),
|
||||
}
|
||||
htlc.OnionBlob = make([]byte, 10)
|
||||
copy(htlc.OnionBlob[:], bytes.Repeat([]byte{2}, 10))
|
||||
htlcs = append(htlcs, htlc)
|
||||
htlcAmt += htlc.Amt
|
||||
}
|
||||
|
||||
// Create a new channel delta which includes the above HTLCs, some
|
||||
// balance updates, and an increment of the current commitment height.
|
||||
// Additionally, modify the signature and commitment transaction.
|
||||
newSequence := uint32(129498)
|
||||
newSig := bytes.Repeat([]byte{3}, 71)
|
||||
newTx := channel.LocalCommitment.CommitTx.Copy()
|
||||
newTx.TxIn[0].Sequence = newSequence
|
||||
commitment := ChannelCommitment{
|
||||
CommitHeight: 1,
|
||||
LocalLogIndex: 2,
|
||||
LocalHtlcIndex: 1,
|
||||
RemoteLogIndex: 2,
|
||||
RemoteHtlcIndex: 1,
|
||||
LocalBalance: lnwire.MilliSatoshi(1e8),
|
||||
RemoteBalance: lnwire.MilliSatoshi(1e8),
|
||||
CommitFee: 55,
|
||||
FeePerKw: 99,
|
||||
CommitTx: newTx,
|
||||
CommitSig: newSig,
|
||||
Htlcs: htlcs,
|
||||
}
|
||||
|
||||
// First update the local node's broadcastable state and also add a
|
||||
// CommitDiff remote node's as well in order to simulate a proper state
|
||||
// transition.
|
||||
if err := channel.UpdateCommitment(&commitment); err != nil {
|
||||
t.Fatalf("unable to update commitment: %v", err)
|
||||
}
|
||||
|
||||
// The balances, new update, the HTLCs and the changes to the fake
|
||||
// commitment transaction along with the modified signature should all
|
||||
// have been updated.
|
||||
updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch updated channel: %v", err)
|
||||
}
|
||||
assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment)
|
||||
numDiskUpdates, err := updatedChannel[0].CommitmentHeight()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read commitment height from disk: %v", err)
|
||||
}
|
||||
if numDiskUpdates != uint64(commitment.CommitHeight) {
|
||||
t.Fatalf("num disk updates doesn't match: %v vs %v",
|
||||
numDiskUpdates, commitment.CommitHeight)
|
||||
}
|
||||
|
||||
// Attempting to query for a commitment diff should return
|
||||
// ErrNoPendingCommit as we haven't yet created a new state for them.
|
||||
_, err = channel.RemoteCommitChainTip()
|
||||
if err != ErrNoPendingCommit {
|
||||
t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
|
||||
}
|
||||
|
||||
// To simulate us extending a new state to the remote party, we'll also
|
||||
// create a new commit diff for them.
|
||||
remoteCommit := commitment
|
||||
remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8)
|
||||
remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8)
|
||||
remoteCommit.CommitHeight = 1
|
||||
commitDiff := &CommitDiff{
|
||||
Commitment: remoteCommit,
|
||||
CommitSig: &lnwire.CommitSig{
|
||||
ChanID: lnwire.ChannelID(key),
|
||||
CommitSig: wireSig,
|
||||
HtlcSigs: []lnwire.Sig{
|
||||
wireSig,
|
||||
wireSig,
|
||||
},
|
||||
},
|
||||
LogUpdates: []LogUpdate{
|
||||
{
|
||||
LogIndex: 1,
|
||||
UpdateMsg: &lnwire.UpdateAddHTLC{
|
||||
ID: 1,
|
||||
Amount: lnwire.NewMSatFromSatoshis(100),
|
||||
Expiry: 25,
|
||||
},
|
||||
},
|
||||
{
|
||||
LogIndex: 2,
|
||||
UpdateMsg: &lnwire.UpdateAddHTLC{
|
||||
ID: 2,
|
||||
Amount: lnwire.NewMSatFromSatoshis(200),
|
||||
Expiry: 50,
|
||||
},
|
||||
},
|
||||
},
|
||||
OpenedCircuitKeys: []CircuitKey{},
|
||||
ClosedCircuitKeys: []CircuitKey{},
|
||||
}
|
||||
copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
|
||||
bytes.Repeat([]byte{1}, 32))
|
||||
copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
|
||||
bytes.Repeat([]byte{2}, 32))
|
||||
if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
|
||||
t.Fatalf("unable to add to commit chain: %v", err)
|
||||
}
|
||||
|
||||
// The commitment tip should now match the commitment that we just
|
||||
// inserted.
|
||||
diskCommitDiff, err := channel.RemoteCommitChainTip()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch commit diff: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(commitDiff, diskCommitDiff) {
|
||||
t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit),
|
||||
spew.Sdump(diskCommitDiff))
|
||||
}
|
||||
|
||||
// We'll save the old remote commitment as this will be added to the
|
||||
// revocation log shortly.
|
||||
oldRemoteCommit := channel.RemoteCommitment
|
||||
|
||||
// Next, write to the log which tracks the necessary revocation state
|
||||
// needed to rectify any fishy behavior by the remote party. Modify the
|
||||
// current uncollapsed revocation state to simulate a state transition
|
||||
// by the remote party.
|
||||
channel.RemoteCurrentRevocation = channel.RemoteNextRevocation
|
||||
newPriv, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate key: %v", err)
|
||||
}
|
||||
channel.RemoteNextRevocation = newPriv.PubKey()
|
||||
|
||||
fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight,
|
||||
diskCommitDiff.LogUpdates, nil)
|
||||
|
||||
err = channel.AdvanceCommitChainTail(fwdPkg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to append to revocation log: %v", err)
|
||||
}
|
||||
|
||||
// At this point, the remote commit chain should be nil, and the posted
|
||||
// remote commitment should match the one we added as a diff above.
|
||||
if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit {
|
||||
t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
|
||||
}
|
||||
|
||||
// We should be able to fetch the channel delta created above by its
|
||||
// update number with all the state properly reconstructed.
|
||||
diskPrevCommit, err := channel.FindPreviousState(
|
||||
oldRemoteCommit.CommitHeight,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch past delta: %v", err)
|
||||
}
|
||||
|
||||
// The two deltas (the original vs the on-disk version) should
|
||||
// identical, and all HTLC data should properly be retained.
|
||||
assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit)
|
||||
|
||||
// The state number recovered from the tail of the revocation log
|
||||
// should be identical to this current state.
|
||||
logTail, err := channel.RevocationLogTail()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to retrieve log: %v", err)
|
||||
}
|
||||
if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
|
||||
t.Fatal("update number doesn't match")
|
||||
}
|
||||
|
||||
oldRemoteCommit = channel.RemoteCommitment
|
||||
|
||||
// Next modify the posted diff commitment slightly, then create a new
|
||||
// commitment diff and advance the tail.
|
||||
commitDiff.Commitment.CommitHeight = 2
|
||||
commitDiff.Commitment.LocalBalance -= htlcAmt
|
||||
commitDiff.Commitment.RemoteBalance += htlcAmt
|
||||
commitDiff.LogUpdates = []LogUpdate{}
|
||||
if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
|
||||
t.Fatalf("unable to add to commit chain: %v", err)
|
||||
}
|
||||
|
||||
fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil)
|
||||
|
||||
err = channel.AdvanceCommitChainTail(fwdPkg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to append to revocation log: %v", err)
|
||||
}
|
||||
|
||||
// Once again, fetch the state and ensure it has been properly updated.
|
||||
prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch past delta: %v", err)
|
||||
}
|
||||
assertCommitmentEqual(t, &oldRemoteCommit, prevCommit)
|
||||
|
||||
// Once again, state number recovered from the tail of the revocation
|
||||
// log should be identical to this current state.
|
||||
logTail, err = channel.RevocationLogTail()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to retrieve log: %v", err)
|
||||
}
|
||||
if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
|
||||
t.Fatal("update number doesn't match")
|
||||
}
|
||||
|
||||
// The revocation state stored on-disk should now also be identical.
|
||||
updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch updated channel: %v", err)
|
||||
}
|
||||
if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) {
|
||||
t.Fatalf("revocation state was not synced")
|
||||
}
|
||||
if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) {
|
||||
t.Fatalf("revocation state was not synced")
|
||||
}
|
||||
|
||||
// Now attempt to delete the channel from the database.
|
||||
closeSummary := &ChannelCloseSummary{
|
||||
ChanPoint: channel.FundingOutpoint,
|
||||
RemotePub: channel.IdentityPub,
|
||||
SettledBalance: btcutil.Amount(500),
|
||||
TimeLockedBalance: btcutil.Amount(10000),
|
||||
IsPending: false,
|
||||
CloseType: RemoteForceClose,
|
||||
}
|
||||
if err := updatedChannel[0].CloseChannel(closeSummary); err != nil {
|
||||
t.Fatalf("unable to delete updated channel: %v", err)
|
||||
}
|
||||
|
||||
// If we attempt to fetch the target channel again, it shouldn't be
|
||||
// found.
|
||||
channels, err := cdb.FetchOpenChannels(channel.IdentityPub)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch updated channels: %v", err)
|
||||
}
|
||||
if len(channels) != 0 {
|
||||
t.Fatalf("%v channels, found, but none should be",
|
||||
len(channels))
|
||||
}
|
||||
|
||||
// Attempting to find previous states on the channel should fail as the
|
||||
// revocation log has been deleted.
|
||||
_, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight)
|
||||
if err == nil {
|
||||
t.Fatal("revocation log search should have failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchPendingChannels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// Create first test channel state
|
||||
state, err := createTestChannelState(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel state: %v", err)
|
||||
}
|
||||
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}
|
||||
|
||||
const broadcastHeight = 99
|
||||
if err := state.SyncPending(addr, broadcastHeight); err != nil {
|
||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
||||
}
|
||||
|
||||
pendingChannels, err := cdb.FetchPendingChannels()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to list pending channels: %v", err)
|
||||
}
|
||||
|
||||
if len(pendingChannels) != 1 {
|
||||
t.Fatalf("incorrect number of pending channels: expecting %v,"+
|
||||
"got %v", 1, len(pendingChannels))
|
||||
}
|
||||
|
||||
// The broadcast height of the pending channel should have been set
|
||||
// properly.
|
||||
if pendingChannels[0].FundingBroadcastHeight != broadcastHeight {
|
||||
t.Fatalf("broadcast height mismatch: expected %v, got %v",
|
||||
pendingChannels[0].FundingBroadcastHeight,
|
||||
broadcastHeight)
|
||||
}
|
||||
|
||||
chanOpenLoc := lnwire.ShortChannelID{
|
||||
BlockHeight: 5,
|
||||
TxIndex: 10,
|
||||
TxPosition: 15,
|
||||
}
|
||||
err = pendingChannels[0].MarkAsOpen(chanOpenLoc)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to mark channel as open: %v", err)
|
||||
}
|
||||
|
||||
if pendingChannels[0].IsPending {
|
||||
t.Fatalf("channel marked open should no longer be pending")
|
||||
}
|
||||
|
||||
if pendingChannels[0].ShortChanID() != chanOpenLoc {
|
||||
t.Fatalf("channel opening height not updated: expected %v, "+
|
||||
"got %v", spew.Sdump(pendingChannels[0].ShortChanID()),
|
||||
chanOpenLoc)
|
||||
}
|
||||
|
||||
// Next, we'll re-fetch the channel to ensure that the open height was
|
||||
// properly set.
|
||||
openChans, err := cdb.FetchAllChannels()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch channels: %v", err)
|
||||
}
|
||||
if openChans[0].ShortChanID() != chanOpenLoc {
|
||||
t.Fatalf("channel opening heights don't match: expected %v, "+
|
||||
"got %v", spew.Sdump(openChans[0].ShortChanID()),
|
||||
chanOpenLoc)
|
||||
}
|
||||
if openChans[0].FundingBroadcastHeight != broadcastHeight {
|
||||
t.Fatalf("broadcast height mismatch: expected %v, got %v",
|
||||
openChans[0].FundingBroadcastHeight,
|
||||
broadcastHeight)
|
||||
}
|
||||
|
||||
pendingChannels, err = cdb.FetchPendingChannels()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to list pending channels: %v", err)
|
||||
}
|
||||
|
||||
if len(pendingChannels) != 0 {
|
||||
t.Fatalf("incorrect number of pending channels: expecting %v,"+
|
||||
"got %v", 0, len(pendingChannels))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchClosedChannels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// First create a test channel, that we'll be closing within this pull
|
||||
// request.
|
||||
state, err := createTestChannelState(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel state: %v", err)
|
||||
}
|
||||
|
||||
// Next sync the channel to disk, marking it as being in a pending open
|
||||
// state.
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}
|
||||
const broadcastHeight = 99
|
||||
if err := state.SyncPending(addr, broadcastHeight); err != nil {
|
||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
||||
}
|
||||
|
||||
// Next, simulate the confirmation of the channel by marking it as
|
||||
// pending within the database.
|
||||
chanOpenLoc := lnwire.ShortChannelID{
|
||||
BlockHeight: 5,
|
||||
TxIndex: 10,
|
||||
TxPosition: 15,
|
||||
}
|
||||
err = state.MarkAsOpen(chanOpenLoc)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to mark channel as open: %v", err)
|
||||
}
|
||||
|
||||
// Next, close the channel by including a close channel summary in the
|
||||
// database.
|
||||
summary := &ChannelCloseSummary{
|
||||
ChanPoint: state.FundingOutpoint,
|
||||
ClosingTXID: rev,
|
||||
RemotePub: state.IdentityPub,
|
||||
Capacity: state.Capacity,
|
||||
SettledBalance: state.LocalCommitment.LocalBalance.ToSatoshis(),
|
||||
TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000,
|
||||
CloseType: RemoteForceClose,
|
||||
IsPending: true,
|
||||
LocalChanConfig: state.LocalChanCfg,
|
||||
}
|
||||
if err := state.CloseChannel(summary); err != nil {
|
||||
t.Fatalf("unable to close channel: %v", err)
|
||||
}
|
||||
|
||||
// Query the database to ensure that the channel has now been properly
|
||||
// closed. We should get the same result whether querying for pending
|
||||
// channels only, or not.
|
||||
pendingClosed, err := cdb.FetchClosedChannels(true)
|
||||
if err != nil {
|
||||
t.Fatalf("failed fetching closed channels: %v", err)
|
||||
}
|
||||
if len(pendingClosed) != 1 {
|
||||
t.Fatalf("incorrect number of pending closed channels: expecting %v,"+
|
||||
"got %v", 1, len(pendingClosed))
|
||||
}
|
||||
if !reflect.DeepEqual(summary, pendingClosed[0]) {
|
||||
t.Fatalf("database summaries don't match: expected %v got %v",
|
||||
spew.Sdump(summary), spew.Sdump(pendingClosed[0]))
|
||||
}
|
||||
closed, err := cdb.FetchClosedChannels(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed fetching all closed channels: %v", err)
|
||||
}
|
||||
if len(closed) != 1 {
|
||||
t.Fatalf("incorrect number of closed channels: expecting %v, "+
|
||||
"got %v", 1, len(closed))
|
||||
}
|
||||
if !reflect.DeepEqual(summary, closed[0]) {
|
||||
t.Fatalf("database summaries don't match: expected %v got %v",
|
||||
spew.Sdump(summary), spew.Sdump(closed[0]))
|
||||
}
|
||||
|
||||
// Mark the channel as fully closed.
|
||||
err = cdb.MarkChanFullyClosed(&state.FundingOutpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("failed fully closing channel: %v", err)
|
||||
}
|
||||
|
||||
// The channel should no longer be considered pending, but should still
|
||||
// be retrieved when fetching all the closed channels.
|
||||
closed, err = cdb.FetchClosedChannels(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed fetching closed channels: %v", err)
|
||||
}
|
||||
if len(closed) != 1 {
|
||||
t.Fatalf("incorrect number of closed channels: expecting %v, "+
|
||||
"got %v", 1, len(closed))
|
||||
}
|
||||
pendingClose, err := cdb.FetchClosedChannels(true)
|
||||
if err != nil {
|
||||
t.Fatalf("failed fetching channels pending close: %v", err)
|
||||
}
|
||||
if len(pendingClose) != 0 {
|
||||
t.Fatalf("incorrect number of closed channels: expecting %v, "+
|
||||
"got %v", 0, len(closed))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchWaitingCloseChannels ensures that the correct channels that are
|
||||
// waiting to be closed are returned.
|
||||
func TestFetchWaitingCloseChannels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChannels = 2
|
||||
const broadcastHeight = 99
|
||||
addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 18555}
|
||||
|
||||
// We'll start by creating two channels within our test database. One of
|
||||
// them will have their funding transaction confirmed on-chain, while
|
||||
// the other one will remain unconfirmed.
|
||||
db, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
channels := make([]*OpenChannel, numChannels)
|
||||
for i := 0; i < numChannels; i++ {
|
||||
channel, err := createTestChannelState(db)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel: %v", err)
|
||||
}
|
||||
err = channel.SyncPending(addr, broadcastHeight)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to sync channel: %v", err)
|
||||
}
|
||||
channels[i] = channel
|
||||
}
|
||||
|
||||
// We'll only confirm the first one.
|
||||
channelConf := lnwire.ShortChannelID{
|
||||
BlockHeight: broadcastHeight + 1,
|
||||
TxIndex: 10,
|
||||
TxPosition: 15,
|
||||
}
|
||||
if err := channels[0].MarkAsOpen(channelConf); err != nil {
|
||||
t.Fatalf("unable to mark channel as open: %v", err)
|
||||
}
|
||||
|
||||
// Then, we'll mark the channels as if their commitments were broadcast.
|
||||
// This would happen in the event of a force close and should make the
|
||||
// channels enter a state of waiting close.
|
||||
for _, channel := range channels {
|
||||
closeTx := wire.NewMsgTx(2)
|
||||
closeTx.AddTxIn(
|
||||
&wire.TxIn{
|
||||
PreviousOutPoint: channel.FundingOutpoint,
|
||||
},
|
||||
)
|
||||
if err := channel.MarkCommitmentBroadcasted(closeTx); err != nil {
|
||||
t.Fatalf("unable to mark commitment broadcast: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we'll fetch all the channels waiting to be closed from the
|
||||
// database. We should expect to see both channels above, even if any of
|
||||
// them haven't had their funding transaction confirm on-chain.
|
||||
waitingCloseChannels, err := db.FetchWaitingCloseChannels()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch all waiting close channels: %v", err)
|
||||
}
|
||||
if len(waitingCloseChannels) != 2 {
|
||||
t.Fatalf("expected %d channels waiting to be closed, got %d", 2,
|
||||
len(waitingCloseChannels))
|
||||
}
|
||||
expectedChannels := make(map[wire.OutPoint]struct{})
|
||||
for _, channel := range channels {
|
||||
expectedChannels[channel.FundingOutpoint] = struct{}{}
|
||||
}
|
||||
for _, channel := range waitingCloseChannels {
|
||||
if _, ok := expectedChannels[channel.FundingOutpoint]; !ok {
|
||||
t.Fatalf("expected channel %v to be waiting close",
|
||||
channel.FundingOutpoint)
|
||||
}
|
||||
|
||||
// Finally, make sure we can retrieve the closing tx for the
|
||||
// channel.
|
||||
closeTx, err := channel.BroadcastedCommitment()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to retrieve commitment: %v", err)
|
||||
}
|
||||
|
||||
if closeTx.TxIn[0].PreviousOutPoint != channel.FundingOutpoint {
|
||||
t.Fatalf("expected outpoint %v, got %v",
|
||||
channel.FundingOutpoint,
|
||||
closeTx.TxIn[0].PreviousOutPoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory
|
||||
// short channel ID of another OpenChannel to reflect a preceding call to
|
||||
// MarkOpen on a different OpenChannel.
|
||||
func TestRefreshShortChanID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// First create a test channel.
|
||||
state, err := createTestChannelState(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel state: %v", err)
|
||||
}
|
||||
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}
|
||||
|
||||
// Mark the channel as pending within the channeldb.
|
||||
const broadcastHeight = 99
|
||||
if err := state.SyncPending(addr, broadcastHeight); err != nil {
|
||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
||||
}
|
||||
|
||||
// Next, locate the pending channel with the database.
|
||||
pendingChannels, err := cdb.FetchPendingChannels()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to load pending channels; %v", err)
|
||||
}
|
||||
|
||||
var pendingChannel *OpenChannel
|
||||
for _, channel := range pendingChannels {
|
||||
if channel.FundingOutpoint == state.FundingOutpoint {
|
||||
pendingChannel = channel
|
||||
break
|
||||
}
|
||||
}
|
||||
if pendingChannel == nil {
|
||||
t.Fatalf("unable to find pending channel with funding "+
|
||||
"outpoint=%v: %v", state.FundingOutpoint, err)
|
||||
}
|
||||
|
||||
// Next, simulate the confirmation of the channel by marking it as
|
||||
// pending within the database.
|
||||
chanOpenLoc := lnwire.ShortChannelID{
|
||||
BlockHeight: 105,
|
||||
TxIndex: 10,
|
||||
TxPosition: 15,
|
||||
}
|
||||
|
||||
err = state.MarkAsOpen(chanOpenLoc)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to mark channel open: %v", err)
|
||||
}
|
||||
|
||||
// The short_chan_id of the receiver to MarkAsOpen should reflect the
|
||||
// open location, but the other pending channel should remain unchanged.
|
||||
if state.ShortChanID() == pendingChannel.ShortChanID() {
|
||||
t.Fatalf("pending channel short_chan_ID should not have been " +
|
||||
"updated before refreshing short_chan_id")
|
||||
}
|
||||
|
||||
// Now that the receiver's short channel id has been updated, check to
|
||||
// ensure that the channel packager's source has been updated as well.
|
||||
// This ensures that the packager will read and write to buckets
|
||||
// corresponding to the new short chan id, instead of the prior.
|
||||
if state.Packager.(*ChannelPackager).source != chanOpenLoc {
|
||||
t.Fatalf("channel packager source was not updated: want %v, "+
|
||||
"got %v", chanOpenLoc,
|
||||
state.Packager.(*ChannelPackager).source)
|
||||
}
|
||||
|
||||
// Now, refresh the short channel ID of the pending channel.
|
||||
err = pendingChannel.RefreshShortChanID()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to refresh short_chan_id: %v", err)
|
||||
}
|
||||
|
||||
// This should result in both OpenChannel's now having the same
|
||||
// ShortChanID.
|
||||
if state.ShortChanID() != pendingChannel.ShortChanID() {
|
||||
t.Fatalf("expected pending channel short_chan_id to be "+
|
||||
"refreshed: want %v, got %v", state.ShortChanID(),
|
||||
pendingChannel.ShortChanID())
|
||||
}
|
||||
|
||||
// Check to ensure that the _other_ OpenChannel channel packager's
|
||||
// source has also been updated after the refresh. This ensures that the
|
||||
// other packagers will read and write to buckets corresponding to the
|
||||
// updated short chan id.
|
||||
if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc {
|
||||
t.Fatalf("channel packager source was not updated: want %v, "+
|
||||
"got %v", chanOpenLoc,
|
||||
pendingChannel.Packager.(*ChannelPackager).source)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,12 +48,6 @@ type UnknownElementType struct {
|
|||
element interface{}
|
||||
}
|
||||
|
||||
// NewUnknownElementType creates a new UnknownElementType error from the passed
|
||||
// method name and element.
|
||||
func NewUnknownElementType(method string, el interface{}) UnknownElementType {
|
||||
return UnknownElementType{method: method, element: el}
|
||||
}
|
||||
|
||||
// Error returns the name of the method that encountered the error, as well as
|
||||
// the type that was unsupported.
|
||||
func (e UnknownElementType) Error() string {
|
||||
|
|
|
@ -4,16 +4,11 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -87,57 +82,6 @@ func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) {
|
|||
return chanDB, nil
|
||||
}
|
||||
|
||||
// Path returns the file path to the channel database.
|
||||
func (d *DB) Path() string {
|
||||
return d.dbPath
|
||||
}
|
||||
|
||||
// Wipe completely deletes all saved state within all used buckets within the
|
||||
// database. The deletion is done in a single transaction, therefore this
|
||||
// operation is fully atomic.
|
||||
func (d *DB) Wipe() error {
|
||||
return d.Update(func(tx *bbolt.Tx) error {
|
||||
err := tx.DeleteBucket(openChannelBucket)
|
||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.DeleteBucket(closedChannelBucket)
|
||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.DeleteBucket(invoiceBucket)
|
||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.DeleteBucket(nodeInfoBucket)
|
||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.DeleteBucket(nodeBucket)
|
||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
||||
return err
|
||||
}
|
||||
err = tx.DeleteBucket(edgeBucket)
|
||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
||||
return err
|
||||
}
|
||||
err = tx.DeleteBucket(edgeIndexBucket)
|
||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
||||
return err
|
||||
}
|
||||
err = tx.DeleteBucket(graphMetaBucket)
|
||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// createChannelDB creates and initializes a fresh version of channeldb. In
|
||||
// the case that the target path has not yet been created or doesn't yet exist,
|
||||
// then the path is created. Additionally, all required top-level buckets used
|
||||
|
@ -163,14 +107,6 @@ func createChannelDB(dbPath string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.CreateBucket(forwardingLogBucket); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.CreateBucket(fwdPackagesKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.CreateBucket(invoiceBucket); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -179,10 +115,6 @@ func createChannelDB(dbPath string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.CreateBucket(nodeInfoBucket); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nodes, err := tx.CreateBucket(nodeBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -249,359 +181,6 @@ func fileExists(path string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// FetchOpenChannels starts a new database transaction and returns all stored
|
||||
// currently active/open channels associated with the target nodeID. In the case
|
||||
// that no active channels are known to have been created with this node, then a
|
||||
// zero-length slice is returned.
|
||||
func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
|
||||
var channels []*OpenChannel
|
||||
err := d.View(func(tx *bbolt.Tx) error {
|
||||
var err error
|
||||
channels, err = d.fetchOpenChannels(tx, nodeID)
|
||||
return err
|
||||
})
|
||||
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// fetchOpenChannels uses and existing database transaction and returns all
|
||||
// stored currently active/open channels associated with the target nodeID. In
|
||||
// the case that no active channels are known to have been created with this
|
||||
// node, then a zero-length slice is returned.
|
||||
func (d *DB) fetchOpenChannels(tx *bbolt.Tx,
|
||||
nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
|
||||
|
||||
// Get the bucket dedicated to storing the metadata for open channels.
|
||||
openChanBucket := tx.Bucket(openChannelBucket)
|
||||
if openChanBucket == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Within this top level bucket, fetch the bucket dedicated to storing
|
||||
// open channel data specific to the remote node.
|
||||
pub := nodeID.SerializeCompressed()
|
||||
nodeChanBucket := openChanBucket.Bucket(pub)
|
||||
if nodeChanBucket == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Next, we'll need to go down an additional layer in order to retrieve
|
||||
// the channels for each chain the node knows of.
|
||||
var channels []*OpenChannel
|
||||
err := nodeChanBucket.ForEach(func(chainHash, v []byte) error {
|
||||
// If there's a value, it's not a bucket so ignore it.
|
||||
if v != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we've found a valid chainhash bucket, then we'll retrieve
|
||||
// that so we can extract all the channels.
|
||||
chainBucket := nodeChanBucket.Bucket(chainHash)
|
||||
if chainBucket == nil {
|
||||
return fmt.Errorf("unable to read bucket for chain=%x",
|
||||
chainHash[:])
|
||||
}
|
||||
|
||||
// Finally, we both of the necessary buckets retrieved, fetch
|
||||
// all the active channels related to this node.
|
||||
nodeChannels, err := d.fetchNodeChannels(chainBucket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read channel for "+
|
||||
"chain_hash=%x, node_key=%x: %v",
|
||||
chainHash[:], pub, err)
|
||||
}
|
||||
|
||||
channels = append(channels, nodeChannels...)
|
||||
return nil
|
||||
})
|
||||
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// fetchNodeChannels retrieves all active channels from the target chainBucket
|
||||
// which is under a node's dedicated channel bucket. This function is typically
|
||||
// used to fetch all the active channels related to a particular node.
|
||||
func (d *DB) fetchNodeChannels(chainBucket *bbolt.Bucket) ([]*OpenChannel, error) {
|
||||
|
||||
var channels []*OpenChannel
|
||||
|
||||
// A node may have channels on several chains, so for each known chain,
|
||||
// we'll extract all the channels.
|
||||
err := chainBucket.ForEach(func(chanPoint, v []byte) error {
|
||||
// If there's a value, it's not a bucket so ignore it.
|
||||
if v != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Once we've found a valid channel bucket, we'll extract it
|
||||
// from the node's chain bucket.
|
||||
chanBucket := chainBucket.Bucket(chanPoint)
|
||||
|
||||
var outPoint wire.OutPoint
|
||||
err := readOutpoint(bytes.NewReader(chanPoint), &outPoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oChannel, err := fetchOpenChannel(chanBucket, &outPoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read channel data for "+
|
||||
"chan_point=%v: %v", outPoint, err)
|
||||
}
|
||||
oChannel.Db = d
|
||||
|
||||
channels = append(channels, oChannel)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// FetchChannel attempts to locate a channel specified by the passed channel
|
||||
// point. If the channel cannot be found, then an error will be returned.
|
||||
func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) {
|
||||
var (
|
||||
targetChan *OpenChannel
|
||||
targetChanPoint bytes.Buffer
|
||||
)
|
||||
|
||||
if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// chanScan will traverse the following bucket structure:
|
||||
// * nodePub => chainHash => chanPoint
|
||||
//
|
||||
// At each level we go one further, ensuring that we're traversing the
|
||||
// proper key (that's actually a bucket). By only reading the bucket
|
||||
// structure and skipping fully decoding each channel, we save a good
|
||||
// bit of CPU as we don't need to do things like decompress public
|
||||
// keys.
|
||||
chanScan := func(tx *bbolt.Tx) error {
|
||||
// Get the bucket dedicated to storing the metadata for open
|
||||
// channels.
|
||||
openChanBucket := tx.Bucket(openChannelBucket)
|
||||
if openChanBucket == nil {
|
||||
return ErrNoActiveChannels
|
||||
}
|
||||
|
||||
// Within the node channel bucket, are the set of node pubkeys
|
||||
// we have channels with, we don't know the entire set, so
|
||||
// we'll check them all.
|
||||
return openChanBucket.ForEach(func(nodePub, v []byte) error {
|
||||
// Ensure that this is a key the same size as a pubkey,
|
||||
// and also that it leads directly to a bucket.
|
||||
if len(nodePub) != 33 || v != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
nodeChanBucket := openChanBucket.Bucket(nodePub)
|
||||
if nodeChanBucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The next layer down is all the chains that this node
|
||||
// has channels on with us.
|
||||
return nodeChanBucket.ForEach(func(chainHash, v []byte) error {
|
||||
// If there's a value, it's not a bucket so
|
||||
// ignore it.
|
||||
if v != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainBucket := nodeChanBucket.Bucket(chainHash)
|
||||
if chainBucket == nil {
|
||||
return fmt.Errorf("unable to read "+
|
||||
"bucket for chain=%x", chainHash[:])
|
||||
}
|
||||
|
||||
// Finally we reach the leaf bucket that stores
|
||||
// all the chanPoints for this node.
|
||||
chanBucket := chainBucket.Bucket(
|
||||
targetChanPoint.Bytes(),
|
||||
)
|
||||
if chanBucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
channel, err := fetchOpenChannel(
|
||||
chanBucket, &chanPoint,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetChan = channel
|
||||
targetChan.Db = d
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
err := d.View(chanScan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if targetChan != nil {
|
||||
return targetChan, nil
|
||||
}
|
||||
|
||||
// If we can't find the channel, then we return with an error, as we
|
||||
// have nothing to backup.
|
||||
return nil, ErrChannelNotFound
|
||||
}
|
||||
|
||||
// FetchAllChannels attempts to retrieve all open channels currently stored
|
||||
// within the database, including pending open, fully open and channels waiting
|
||||
// for a closing transaction to confirm.
|
||||
func (d *DB) FetchAllChannels() ([]*OpenChannel, error) {
|
||||
var channels []*OpenChannel
|
||||
|
||||
// TODO(halseth): fetch all in one db tx.
|
||||
openChannels, err := d.FetchAllOpenChannels()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channels = append(channels, openChannels...)
|
||||
|
||||
pendingChannels, err := d.FetchPendingChannels()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channels = append(channels, pendingChannels...)
|
||||
|
||||
waitingClose, err := d.FetchWaitingCloseChannels()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channels = append(channels, waitingClose...)
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// FetchAllOpenChannels will return all channels that have the funding
|
||||
// transaction confirmed, and is not waiting for a closing transaction to be
|
||||
// confirmed.
|
||||
func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) {
|
||||
return fetchChannels(d, false, false)
|
||||
}
|
||||
|
||||
// FetchPendingChannels will return channels that have completed the process of
|
||||
// generating and broadcasting funding transactions, but whose funding
|
||||
// transactions have yet to be confirmed on the blockchain.
|
||||
func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) {
|
||||
return fetchChannels(d, true, false)
|
||||
}
|
||||
|
||||
// FetchWaitingCloseChannels will return all channels that have been opened,
|
||||
// but are now waiting for a closing transaction to be confirmed.
|
||||
//
|
||||
// NOTE: This includes channels that are also pending to be opened.
|
||||
func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) {
|
||||
waitingClose, err := fetchChannels(d, false, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pendingWaitingClose, err := fetchChannels(d, true, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(waitingClose, pendingWaitingClose...), nil
|
||||
}
|
||||
|
||||
// fetchChannels attempts to retrieve channels currently stored in the
|
||||
// database. The pending parameter determines whether only pending channels
|
||||
// will be returned, or only open channels will be returned. The waitingClose
|
||||
// parameter determines whether only channels waiting for a closing transaction
|
||||
// to be confirmed should be returned. If no active channels exist within the
|
||||
// network, then ErrNoActiveChannels is returned.
|
||||
func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) {
|
||||
var channels []*OpenChannel
|
||||
|
||||
err := d.View(func(tx *bbolt.Tx) error {
|
||||
// Get the bucket dedicated to storing the metadata for open
|
||||
// channels.
|
||||
openChanBucket := tx.Bucket(openChannelBucket)
|
||||
if openChanBucket == nil {
|
||||
return ErrNoActiveChannels
|
||||
}
|
||||
|
||||
// Next, fetch the bucket dedicated to storing metadata related
|
||||
// to all nodes. All keys within this bucket are the serialized
|
||||
// public keys of all our direct counterparties.
|
||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
||||
if nodeMetaBucket == nil {
|
||||
return fmt.Errorf("node bucket not created")
|
||||
}
|
||||
|
||||
// Finally for each node public key in the bucket, fetch all
|
||||
// the channels related to this particular node.
|
||||
return nodeMetaBucket.ForEach(func(k, v []byte) error {
|
||||
nodeChanBucket := openChanBucket.Bucket(k)
|
||||
if nodeChanBucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nodeChanBucket.ForEach(func(chainHash, v []byte) error {
|
||||
// If there's a value, it's not a bucket so
|
||||
// ignore it.
|
||||
if v != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we've found a valid chainhash bucket,
|
||||
// then we'll retrieve that so we can extract
|
||||
// all the channels.
|
||||
chainBucket := nodeChanBucket.Bucket(chainHash)
|
||||
if chainBucket == nil {
|
||||
return fmt.Errorf("unable to read "+
|
||||
"bucket for chain=%x", chainHash[:])
|
||||
}
|
||||
|
||||
nodeChans, err := d.fetchNodeChannels(chainBucket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read "+
|
||||
"channel for chain_hash=%x, "+
|
||||
"node_key=%x: %v", chainHash[:], k, err)
|
||||
}
|
||||
for _, channel := range nodeChans {
|
||||
if channel.IsPending != pending {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the channel is in any other state
|
||||
// than Default, then it means it is
|
||||
// waiting to be closed.
|
||||
channelWaitingClose :=
|
||||
channel.ChanStatus() != ChanStatusDefault
|
||||
|
||||
// Only include it if we requested
|
||||
// channels with the same waitingClose
|
||||
// status.
|
||||
if channelWaitingClose != waitingClose {
|
||||
continue
|
||||
}
|
||||
|
||||
channels = append(channels, channel)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// FetchClosedChannels attempts to fetch all closed channels from the database.
|
||||
// The pendingOnly bool toggles if channels that aren't yet fully closed should
|
||||
// be returned in the response or not. When a channel was cooperatively closed,
|
||||
|
@ -641,371 +220,6 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
|
|||
return chanSummaries, nil
|
||||
}
|
||||
|
||||
// ErrClosedChannelNotFound signals that a closed channel could not be found in
|
||||
// the channeldb.
|
||||
var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary")
|
||||
|
||||
// FetchClosedChannel queries for a channel close summary using the channel
|
||||
// point of the channel in question.
|
||||
func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) {
|
||||
var chanSummary *ChannelCloseSummary
|
||||
if err := d.View(func(tx *bbolt.Tx) error {
|
||||
closeBucket := tx.Bucket(closedChannelBucket)
|
||||
if closeBucket == nil {
|
||||
return ErrClosedChannelNotFound
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
var err error
|
||||
if err = writeOutpoint(&b, chanID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
summaryBytes := closeBucket.Get(b.Bytes())
|
||||
if summaryBytes == nil {
|
||||
return ErrClosedChannelNotFound
|
||||
}
|
||||
|
||||
summaryReader := bytes.NewReader(summaryBytes)
|
||||
chanSummary, err = deserializeCloseChannelSummary(summaryReader)
|
||||
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return chanSummary, nil
|
||||
}
|
||||
|
||||
// FetchClosedChannelForID queries for a channel close summary using the
|
||||
// channel ID of the channel in question.
|
||||
func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
|
||||
*ChannelCloseSummary, error) {
|
||||
|
||||
var chanSummary *ChannelCloseSummary
|
||||
if err := d.View(func(tx *bbolt.Tx) error {
|
||||
closeBucket := tx.Bucket(closedChannelBucket)
|
||||
if closeBucket == nil {
|
||||
return ErrClosedChannelNotFound
|
||||
}
|
||||
|
||||
// The first 30 bytes of the channel ID and outpoint will be
|
||||
// equal.
|
||||
cursor := closeBucket.Cursor()
|
||||
op, c := cursor.Seek(cid[:30])
|
||||
|
||||
// We scan over all possible candidates for this channel ID.
|
||||
for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() {
|
||||
var outPoint wire.OutPoint
|
||||
err := readOutpoint(bytes.NewReader(op), &outPoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the found outpoint does not correspond to this
|
||||
// channel ID, we continue.
|
||||
if !cid.IsChanPoint(&outPoint) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Deserialize the close summary and return.
|
||||
r := bytes.NewReader(c)
|
||||
chanSummary, err = deserializeCloseChannelSummary(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return ErrClosedChannelNotFound
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return chanSummary, nil
|
||||
}
|
||||
|
||||
// MarkChanFullyClosed marks a channel as fully closed within the database. A
|
||||
// channel should be marked as fully closed if the channel was initially
|
||||
// cooperatively closed and it's reached a single confirmation, or after all
|
||||
// the pending funds in a channel that has been forcibly closed have been
|
||||
// swept.
|
||||
func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
|
||||
return d.Update(func(tx *bbolt.Tx) error {
|
||||
var b bytes.Buffer
|
||||
if err := writeOutpoint(&b, chanPoint); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chanID := b.Bytes()
|
||||
|
||||
closedChanBucket, err := tx.CreateBucketIfNotExists(
|
||||
closedChannelBucket,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chanSummaryBytes := closedChanBucket.Get(chanID)
|
||||
if chanSummaryBytes == nil {
|
||||
return fmt.Errorf("no closed channel for "+
|
||||
"chan_point=%v found", chanPoint)
|
||||
}
|
||||
|
||||
chanSummaryReader := bytes.NewReader(chanSummaryBytes)
|
||||
chanSummary, err := deserializeCloseChannelSummary(
|
||||
chanSummaryReader,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chanSummary.IsPending = false
|
||||
|
||||
var newSummary bytes.Buffer
|
||||
err = serializeChannelCloseSummary(&newSummary, chanSummary)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = closedChanBucket.Put(chanID, newSummary.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now that the channel is closed, we'll check if we have any
|
||||
// other open channels with this peer. If we don't we'll
|
||||
// garbage collect it to ensure we don't establish persistent
|
||||
// connections to peers without open channels.
|
||||
return d.pruneLinkNode(tx, chanSummary.RemotePub)
|
||||
})
|
||||
}
|
||||
|
||||
// pruneLinkNode determines whether we should garbage collect a link node from
|
||||
// the database due to no longer having any open channels with it. If there are
|
||||
// any left, then this acts as a no-op.
|
||||
func (d *DB) pruneLinkNode(tx *bbolt.Tx, remotePub *btcec.PublicKey) error {
|
||||
openChannels, err := d.fetchOpenChannels(tx, remotePub)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch open channels for peer %x: "+
|
||||
"%v", remotePub.SerializeCompressed(), err)
|
||||
}
|
||||
|
||||
if len(openChannels) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("Pruning link node %x with zero open channels from database",
|
||||
remotePub.SerializeCompressed())
|
||||
|
||||
return d.deleteLinkNode(tx, remotePub)
|
||||
}
|
||||
|
||||
// PruneLinkNodes attempts to prune all link nodes found within the databse with
|
||||
// whom we no longer have any open channels with.
|
||||
func (d *DB) PruneLinkNodes() error {
|
||||
return d.Update(func(tx *bbolt.Tx) error {
|
||||
linkNodes, err := d.fetchAllLinkNodes(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, linkNode := range linkNodes {
|
||||
err := d.pruneLinkNode(tx, linkNode.IdentityPub)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ChannelShell is a shell of a channel that is meant to be used for channel
|
||||
// recovery purposes. It contains a minimal OpenChannel instance along with
|
||||
// addresses for that target node.
|
||||
type ChannelShell struct {
|
||||
// NodeAddrs the set of addresses that this node has known to be
|
||||
// reachable at in the past.
|
||||
NodeAddrs []net.Addr
|
||||
|
||||
// Chan is a shell of an OpenChannel, it contains only the items
|
||||
// required to restore the channel on disk.
|
||||
Chan *OpenChannel
|
||||
}
|
||||
|
||||
// RestoreChannelShells is a method that allows the caller to reconstruct the
|
||||
// state of an OpenChannel from the ChannelShell. We'll attempt to write the
|
||||
// new channel to disk, create a LinkNode instance with the passed node
|
||||
// addresses, and finally create an edge within the graph for the channel as
|
||||
// well. This method is idempotent, so repeated calls with the same set of
|
||||
// channel shells won't modify the database after the initial call.
|
||||
func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
|
||||
chanGraph := d.ChannelGraph()
|
||||
|
||||
// TODO(conner): find way to do this w/o accessing internal members?
|
||||
chanGraph.cacheMu.Lock()
|
||||
defer chanGraph.cacheMu.Unlock()
|
||||
|
||||
var chansRestored []uint64
|
||||
err := d.Update(func(tx *bbolt.Tx) error {
|
||||
for _, channelShell := range channelShells {
|
||||
channel := channelShell.Chan
|
||||
|
||||
// When we make a channel, we mark that the channel has
|
||||
// been restored, this will signal to other sub-systems
|
||||
// to not attempt to use the channel as if it was a
|
||||
// regular one.
|
||||
channel.chanStatus |= ChanStatusRestored
|
||||
|
||||
// First, we'll attempt to create a new open channel
|
||||
// and link node for this channel. If the channel
|
||||
// already exists, then in order to ensure this method
|
||||
// is idempotent, we'll continue to the next step.
|
||||
channel.Db = d
|
||||
err := syncNewChannel(
|
||||
tx, channel, channelShell.NodeAddrs,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Next, we'll create an active edge in the graph
|
||||
// database for this channel in order to restore our
|
||||
// partial view of the network.
|
||||
//
|
||||
// TODO(roasbeef): if we restore *after* the channel
|
||||
// has been closed on chain, then need to inform the
|
||||
// router that it should try and prune these values as
|
||||
// we can detect them
|
||||
edgeInfo := ChannelEdgeInfo{
|
||||
ChannelID: channel.ShortChannelID.ToUint64(),
|
||||
ChainHash: channel.ChainHash,
|
||||
ChannelPoint: channel.FundingOutpoint,
|
||||
Capacity: channel.Capacity,
|
||||
}
|
||||
|
||||
nodes := tx.Bucket(nodeBucket)
|
||||
if nodes == nil {
|
||||
return ErrGraphNotFound
|
||||
}
|
||||
selfNode, err := chanGraph.sourceNode(nodes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Depending on which pub key is smaller, we'll assign
|
||||
// our roles as "node1" and "node2".
|
||||
chanPeer := channel.IdentityPub.SerializeCompressed()
|
||||
selfIsSmaller := bytes.Compare(
|
||||
selfNode.PubKeyBytes[:], chanPeer,
|
||||
) == -1
|
||||
if selfIsSmaller {
|
||||
copy(edgeInfo.NodeKey1Bytes[:], selfNode.PubKeyBytes[:])
|
||||
copy(edgeInfo.NodeKey2Bytes[:], chanPeer)
|
||||
} else {
|
||||
copy(edgeInfo.NodeKey1Bytes[:], chanPeer)
|
||||
copy(edgeInfo.NodeKey2Bytes[:], selfNode.PubKeyBytes[:])
|
||||
}
|
||||
|
||||
// With the edge info shell constructed, we'll now add
|
||||
// it to the graph.
|
||||
err = chanGraph.addChannelEdge(tx, &edgeInfo)
|
||||
if err != nil && err != ErrEdgeAlreadyExist {
|
||||
return err
|
||||
}
|
||||
|
||||
// Similarly, we'll construct a channel edge shell and
|
||||
// add that itself to the graph.
|
||||
chanEdge := ChannelEdgePolicy{
|
||||
ChannelID: edgeInfo.ChannelID,
|
||||
LastUpdate: time.Now(),
|
||||
}
|
||||
|
||||
// If their pubkey is larger, then we'll flip the
|
||||
// direction bit to indicate that us, the "second" node
|
||||
// is updating their policy.
|
||||
if !selfIsSmaller {
|
||||
chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection
|
||||
}
|
||||
|
||||
_, err = updateEdgePolicy(tx, &chanEdge)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chansRestored = append(chansRestored, edgeInfo.ChannelID)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, chanid := range chansRestored {
|
||||
chanGraph.rejectCache.remove(chanid)
|
||||
chanGraph.chanCache.remove(chanid)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddrsForNode consults the graph and channel database for all addresses known
|
||||
// to the passed node public key.
|
||||
func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
|
||||
var (
|
||||
linkNode *LinkNode
|
||||
graphNode LightningNode
|
||||
)
|
||||
|
||||
dbErr := d.View(func(tx *bbolt.Tx) error {
|
||||
var err error
|
||||
|
||||
linkNode, err = fetchLinkNode(tx, nodePub)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We'll also query the graph for this peer to see if they have
|
||||
// any addresses that we don't currently have stored within the
|
||||
// link node database.
|
||||
nodes := tx.Bucket(nodeBucket)
|
||||
if nodes == nil {
|
||||
return ErrGraphNotFound
|
||||
}
|
||||
compressedPubKey := nodePub.SerializeCompressed()
|
||||
graphNode, err = fetchLightningNode(nodes, compressedPubKey)
|
||||
if err != nil && err != ErrGraphNodeNotFound {
|
||||
// If the node isn't found, then that's OK, as we still
|
||||
// have the link node data.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
|
||||
// Now that we have both sources of addrs for this node, we'll use a
|
||||
// map to de-duplicate any addresses between the two sources, and
|
||||
// produce a final list of the combined addrs.
|
||||
addrs := make(map[string]net.Addr)
|
||||
for _, addr := range linkNode.Addresses {
|
||||
addrs[addr.String()] = addr
|
||||
}
|
||||
for _, addr := range graphNode.Addresses {
|
||||
addrs[addr.String()] = addr
|
||||
}
|
||||
dedupedAddrs := make([]net.Addr, 0, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
dedupedAddrs = append(dedupedAddrs, addr)
|
||||
}
|
||||
|
||||
return dedupedAddrs, nil
|
||||
}
|
||||
|
||||
// syncVersions function is used for safe db version synchronization. It
|
||||
// applies migration functions to the current database and recovers the
|
||||
// previous state of db if at least one error/panic appeared during migration.
|
||||
|
|
|
@ -1,471 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/shachain"
|
||||
)
|
||||
|
||||
func TestOpenWithCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First, create a temporary directory to be used for the duration of
|
||||
// this test.
|
||||
tempDirName, err := ioutil.TempDir("", "channeldb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDirName)
|
||||
|
||||
// Next, open thereby creating channeldb for the first time.
|
||||
dbPath := filepath.Join(tempDirName, "cdb")
|
||||
cdb, err := Open(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channeldb: %v", err)
|
||||
}
|
||||
if err := cdb.Close(); err != nil {
|
||||
t.Fatalf("unable to close channeldb: %v", err)
|
||||
}
|
||||
|
||||
// The path should have been successfully created.
|
||||
if !fileExists(dbPath) {
|
||||
t.Fatalf("channeldb failed to create data directory")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWipe tests that the database wipe operation completes successfully
|
||||
// and that the buckets are deleted. It also checks that attempts to fetch
|
||||
// information while the buckets are not set return the correct errors.
|
||||
func TestWipe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First, create a temporary directory to be used for the duration of
|
||||
// this test.
|
||||
tempDirName, err := ioutil.TempDir("", "channeldb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDirName)
|
||||
|
||||
// Next, open thereby creating channeldb for the first time.
|
||||
dbPath := filepath.Join(tempDirName, "cdb")
|
||||
cdb, err := Open(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channeldb: %v", err)
|
||||
}
|
||||
defer cdb.Close()
|
||||
|
||||
if err := cdb.Wipe(); err != nil {
|
||||
t.Fatalf("unable to wipe channeldb: %v", err)
|
||||
}
|
||||
// Check correct errors are returned
|
||||
_, err = cdb.FetchAllOpenChannels()
|
||||
if err != ErrNoActiveChannels {
|
||||
t.Fatalf("fetching open channels: expected '%v' instead got '%v'",
|
||||
ErrNoActiveChannels, err)
|
||||
}
|
||||
_, err = cdb.FetchClosedChannels(false)
|
||||
if err != ErrNoClosedChannels {
|
||||
t.Fatalf("fetching closed channels: expected '%v' instead got '%v'",
|
||||
ErrNoClosedChannels, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchClosedChannelForID tests that we are able to properly retrieve a
|
||||
// ChannelCloseSummary from the DB given a ChannelID.
|
||||
func TestFetchClosedChannelForID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChans = 101
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// Create the test channel state, that we will mutate the index of the
|
||||
// funding point.
|
||||
state, err := createTestChannelState(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel state: %v", err)
|
||||
}
|
||||
|
||||
// Now run through the number of channels, and modify the outpoint index
|
||||
// to create new channel IDs.
|
||||
for i := uint32(0); i < numChans; i++ {
|
||||
// Save the open channel to disk.
|
||||
state.FundingOutpoint.Index = i
|
||||
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18556,
|
||||
}
|
||||
if err := state.SyncPending(addr, 101); err != nil {
|
||||
t.Fatalf("unable to save and serialize channel "+
|
||||
"state: %v", err)
|
||||
}
|
||||
|
||||
// Close the channel. To make sure we retrieve the correct
|
||||
// summary later, we make them differ in the SettledBalance.
|
||||
closeSummary := &ChannelCloseSummary{
|
||||
ChanPoint: state.FundingOutpoint,
|
||||
RemotePub: state.IdentityPub,
|
||||
SettledBalance: btcutil.Amount(500 + i),
|
||||
}
|
||||
if err := state.CloseChannel(closeSummary); err != nil {
|
||||
t.Fatalf("unable to close channel: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now run though them all again and make sure we are able to retrieve
|
||||
// summaries from the DB.
|
||||
for i := uint32(0); i < numChans; i++ {
|
||||
state.FundingOutpoint.Index = i
|
||||
|
||||
// We calculate the ChannelID and use it to fetch the summary.
|
||||
cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint)
|
||||
fetchedSummary, err := cdb.FetchClosedChannelForID(cid)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch close summary: %v", err)
|
||||
}
|
||||
|
||||
// Make sure we retrieved the correct one by checking the
|
||||
// SettledBalance.
|
||||
if fetchedSummary.SettledBalance != btcutil.Amount(500+i) {
|
||||
t.Fatalf("summaries don't match: expected %v got %v",
|
||||
btcutil.Amount(500+i),
|
||||
fetchedSummary.SettledBalance)
|
||||
}
|
||||
}
|
||||
|
||||
// As a final test we make sure that we get ErrClosedChannelNotFound
|
||||
// for a ChannelID we didn't add to the DB.
|
||||
state.FundingOutpoint.Index++
|
||||
cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint)
|
||||
_, err = cdb.FetchClosedChannelForID(cid)
|
||||
if err != ErrClosedChannelNotFound {
|
||||
t.Fatalf("expected ErrClosedChannelNotFound, instead got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddrsForNode tests the we're able to properly obtain all the addresses
|
||||
// for a target node.
|
||||
func TestAddrsForNode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
graph := cdb.ChannelGraph()
|
||||
|
||||
// We'll make a test vertex to insert into the database, as the source
|
||||
// node, but this node will only have half the number of addresses it
|
||||
// usually does.
|
||||
testNode, err := createTestVertex(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test node: %v", err)
|
||||
}
|
||||
testNode.Addresses = []net.Addr{testAddr}
|
||||
if err := graph.SetSourceNode(testNode); err != nil {
|
||||
t.Fatalf("unable to set source node: %v", err)
|
||||
}
|
||||
|
||||
// Next, we'll make a link node with the same pubkey, but with an
|
||||
// additional address.
|
||||
nodePub, err := testNode.PubKey()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to recv node pub: %v", err)
|
||||
}
|
||||
linkNode := cdb.NewLinkNode(
|
||||
wire.MainNet, nodePub, anotherAddr,
|
||||
)
|
||||
if err := linkNode.Sync(); err != nil {
|
||||
t.Fatalf("unable to sync link node: %v", err)
|
||||
}
|
||||
|
||||
// Now that we've created a link node, as well as a vertex for the
|
||||
// node, we'll query for all its addresses.
|
||||
nodeAddrs, err := cdb.AddrsForNode(nodePub)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to obtain node addrs: %v", err)
|
||||
}
|
||||
|
||||
expectedAddrs := make(map[string]struct{})
|
||||
expectedAddrs[testAddr.String()] = struct{}{}
|
||||
expectedAddrs[anotherAddr.String()] = struct{}{}
|
||||
|
||||
// Finally, ensure that all the expected addresses are found.
|
||||
if len(nodeAddrs) != len(expectedAddrs) {
|
||||
t.Fatalf("expected %v addrs, got %v",
|
||||
len(expectedAddrs), len(nodeAddrs))
|
||||
}
|
||||
for _, addr := range nodeAddrs {
|
||||
if _, ok := expectedAddrs[addr.String()]; !ok {
|
||||
t.Fatalf("unexpected addr: %v", addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchChannel tests that we're able to fetch an arbitrary channel from
|
||||
// disk.
|
||||
func TestFetchChannel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// Create the test channel state that we'll sync to the database
|
||||
// shortly.
|
||||
channelState, err := createTestChannelState(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel state: %v", err)
|
||||
}
|
||||
|
||||
// Mark the channel as pending, then immediately mark it as open to it
|
||||
// can be fully visible.
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}
|
||||
if err := channelState.SyncPending(addr, 9); err != nil {
|
||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
||||
}
|
||||
err = channelState.MarkAsOpen(lnwire.NewShortChanIDFromInt(99))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to mark channel open: %v", err)
|
||||
}
|
||||
|
||||
// Next, attempt to fetch the channel by its chan point.
|
||||
dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch channel: %v", err)
|
||||
}
|
||||
|
||||
// The decoded channel state should be identical to what we stored
|
||||
// above.
|
||||
if !reflect.DeepEqual(channelState, dbChannel) {
|
||||
t.Fatalf("channel state doesn't match:: %v vs %v",
|
||||
spew.Sdump(channelState), spew.Sdump(dbChannel))
|
||||
}
|
||||
|
||||
// If we attempt to query for a non-exist ante channel, then we should
|
||||
// get an error.
|
||||
channelState2, err := createTestChannelState(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel state: %v", err)
|
||||
}
|
||||
channelState2.FundingOutpoint.Index ^= 1
|
||||
|
||||
_, err = cdb.FetchChannel(channelState2.FundingOutpoint)
|
||||
if err == nil {
|
||||
t.Fatalf("expected query to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func genRandomChannelShell() (*ChannelShell, error) {
|
||||
var testPriv [32]byte
|
||||
if _, err := rand.Read(testPriv[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:])
|
||||
|
||||
var chanPoint wire.OutPoint
|
||||
if _, err := rand.Read(chanPoint.Hash[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pub.Curve = nil
|
||||
|
||||
chanPoint.Index = uint32(rand.Intn(math.MaxUint16))
|
||||
|
||||
chanStatus := ChanStatusDefault | ChanStatusRestored
|
||||
|
||||
var shaChainPriv [32]byte
|
||||
if _, err := rand.Read(testPriv[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
revRoot, err := chainhash.NewHash(shaChainPriv[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shaChainProducer := shachain.NewRevocationProducer(*revRoot)
|
||||
|
||||
return &ChannelShell{
|
||||
NodeAddrs: []net.Addr{&net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}},
|
||||
Chan: &OpenChannel{
|
||||
chanStatus: chanStatus,
|
||||
ChainHash: rev,
|
||||
FundingOutpoint: chanPoint,
|
||||
ShortChannelID: lnwire.NewShortChanIDFromInt(
|
||||
uint64(rand.Int63()),
|
||||
),
|
||||
IdentityPub: pub,
|
||||
LocalChanCfg: ChannelConfig{
|
||||
ChannelConstraints: ChannelConstraints{
|
||||
CsvDelay: uint16(rand.Int63()),
|
||||
},
|
||||
PaymentBasePoint: keychain.KeyDescriptor{
|
||||
KeyLocator: keychain.KeyLocator{
|
||||
Family: keychain.KeyFamily(rand.Int63()),
|
||||
Index: uint32(rand.Int63()),
|
||||
},
|
||||
},
|
||||
},
|
||||
RemoteCurrentRevocation: pub,
|
||||
IsPending: false,
|
||||
RevocationStore: shachain.NewRevocationStore(),
|
||||
RevocationProducer: shaChainProducer,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestRestoreChannelShells tests that we're able to insert a partially channel
|
||||
// populated to disk. This is useful for channel recovery purposes. We should
|
||||
// find the new channel shell on disk, and also the db should be populated with
|
||||
// an edge for that channel.
|
||||
func TestRestoreChannelShells(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// First, we'll make our channel shell, it will only have the minimal
|
||||
// amount of information required for us to initiate the data loss
|
||||
// protection feature.
|
||||
channelShell, err := genRandomChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to gen channel shell: %v", err)
|
||||
}
|
||||
|
||||
graph := cdb.ChannelGraph()
|
||||
|
||||
// Before we can restore the channel, we'll need to make a source node
|
||||
// in the graph as the channel edge we create will need to have a
|
||||
// origin.
|
||||
testNode, err := createTestVertex(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test node: %v", err)
|
||||
}
|
||||
if err := graph.SetSourceNode(testNode); err != nil {
|
||||
t.Fatalf("unable to set source node: %v", err)
|
||||
}
|
||||
|
||||
// With the channel shell constructed, we'll now insert it into the
|
||||
// database with the restoration method.
|
||||
if err := cdb.RestoreChannelShells(channelShell); err != nil {
|
||||
t.Fatalf("unable to restore channel shell: %v", err)
|
||||
}
|
||||
|
||||
// Now that the channel has been inserted, we'll attempt to query for
|
||||
// it to ensure we can properly locate it via various means.
|
||||
//
|
||||
// First, we'll attempt to query for all channels that we have with the
|
||||
// node public key that was restored.
|
||||
nodeChans, err := cdb.FetchOpenChannels(channelShell.Chan.IdentityPub)
|
||||
if err != nil {
|
||||
t.Fatalf("unable find channel: %v", err)
|
||||
}
|
||||
|
||||
// We should now find a single channel from the database.
|
||||
if len(nodeChans) != 1 {
|
||||
t.Fatalf("unable to find restored channel by node "+
|
||||
"pubkey: %v", err)
|
||||
}
|
||||
|
||||
// Ensure that it isn't possible to modify the commitment state machine
|
||||
// of this restored channel.
|
||||
channel := nodeChans[0]
|
||||
err = channel.UpdateCommitment(nil)
|
||||
if err != ErrNoRestoredChannelMutation {
|
||||
t.Fatalf("able to mutate restored channel")
|
||||
}
|
||||
err = channel.AppendRemoteCommitChain(nil)
|
||||
if err != ErrNoRestoredChannelMutation {
|
||||
t.Fatalf("able to mutate restored channel")
|
||||
}
|
||||
err = channel.AdvanceCommitChainTail(nil)
|
||||
if err != ErrNoRestoredChannelMutation {
|
||||
t.Fatalf("able to mutate restored channel")
|
||||
}
|
||||
|
||||
// That single channel should have the proper channel point, and also
|
||||
// the expected set of flags to indicate that it was a restored
|
||||
// channel.
|
||||
if nodeChans[0].FundingOutpoint != channelShell.Chan.FundingOutpoint {
|
||||
t.Fatalf("wrong funding outpoint: expected %v, got %v",
|
||||
nodeChans[0].FundingOutpoint,
|
||||
channelShell.Chan.FundingOutpoint)
|
||||
}
|
||||
if !nodeChans[0].HasChanStatus(ChanStatusRestored) {
|
||||
t.Fatalf("node has wrong status flags: %v",
|
||||
nodeChans[0].chanStatus)
|
||||
}
|
||||
|
||||
// We should also be able to find the channel if we query for it
|
||||
// directly.
|
||||
_, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch channel: %v", err)
|
||||
}
|
||||
|
||||
// We should also be able to find the link node that was inserted by
|
||||
// its public key.
|
||||
linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch link node: %v", err)
|
||||
}
|
||||
|
||||
// The node should have the same address, as specified in the channel
|
||||
// shell.
|
||||
if reflect.DeepEqual(linkNode.Addresses, channelShell.NodeAddrs) {
|
||||
t.Fatalf("addr mismach: expected %v, got %v",
|
||||
linkNode.Addresses, channelShell.NodeAddrs)
|
||||
}
|
||||
|
||||
// Finally, we'll ensure that the edge for the channel was properly
|
||||
// inserted.
|
||||
chanInfos, err := graph.FetchChanInfos(
|
||||
[]uint64{channelShell.Chan.ShortChannelID.ToUint64()},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to find edges: %v", err)
|
||||
}
|
||||
|
||||
if len(chanInfos) != 1 {
|
||||
t.Fatalf("wrong amount of chan infos: expected %v got %v",
|
||||
len(chanInfos), 1)
|
||||
}
|
||||
|
||||
// We should only find a single edge.
|
||||
if chanInfos[0].Policy1 != nil && chanInfos[0].Policy2 != nil {
|
||||
t.Fatalf("only a single edge should be inserted: %v", err)
|
||||
}
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
package migration_01_to_11
|
|
@ -1,55 +1,23 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoChanDBExists is returned when a channel bucket hasn't been
|
||||
// created.
|
||||
ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created")
|
||||
|
||||
// ErrDBReversion is returned when detecting an attempt to revert to a
|
||||
// prior database version.
|
||||
ErrDBReversion = fmt.Errorf("channel db cannot revert to prior version")
|
||||
|
||||
// ErrLinkNodesNotFound is returned when node info bucket hasn't been
|
||||
// created.
|
||||
ErrLinkNodesNotFound = fmt.Errorf("no link nodes exist")
|
||||
|
||||
// ErrNoActiveChannels is returned when there is no active (open)
|
||||
// channels within the database.
|
||||
ErrNoActiveChannels = fmt.Errorf("no active channels exist")
|
||||
|
||||
// ErrNoPastDeltas is returned when the channel delta bucket hasn't been
|
||||
// created.
|
||||
ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas")
|
||||
|
||||
// ErrInvoiceNotFound is returned when a targeted invoice can't be
|
||||
// found.
|
||||
ErrInvoiceNotFound = fmt.Errorf("unable to locate invoice")
|
||||
|
||||
// ErrNoInvoicesCreated is returned when we don't have invoices in
|
||||
// our database to return.
|
||||
ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices")
|
||||
|
||||
// ErrDuplicateInvoice is returned when an invoice with the target
|
||||
// payment hash already exists.
|
||||
ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists")
|
||||
|
||||
// ErrNoPaymentsCreated is returned when bucket of payments hasn't been
|
||||
// created.
|
||||
ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments")
|
||||
|
||||
// ErrNodeNotFound is returned when node bucket exists, but node with
|
||||
// specific identity can't be found.
|
||||
ErrNodeNotFound = fmt.Errorf("link node with target identity not found")
|
||||
|
||||
// ErrChannelNotFound is returned when we attempt to locate a channel
|
||||
// for a specific chain, but it is not found.
|
||||
ErrChannelNotFound = fmt.Errorf("channel not found")
|
||||
|
||||
// ErrMetaNotFound is returned when meta bucket hasn't been
|
||||
// created.
|
||||
ErrMetaNotFound = fmt.Errorf("unable to locate meta information")
|
||||
|
@ -58,22 +26,11 @@ var (
|
|||
// graph doesn't exist.
|
||||
ErrGraphNotFound = fmt.Errorf("graph bucket not initialized")
|
||||
|
||||
// ErrGraphNeverPruned is returned when graph was never pruned.
|
||||
ErrGraphNeverPruned = fmt.Errorf("graph never pruned")
|
||||
|
||||
// ErrSourceNodeNotSet is returned if the source node of the graph
|
||||
// hasn't been added The source node is the center node within a
|
||||
// star-graph.
|
||||
ErrSourceNodeNotSet = fmt.Errorf("source node does not exist")
|
||||
|
||||
// ErrGraphNodesNotFound is returned in case none of the nodes has
|
||||
// been added in graph node bucket.
|
||||
ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist")
|
||||
|
||||
// ErrGraphNoEdgesFound is returned in case of none of the channel/edges
|
||||
// has been added in graph edge bucket.
|
||||
ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist")
|
||||
|
||||
// ErrGraphNodeNotFound is returned when we're unable to find the target
|
||||
// node.
|
||||
ErrGraphNodeNotFound = fmt.Errorf("unable to find node")
|
||||
|
@ -82,17 +39,6 @@ var (
|
|||
// can't be found.
|
||||
ErrEdgeNotFound = fmt.Errorf("edge not found")
|
||||
|
||||
// ErrZombieEdge is an error returned when we attempt to look up an edge
|
||||
// but it is marked as a zombie within the zombie index.
|
||||
ErrZombieEdge = errors.New("edge marked as zombie")
|
||||
|
||||
// ErrEdgeAlreadyExist is returned when edge with specific
|
||||
// channel id can't be added because it already exist.
|
||||
ErrEdgeAlreadyExist = fmt.Errorf("edge already exist")
|
||||
|
||||
// ErrNodeAliasNotFound is returned when alias for node can't be found.
|
||||
ErrNodeAliasNotFound = fmt.Errorf("alias for node not found")
|
||||
|
||||
// ErrUnknownAddressType is returned when a node's addressType is not
|
||||
// an expected value.
|
||||
ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved")
|
||||
|
@ -101,20 +47,11 @@ var (
|
|||
// channels it has closed, but it hasn't yet closed any channels.
|
||||
ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet")
|
||||
|
||||
// ErrNoForwardingEvents is returned in the case that a query fails due
|
||||
// to the log not having any recorded events.
|
||||
ErrNoForwardingEvents = fmt.Errorf("no recorded forwarding events")
|
||||
|
||||
// ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel
|
||||
// policy field is not found in the db even though its message flags
|
||||
// indicate it should be.
|
||||
ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " +
|
||||
"present")
|
||||
|
||||
// ErrChanAlreadyExists is return when the caller attempts to create a
|
||||
// channel with a channel point that is already present in the
|
||||
// database.
|
||||
ErrChanAlreadyExists = fmt.Errorf("channel already exists")
|
||||
)
|
||||
|
||||
// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
package migration_01_to_11
|
|
@ -1,274 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
var (
|
||||
// forwardingLogBucket is the bucket that we'll use to store the
|
||||
// forwarding log. The forwarding log contains a time series database
|
||||
// of the forwarding history of a lightning daemon. Each key within the
|
||||
// bucket is a timestamp (in nano seconds since the unix epoch), and
|
||||
// the value a slice of a forwarding event for that timestamp.
|
||||
forwardingLogBucket = []byte("circuit-fwd-log")
|
||||
)
|
||||
|
||||
const (
|
||||
// forwardingEventSize is the size of a forwarding event. The breakdown
|
||||
// is as follows:
|
||||
//
|
||||
// * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in
|
||||
// || 8 byte value out
|
||||
//
|
||||
// From the value in and value out, callers can easily compute the
|
||||
// total fee extract from a forwarding event.
|
||||
forwardingEventSize = 32
|
||||
|
||||
// MaxResponseEvents is the max number of forwarding events that will
|
||||
// be returned by a single query response. This size was selected to
|
||||
// safely remain under gRPC's 4MiB message size response limit. As each
|
||||
// full forwarding event (including the timestamp) is 40 bytes, we can
|
||||
// safely return 50k entries in a single response.
|
||||
MaxResponseEvents = 50000
|
||||
)
|
||||
|
||||
// ForwardingLog returns an instance of the ForwardingLog object backed by the
|
||||
// target database instance.
|
||||
func (d *DB) ForwardingLog() *ForwardingLog {
|
||||
return &ForwardingLog{
|
||||
db: d,
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardingLog is a time series database that logs the fulfilment of payment
|
||||
// circuits by a lightning network daemon. The log contains a series of
|
||||
// forwarding events which map a timestamp to a forwarding event. A forwarding
|
||||
// event describes which channels were used to create+settle a circuit, and the
|
||||
// amount involved. Subtracting the outgoing amount from the incoming amount
|
||||
// reveals the fee charged for the forwarding service.
|
||||
type ForwardingLog struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// ForwardingEvent is an event in the forwarding log's time series. Each
|
||||
// forwarding event logs the creation and tear-down of a payment circuit. A
|
||||
// circuit is created once an incoming HTLC has been fully forwarded, and
|
||||
// destroyed once the payment has been settled.
|
||||
type ForwardingEvent struct {
|
||||
// Timestamp is the settlement time of this payment circuit.
|
||||
Timestamp time.Time
|
||||
|
||||
// IncomingChanID is the incoming channel ID of the payment circuit.
|
||||
IncomingChanID lnwire.ShortChannelID
|
||||
|
||||
// OutgoingChanID is the outgoing channel ID of the payment circuit.
|
||||
OutgoingChanID lnwire.ShortChannelID
|
||||
|
||||
// AmtIn is the amount of the incoming HTLC. Subtracting this from the
|
||||
// outgoing amount gives the total fees of this payment circuit.
|
||||
AmtIn lnwire.MilliSatoshi
|
||||
|
||||
// AmtOut is the amount of the outgoing HTLC. Subtracting the incoming
|
||||
// amount from this gives the total fees for this payment circuit.
|
||||
AmtOut lnwire.MilliSatoshi
|
||||
}
|
||||
|
||||
// encodeForwardingEvent writes out the target forwarding event to the passed
|
||||
// io.Writer, using the expected DB format. Note that the timestamp isn't
|
||||
// serialized as this will be the key value within the bucket.
|
||||
func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
|
||||
return WriteElements(
|
||||
w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut,
|
||||
)
|
||||
}
|
||||
|
||||
// decodeForwardingEvent attempts to decode the raw bytes of a serialized
|
||||
// forwarding event into the target ForwardingEvent. Note that the timestamp
|
||||
// won't be decoded, as the caller is expected to set this due to the bucket
|
||||
// structure of the forwarding log.
|
||||
func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error {
|
||||
return ReadElements(
|
||||
r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut,
|
||||
)
|
||||
}
|
||||
|
||||
// AddForwardingEvents adds a series of forwarding events to the database.
|
||||
// Before inserting, the set of events will be sorted according to their
|
||||
// timestamp. This ensures that all writes to disk are sequential.
|
||||
func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error {
|
||||
// Before we create the database transaction, we'll ensure that the set
|
||||
// of forwarding events are properly sorted according to their
|
||||
// timestamp.
|
||||
sort.Slice(events, func(i, j int) bool {
|
||||
return events[i].Timestamp.Before(events[j].Timestamp)
|
||||
})
|
||||
|
||||
var timestamp [8]byte
|
||||
|
||||
return f.db.Batch(func(tx *bbolt.Tx) error {
|
||||
// First, we'll fetch the bucket that stores our time series
|
||||
// log.
|
||||
logBucket, err := tx.CreateBucketIfNotExists(
|
||||
forwardingLogBucket,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// With the bucket obtained, we can now begin to write out the
|
||||
// series of events.
|
||||
for _, event := range events {
|
||||
var eventBytes [forwardingEventSize]byte
|
||||
eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize])
|
||||
|
||||
// First, we'll serialize this timestamp into our
|
||||
// timestamp buffer.
|
||||
byteOrder.PutUint64(
|
||||
timestamp[:], uint64(event.Timestamp.UnixNano()),
|
||||
)
|
||||
|
||||
// With the key encoded, we'll then encode the event
|
||||
// into our buffer, then write it out to disk.
|
||||
err := encodeForwardingEvent(eventBuf, &event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = logBucket.Put(timestamp[:], eventBuf.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ForwardingEventQuery represents a query to the forwarding log payment
|
||||
// circuit time series database. The query allows a caller to retrieve all
|
||||
// records for a particular time slice, offset in that time slice, limiting the
|
||||
// total number of responses returned.
|
||||
type ForwardingEventQuery struct {
|
||||
// StartTime is the start time of the time slice.
|
||||
StartTime time.Time
|
||||
|
||||
// EndTime is the end time of the time slice.
|
||||
EndTime time.Time
|
||||
|
||||
// IndexOffset is the offset within the time slice to start at. This
|
||||
// can be used to start the response at a particular record.
|
||||
IndexOffset uint32
|
||||
|
||||
// NumMaxEvents is the max number of events to return.
|
||||
NumMaxEvents uint32
|
||||
}
|
||||
|
||||
// ForwardingLogTimeSlice is the response to a forwarding query. It includes
|
||||
// the original query, the set events that match the query, and an integer
|
||||
// which represents the offset index of the last item in the set of retuned
|
||||
// events. This integer allows callers to resume their query using this offset
|
||||
// in the event that the query's response exceeds the max number of returnable
|
||||
// events.
|
||||
type ForwardingLogTimeSlice struct {
|
||||
ForwardingEventQuery
|
||||
|
||||
// ForwardingEvents is the set of events in our time series that answer
|
||||
// the query embedded above.
|
||||
ForwardingEvents []ForwardingEvent
|
||||
|
||||
// LastIndexOffset is the index of the last element in the set of
|
||||
// returned ForwardingEvents above. Callers can use this to resume
|
||||
// their query in the event that the time slice has too many events to
|
||||
// fit into a single response.
|
||||
LastIndexOffset uint32
|
||||
}
|
||||
|
||||
// Query allows a caller to query the forwarding event time series for a
|
||||
// particular time slice. The caller can control the precise time as well as
|
||||
// the number of events to be returned.
|
||||
//
|
||||
// TODO(roasbeef): rename?
|
||||
func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) {
|
||||
resp := ForwardingLogTimeSlice{
|
||||
ForwardingEventQuery: q,
|
||||
}
|
||||
|
||||
// If the user provided an index offset, then we'll not know how many
|
||||
// records we need to skip. We'll also keep track of the record offset
|
||||
// as that's part of the final return value.
|
||||
recordsToSkip := q.IndexOffset
|
||||
recordOffset := q.IndexOffset
|
||||
|
||||
err := f.db.View(func(tx *bbolt.Tx) error {
|
||||
// If the bucket wasn't found, then there aren't any events to
|
||||
// be returned.
|
||||
logBucket := tx.Bucket(forwardingLogBucket)
|
||||
if logBucket == nil {
|
||||
return ErrNoForwardingEvents
|
||||
}
|
||||
|
||||
// We'll be using a cursor to seek into the database, so we'll
|
||||
// populate byte slices that represent the start of the key
|
||||
// space we're interested in, and the end.
|
||||
var startTime, endTime [8]byte
|
||||
byteOrder.PutUint64(startTime[:], uint64(q.StartTime.UnixNano()))
|
||||
byteOrder.PutUint64(endTime[:], uint64(q.EndTime.UnixNano()))
|
||||
|
||||
// If we know that a set of log events exists, then we'll begin
|
||||
// our seek through the log in order to satisfy the query.
|
||||
// We'll continue until either we reach the end of the range,
|
||||
// or reach our max number of events.
|
||||
logCursor := logBucket.Cursor()
|
||||
timestamp, events := logCursor.Seek(startTime[:])
|
||||
for ; timestamp != nil && bytes.Compare(timestamp, endTime[:]) <= 0; timestamp, events = logCursor.Next() {
|
||||
// If our current return payload exceeds the max number
|
||||
// of events, then we'll exit now.
|
||||
if uint32(len(resp.ForwardingEvents)) >= q.NumMaxEvents {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we're not yet past the user defined offset, then
|
||||
// we'll continue to seek forward.
|
||||
if recordsToSkip > 0 {
|
||||
recordsToSkip--
|
||||
continue
|
||||
}
|
||||
|
||||
currentTime := time.Unix(
|
||||
0, int64(byteOrder.Uint64(timestamp)),
|
||||
)
|
||||
|
||||
// At this point, we've skipped enough records to start
|
||||
// to collate our query. For each record, we'll
|
||||
// increment the final record offset so the querier can
|
||||
// utilize pagination to seek further.
|
||||
readBuf := bytes.NewReader(events)
|
||||
for readBuf.Len() != 0 {
|
||||
var event ForwardingEvent
|
||||
err := decodeForwardingEvent(readBuf, &event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event.Timestamp = currentTime
|
||||
resp.ForwardingEvents = append(resp.ForwardingEvents, event)
|
||||
|
||||
recordOffset++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil && err != ErrNoForwardingEvents {
|
||||
return ForwardingLogTimeSlice{}, err
|
||||
}
|
||||
|
||||
resp.LastIndexOffset = recordOffset
|
||||
|
||||
return resp, nil
|
||||
}
|
|
@ -1,265 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestForwardingLogBasicStorageAndQuery tests that we're able to store and
|
||||
// then query for items that have previously been added to the event log.
|
||||
func TestForwardingLogBasicStorageAndQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First, we'll set up a test database, and use that to instantiate the
|
||||
// forwarding event log that we'll be using for the duration of the
|
||||
// test.
|
||||
db, cleanUp, err := makeTestDB()
|
||||
defer cleanUp()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test db: %v", err)
|
||||
}
|
||||
log := ForwardingLog{
|
||||
db: db,
|
||||
}
|
||||
|
||||
initialTime := time.Unix(1234, 0)
|
||||
timestamp := time.Unix(1234, 0)
|
||||
|
||||
// We'll create 100 random events, which each event being spaced 10
|
||||
// minutes after the prior event.
|
||||
numEvents := 100
|
||||
events := make([]ForwardingEvent, numEvents)
|
||||
for i := 0; i < numEvents; i++ {
|
||||
events[i] = ForwardingEvent{
|
||||
Timestamp: timestamp,
|
||||
IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
||||
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
||||
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
|
||||
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
|
||||
}
|
||||
|
||||
timestamp = timestamp.Add(time.Minute * 10)
|
||||
}
|
||||
|
||||
// Now that all of our set of events constructed, we'll add them to the
|
||||
// database in a batch manner.
|
||||
if err := log.AddForwardingEvents(events); err != nil {
|
||||
t.Fatalf("unable to add events: %v", err)
|
||||
}
|
||||
|
||||
// With our events added we'll now construct a basic query to retrieve
|
||||
// all of the events.
|
||||
eventQuery := ForwardingEventQuery{
|
||||
StartTime: initialTime,
|
||||
EndTime: timestamp,
|
||||
IndexOffset: 0,
|
||||
NumMaxEvents: 1000,
|
||||
}
|
||||
timeSlice, err := log.Query(eventQuery)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query for events: %v", err)
|
||||
}
|
||||
|
||||
// The set of returned events should match identically, as they should
|
||||
// be returned in sorted order.
|
||||
if !reflect.DeepEqual(events, timeSlice.ForwardingEvents) {
|
||||
t.Fatalf("event mismatch: expected %v vs %v",
|
||||
spew.Sdump(events), spew.Sdump(timeSlice.ForwardingEvents))
|
||||
}
|
||||
|
||||
// The offset index of the final entry should be numEvents, so the
|
||||
// number of total events we've written.
|
||||
if timeSlice.LastIndexOffset != uint32(numEvents) {
|
||||
t.Fatalf("wrong final offset: expected %v, got %v",
|
||||
timeSlice.LastIndexOffset, numEvents)
|
||||
}
|
||||
}
|
||||
|
||||
// TestForwardingLogQueryOptions tests that the query offset works properly. So
|
||||
// if we add a series of events, then we should be able to seek within the
|
||||
// timeslice accordingly. This exercises the index offset and num max event
|
||||
// field in the query, and also the last index offset field int he response.
|
||||
func TestForwardingLogQueryOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First, we'll set up a test database, and use that to instantiate the
|
||||
// forwarding event log that we'll be using for the duration of the
|
||||
// test.
|
||||
db, cleanUp, err := makeTestDB()
|
||||
defer cleanUp()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test db: %v", err)
|
||||
}
|
||||
log := ForwardingLog{
|
||||
db: db,
|
||||
}
|
||||
|
||||
initialTime := time.Unix(1234, 0)
|
||||
endTime := time.Unix(1234, 0)
|
||||
|
||||
// We'll create 20 random events, which each event being spaced 10
|
||||
// minutes after the prior event.
|
||||
numEvents := 20
|
||||
events := make([]ForwardingEvent, numEvents)
|
||||
for i := 0; i < numEvents; i++ {
|
||||
events[i] = ForwardingEvent{
|
||||
Timestamp: endTime,
|
||||
IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
||||
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
||||
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
|
||||
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
|
||||
}
|
||||
|
||||
endTime = endTime.Add(time.Minute * 10)
|
||||
}
|
||||
|
||||
// Now that all of our set of events constructed, we'll add them to the
|
||||
// database in a batch manner.
|
||||
if err := log.AddForwardingEvents(events); err != nil {
|
||||
t.Fatalf("unable to add events: %v", err)
|
||||
}
|
||||
|
||||
// With all of our events added, we should be able to query for the
|
||||
// first 10 events using the max event query field.
|
||||
eventQuery := ForwardingEventQuery{
|
||||
StartTime: initialTime,
|
||||
EndTime: endTime,
|
||||
IndexOffset: 0,
|
||||
NumMaxEvents: 10,
|
||||
}
|
||||
timeSlice, err := log.Query(eventQuery)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query for events: %v", err)
|
||||
}
|
||||
|
||||
// We should get exactly 10 events back.
|
||||
if len(timeSlice.ForwardingEvents) != 10 {
|
||||
t.Fatalf("wrong number of events: expected %v, got %v", 10,
|
||||
len(timeSlice.ForwardingEvents))
|
||||
}
|
||||
|
||||
// The set of events returned should be the first 10 events that we
|
||||
// added.
|
||||
if !reflect.DeepEqual(events[:10], timeSlice.ForwardingEvents) {
|
||||
t.Fatalf("wrong response: expected %v, got %v",
|
||||
spew.Sdump(events[:10]),
|
||||
spew.Sdump(timeSlice.ForwardingEvents))
|
||||
}
|
||||
|
||||
// The final offset should be the exact number of events returned.
|
||||
if timeSlice.LastIndexOffset != 10 {
|
||||
t.Fatalf("wrong index offset: expected %v, got %v", 10,
|
||||
timeSlice.LastIndexOffset)
|
||||
}
|
||||
|
||||
// If we use the final offset to query again, then we should get 10
|
||||
// more events, that are the last 10 events we wrote.
|
||||
eventQuery.IndexOffset = 10
|
||||
timeSlice, err = log.Query(eventQuery)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query for events: %v", err)
|
||||
}
|
||||
|
||||
// We should get exactly 10 events back once again.
|
||||
if len(timeSlice.ForwardingEvents) != 10 {
|
||||
t.Fatalf("wrong number of events: expected %v, got %v", 10,
|
||||
len(timeSlice.ForwardingEvents))
|
||||
}
|
||||
|
||||
// The events that we got back should be the last 10 events that we
|
||||
// wrote out.
|
||||
if !reflect.DeepEqual(events[10:], timeSlice.ForwardingEvents) {
|
||||
t.Fatalf("wrong response: expected %v, got %v",
|
||||
spew.Sdump(events[10:]),
|
||||
spew.Sdump(timeSlice.ForwardingEvents))
|
||||
}
|
||||
|
||||
// Finally, the last index offset should be 20, or the number of
|
||||
// records we've written out.
|
||||
if timeSlice.LastIndexOffset != 20 {
|
||||
t.Fatalf("wrong index offset: expected %v, got %v", 20,
|
||||
timeSlice.LastIndexOffset)
|
||||
}
|
||||
}
|
||||
|
||||
// TestForwardingLogQueryLimit tests that we're able to properly limit the
|
||||
// number of events that are returned as part of a query.
|
||||
func TestForwardingLogQueryLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First, we'll set up a test database, and use that to instantiate the
|
||||
// forwarding event log that we'll be using for the duration of the
|
||||
// test.
|
||||
db, cleanUp, err := makeTestDB()
|
||||
defer cleanUp()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test db: %v", err)
|
||||
}
|
||||
log := ForwardingLog{
|
||||
db: db,
|
||||
}
|
||||
|
||||
initialTime := time.Unix(1234, 0)
|
||||
endTime := time.Unix(1234, 0)
|
||||
|
||||
// We'll create 200 random events, which each event being spaced 10
|
||||
// minutes after the prior event.
|
||||
numEvents := 200
|
||||
events := make([]ForwardingEvent, numEvents)
|
||||
for i := 0; i < numEvents; i++ {
|
||||
events[i] = ForwardingEvent{
|
||||
Timestamp: endTime,
|
||||
IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
||||
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
||||
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
|
||||
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
|
||||
}
|
||||
|
||||
endTime = endTime.Add(time.Minute * 10)
|
||||
}
|
||||
|
||||
// Now that all of our set of events constructed, we'll add them to the
|
||||
// database in a batch manner.
|
||||
if err := log.AddForwardingEvents(events); err != nil {
|
||||
t.Fatalf("unable to add events: %v", err)
|
||||
}
|
||||
|
||||
// Once the events have been written out, we'll issue a query over the
|
||||
// entire range, but restrict the number of events to the first 100.
|
||||
eventQuery := ForwardingEventQuery{
|
||||
StartTime: initialTime,
|
||||
EndTime: endTime,
|
||||
IndexOffset: 0,
|
||||
NumMaxEvents: 100,
|
||||
}
|
||||
timeSlice, err := log.Query(eventQuery)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query for events: %v", err)
|
||||
}
|
||||
|
||||
// We should get exactly 100 events back.
|
||||
if len(timeSlice.ForwardingEvents) != 100 {
|
||||
t.Fatalf("wrong number of events: expected %v, got %v", 10,
|
||||
len(timeSlice.ForwardingEvents))
|
||||
}
|
||||
|
||||
// The set of events returned should be the first 100 events that we
|
||||
// added.
|
||||
if !reflect.DeepEqual(events[:100], timeSlice.ForwardingEvents) {
|
||||
t.Fatalf("wrong response: expected %v, got %v",
|
||||
spew.Sdump(events[:100]),
|
||||
spew.Sdump(timeSlice.ForwardingEvents))
|
||||
}
|
||||
|
||||
// The final offset should be the exact number of events returned.
|
||||
if timeSlice.LastIndexOffset != 100 {
|
||||
t.Fatalf("wrong index offset: expected %v, got %v", 100,
|
||||
timeSlice.LastIndexOffset)
|
||||
}
|
||||
}
|
|
@ -1,928 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding
|
||||
// package has potentially been mangled.
|
||||
var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted")
|
||||
|
||||
// FwdState is an enum used to describe the lifecycle of a FwdPkg.
|
||||
type FwdState byte
|
||||
|
||||
const (
|
||||
// FwdStateLockedIn is the starting state for all forwarding packages.
|
||||
// Packages in this state have not yet committed to the exact set of
|
||||
// Adds to forward to the switch.
|
||||
FwdStateLockedIn FwdState = iota
|
||||
|
||||
// FwdStateProcessed marks the state in which all Adds have been
|
||||
// locally processed and the forwarding decision to the switch has been
|
||||
// persisted.
|
||||
FwdStateProcessed
|
||||
|
||||
// FwdStateCompleted signals that all Adds have been acked, and that all
|
||||
// settles and fails have been delivered to their sources. Packages in
|
||||
// this state can be removed permanently.
|
||||
FwdStateCompleted
|
||||
)
|
||||
|
||||
var (
|
||||
// fwdPackagesKey is the root-level bucket that all forwarding packages
|
||||
// are written. This bucket is further subdivided based on the short
|
||||
// channel ID of each channel.
|
||||
fwdPackagesKey = []byte("fwd-packages")
|
||||
|
||||
// addBucketKey is the bucket to which all Add log updates are written.
|
||||
addBucketKey = []byte("add-updates")
|
||||
|
||||
// failSettleBucketKey is the bucket to which all Settle/Fail log
|
||||
// updates are written.
|
||||
failSettleBucketKey = []byte("fail-settle-updates")
|
||||
|
||||
// fwdFilterKey is a key used to write the set of Adds that passed
|
||||
// validation and are to be forwarded to the switch.
|
||||
// NOTE: The presence of this key within a forwarding package indicates
|
||||
// that the package has reached FwdStateProcessed.
|
||||
fwdFilterKey = []byte("fwd-filter-key")
|
||||
|
||||
// ackFilterKey is a key used to access the PkgFilter indicating which
|
||||
// Adds have received a Settle/Fail. This response may come from a
|
||||
// number of sources, including: exitHop settle/fails, switch failures,
|
||||
// chain arbiter interjections, as well as settle/fails from the
|
||||
// next hop in the route.
|
||||
ackFilterKey = []byte("ack-filter-key")
|
||||
|
||||
// settleFailFilterKey is a key used to access the PkgFilter indicating
|
||||
// which Settles/Fails in have been received and processed by the link
|
||||
// that originally received the Add.
|
||||
settleFailFilterKey = []byte("settle-fail-filter-key")
|
||||
)
|
||||
|
||||
// PkgFilter is used to compactly represent a particular subset of the Adds in a
|
||||
// forwarding package. Each filter is represented as a simple, statically-sized
|
||||
// bitvector, where the elements are intended to be the indices of the Adds as
|
||||
// they are written in the FwdPkg.
|
||||
type PkgFilter struct {
|
||||
count uint16
|
||||
filter []byte
|
||||
}
|
||||
|
||||
// NewPkgFilter initializes an empty PkgFilter supporting `count` elements.
|
||||
func NewPkgFilter(count uint16) *PkgFilter {
|
||||
// We add 7 to ensure that the integer division yields properly rounded
|
||||
// values.
|
||||
filterLen := (count + 7) / 8
|
||||
|
||||
return &PkgFilter{
|
||||
count: count,
|
||||
filter: make([]byte, filterLen),
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of elements represented by this PkgFilter.
|
||||
func (f *PkgFilter) Count() uint16 {
|
||||
return f.count
|
||||
}
|
||||
|
||||
// Set marks the `i`-th element as included by this filter.
|
||||
// NOTE: It is assumed that i is always less than count.
|
||||
func (f *PkgFilter) Set(i uint16) {
|
||||
byt := i / 8
|
||||
bit := i % 8
|
||||
|
||||
// Set the i-th bit in the filter.
|
||||
// TODO(conner): ignore if > count to prevent panic?
|
||||
f.filter[byt] |= byte(1 << (7 - bit))
|
||||
}
|
||||
|
||||
// Contains queries the filter for membership of index `i`.
|
||||
// NOTE: It is assumed that i is always less than count.
|
||||
func (f *PkgFilter) Contains(i uint16) bool {
|
||||
byt := i / 8
|
||||
bit := i % 8
|
||||
|
||||
// Read the i-th bit in the filter.
|
||||
// TODO(conner): ignore if > count to prevent panic?
|
||||
return f.filter[byt]&(1<<(7-bit)) != 0
|
||||
}
|
||||
|
||||
// Equal checks two PkgFilters for equality.
|
||||
func (f *PkgFilter) Equal(f2 *PkgFilter) bool {
|
||||
if f == f2 {
|
||||
return true
|
||||
}
|
||||
if f.count != f2.count {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(f.filter, f2.filter)
|
||||
}
|
||||
|
||||
// IsFull returns true if every element in the filter has been Set, and false
|
||||
// otherwise.
|
||||
func (f *PkgFilter) IsFull() bool {
|
||||
// Batch validate bytes that are fully used.
|
||||
for i := uint16(0); i < f.count/8; i++ {
|
||||
if f.filter[i] != 0xFF {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// If the count is not a multiple of 8, check that the filter contains
|
||||
// all remaining bits.
|
||||
rem := f.count % 8
|
||||
for idx := f.count - rem; idx < f.count; idx++ {
|
||||
if !f.Contains(idx) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Size returns number of bytes produced when the PkgFilter is serialized.
|
||||
func (f *PkgFilter) Size() uint16 {
|
||||
// 2 bytes for uint16 `count`, then round up number of bytes required to
|
||||
// represent `count` bits.
|
||||
return 2 + (f.count+7)/8
|
||||
}
|
||||
|
||||
// Encode writes the filter to the provided io.Writer.
|
||||
func (f *PkgFilter) Encode(w io.Writer) error {
|
||||
if err := binary.Write(w, binary.BigEndian, f.count); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := w.Write(f.filter)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode reads the filter from the provided io.Reader.
|
||||
func (f *PkgFilter) Decode(r io.Reader) error {
|
||||
if err := binary.Read(r, binary.BigEndian, &f.count); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.filter = make([]byte, f.Size()-2)
|
||||
_, err := io.ReadFull(r, f.filter)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// FwdPkg records all adds, settles, and fails that were locked in as a result
|
||||
// of the remote peer sending us a revocation. Each package is identified by
|
||||
// the short chanid and remote commitment height corresponding to the revocation
|
||||
// that locked in the HTLCs. For everything except a locally initiated payment,
|
||||
// settles and fails in a forwarding package must have a corresponding Add in
|
||||
// another package, and can be removed individually once the source link has
|
||||
// received the fail/settle.
|
||||
//
|
||||
// Adds cannot be removed, as we need to present the same batch of Adds to
|
||||
// properly handle replay protection. Instead, we use a PkgFilter to mark that
|
||||
// we have finished processing a particular Add. A FwdPkg should only be deleted
|
||||
// after the AckFilter is full and all settles and fails have been persistently
|
||||
// removed.
|
||||
type FwdPkg struct {
|
||||
// Source identifies the channel that wrote this forwarding package.
|
||||
Source lnwire.ShortChannelID
|
||||
|
||||
// Height is the height of the remote commitment chain that locked in
|
||||
// this forwarding package.
|
||||
Height uint64
|
||||
|
||||
// State signals the persistent condition of the package and directs how
|
||||
// to reprocess the package in the event of failures.
|
||||
State FwdState
|
||||
|
||||
// Adds contains all add messages which need to be processed and
|
||||
// forwarded to the switch. Adds does not change over the life of a
|
||||
// forwarding package.
|
||||
Adds []LogUpdate
|
||||
|
||||
// FwdFilter is a filter containing the indices of all Adds that were
|
||||
// forwarded to the switch.
|
||||
FwdFilter *PkgFilter
|
||||
|
||||
// AckFilter is a filter containing the indices of all Adds for which
|
||||
// the source has received a settle or fail and is reflected in the next
|
||||
// commitment txn. A package should not be removed until IsFull()
|
||||
// returns true.
|
||||
AckFilter *PkgFilter
|
||||
|
||||
// SettleFails contains all settle and fail messages that should be
|
||||
// forwarded to the switch.
|
||||
SettleFails []LogUpdate
|
||||
|
||||
// SettleFailFilter is a filter containing the indices of all Settle or
|
||||
// Fails originating in this package that have been received and locked
|
||||
// into the incoming link's commitment state.
|
||||
SettleFailFilter *PkgFilter
|
||||
}
|
||||
|
||||
// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This
|
||||
// should be used to create a package at the time we receive a revocation.
|
||||
func NewFwdPkg(source lnwire.ShortChannelID, height uint64,
|
||||
addUpdates, settleFailUpdates []LogUpdate) *FwdPkg {
|
||||
|
||||
nAddUpdates := uint16(len(addUpdates))
|
||||
nSettleFailUpdates := uint16(len(settleFailUpdates))
|
||||
|
||||
return &FwdPkg{
|
||||
Source: source,
|
||||
Height: height,
|
||||
State: FwdStateLockedIn,
|
||||
Adds: addUpdates,
|
||||
FwdFilter: NewPkgFilter(nAddUpdates),
|
||||
AckFilter: NewPkgFilter(nAddUpdates),
|
||||
SettleFails: settleFailUpdates,
|
||||
SettleFailFilter: NewPkgFilter(nSettleFailUpdates),
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns an unique identifier for this package, used to ensure that sphinx
|
||||
// replay processing of this batch is idempotent.
|
||||
func (f *FwdPkg) ID() []byte {
|
||||
var id = make([]byte, 16)
|
||||
byteOrder.PutUint64(id[:8], f.Source.ToUint64())
|
||||
byteOrder.PutUint64(id[8:], f.Height)
|
||||
return id
|
||||
}
|
||||
|
||||
// String returns a human-readable description of the forwarding package.
|
||||
func (f *FwdPkg) String() string {
|
||||
return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)",
|
||||
f, f.Source, f.Height, len(f.Adds), len(f.SettleFails))
|
||||
}
|
||||
|
||||
// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID
|
||||
// is assumed to be that of the packager.
|
||||
type AddRef struct {
|
||||
// Height is the remote commitment height that locked in the Add.
|
||||
Height uint64
|
||||
|
||||
// Index is the index of the Add within the fwd pkg's Adds.
|
||||
//
|
||||
// NOTE: This index is static over the lifetime of a forwarding package.
|
||||
Index uint16
|
||||
}
|
||||
|
||||
// Encode serializes the AddRef to the given io.Writer.
|
||||
func (a *AddRef) Encode(w io.Writer) error {
|
||||
if err := binary.Write(w, binary.BigEndian, a.Height); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return binary.Write(w, binary.BigEndian, a.Index)
|
||||
}
|
||||
|
||||
// Decode deserializes the AddRef from the given io.Reader.
|
||||
func (a *AddRef) Decode(r io.Reader) error {
|
||||
if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return binary.Read(r, binary.BigEndian, &a.Index)
|
||||
}
|
||||
|
||||
// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A
|
||||
// channel does not remove its own Settle/Fail htlcs, so the source is provided
|
||||
// to locate a db bucket belonging to another channel.
|
||||
type SettleFailRef struct {
|
||||
// Source identifies the outgoing link that locked in the settle or
|
||||
// fail. This is then used by the *incoming* link to find the settle
|
||||
// fail in another link's forwarding packages.
|
||||
Source lnwire.ShortChannelID
|
||||
|
||||
// Height is the remote commitment height that locked in this
|
||||
// Settle/Fail.
|
||||
Height uint64
|
||||
|
||||
// Index is the index of the Add with the fwd pkg's SettleFails.
|
||||
//
|
||||
// NOTE: This index is static over the lifetime of a forwarding package.
|
||||
Index uint16
|
||||
}
|
||||
|
||||
// SettleFailAcker is a generic interface providing the ability to acknowledge
|
||||
// settle/fail HTLCs stored in forwarding packages.
|
||||
type SettleFailAcker interface {
|
||||
// AckSettleFails atomically updates the settle-fail filters in *other*
|
||||
// channels' forwarding packages.
|
||||
AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error
|
||||
}
|
||||
|
||||
// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages
|
||||
// of any active channel.
|
||||
type GlobalFwdPkgReader interface {
|
||||
// LoadChannelFwdPkgs loads all known forwarding packages for the given
|
||||
// channel.
|
||||
LoadChannelFwdPkgs(tx *bbolt.Tx,
|
||||
source lnwire.ShortChannelID) ([]*FwdPkg, error)
|
||||
}
|
||||
|
||||
// FwdOperator defines the interfaces for managing forwarding packages that are
|
||||
// external to a particular channel. This interface is used by the switch to
|
||||
// read forwarding packages from arbitrary channels, and acknowledge settles and
|
||||
// fails for locally-sourced payments.
|
||||
type FwdOperator interface {
|
||||
// GlobalFwdPkgReader provides read access to all known forwarding
|
||||
// packages
|
||||
GlobalFwdPkgReader
|
||||
|
||||
// SettleFailAcker grants the ability to acknowledge settles or fails
|
||||
// residing in arbitrary forwarding packages.
|
||||
SettleFailAcker
|
||||
}
|
||||
|
||||
// SwitchPackager is a concrete implementation of the FwdOperator interface.
|
||||
// A SwitchPackager offers the ability to read any forwarding package, and ack
|
||||
// arbitrary settle and fail HTLCs.
|
||||
type SwitchPackager struct{}
|
||||
|
||||
// NewSwitchPackager instantiates a new SwitchPackager.
|
||||
func NewSwitchPackager() *SwitchPackager {
|
||||
return &SwitchPackager{}
|
||||
}
|
||||
|
||||
// AckSettleFails atomically updates the settle-fail filters in *other*
|
||||
// channels' forwarding packages, to mark that the switch has received a settle
|
||||
// or fail residing in the forwarding package of a link.
|
||||
func (*SwitchPackager) AckSettleFails(tx *bbolt.Tx,
|
||||
settleFailRefs ...SettleFailRef) error {
|
||||
|
||||
return ackSettleFails(tx, settleFailRefs)
|
||||
}
|
||||
|
||||
// LoadChannelFwdPkgs loads all forwarding packages for a particular channel.
|
||||
func (*SwitchPackager) LoadChannelFwdPkgs(tx *bbolt.Tx,
|
||||
source lnwire.ShortChannelID) ([]*FwdPkg, error) {
|
||||
|
||||
return loadChannelFwdPkgs(tx, source)
|
||||
}
|
||||
|
||||
// FwdPackager supports all operations required to modify fwd packages, such as
|
||||
// creation, updates, reading, and removal. The interfaces are broken down in
|
||||
// this way to support future delegation of the subinterfaces.
|
||||
type FwdPackager interface {
|
||||
// AddFwdPkg serializes and writes a FwdPkg for this channel at the
|
||||
// remote commitment height included in the forwarding package.
|
||||
AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error
|
||||
|
||||
// SetFwdFilter looks up the forwarding package at the remote `height`
|
||||
// and sets the `fwdFilter`, marking the Adds for which:
|
||||
// 1) We are not the exit node
|
||||
// 2) Passed all validation
|
||||
// 3) Should be forwarded to the switch immediately after a failure
|
||||
SetFwdFilter(tx *bbolt.Tx, height uint64, fwdFilter *PkgFilter) error
|
||||
|
||||
// AckAddHtlcs atomically updates the add filters in this channel's
|
||||
// forwarding packages to mark the resolution of an Add that was
|
||||
// received from the remote party.
|
||||
AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error
|
||||
|
||||
// SettleFailAcker allows a link to acknowledge settle/fail HTLCs
|
||||
// belonging to other channels.
|
||||
SettleFailAcker
|
||||
|
||||
// LoadFwdPkgs loads all known forwarding packages owned by this
|
||||
// channel.
|
||||
LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error)
|
||||
|
||||
// RemovePkg deletes a forwarding package owned by this channel at
|
||||
// the provided remote `height`.
|
||||
RemovePkg(tx *bbolt.Tx, height uint64) error
|
||||
}
|
||||
|
||||
// ChannelPackager is used by a channel to manage the lifecycle of its forwarding
|
||||
// packages. The packager is tied to a particular source channel ID, allowing it
|
||||
// to create and edit its own packages. Each packager also has the ability to
|
||||
// remove fail/settle htlcs that correspond to an add contained in one of
|
||||
// source's packages.
|
||||
type ChannelPackager struct {
|
||||
source lnwire.ShortChannelID
|
||||
}
|
||||
|
||||
// NewChannelPackager creates a new packager for a single channel.
|
||||
func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager {
|
||||
return &ChannelPackager{
|
||||
source: source,
|
||||
}
|
||||
}
|
||||
|
||||
// AddFwdPkg writes a newly locked in forwarding package to disk.
|
||||
func (*ChannelPackager) AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error {
|
||||
fwdPkgBkt, err := tx.CreateBucketIfNotExists(fwdPackagesKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
source := makeLogKey(fwdPkg.Source.ToUint64())
|
||||
sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
heightKey := makeLogKey(fwdPkg.Height)
|
||||
heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write ADD updates we received at this commit height.
|
||||
addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write SETTLE/FAIL updates we received at this commit height.
|
||||
failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range fwdPkg.Adds {
|
||||
err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Persist the initialized pkg filter, which will be used to determine
|
||||
// when we can remove this forwarding package from disk.
|
||||
var ackFilterBuf bytes.Buffer
|
||||
if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range fwdPkg.SettleFails {
|
||||
err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var settleFailFilterBuf bytes.Buffer
|
||||
err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
|
||||
}
|
||||
|
||||
// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key.
|
||||
func putLogUpdate(bkt *bbolt.Bucket, idx uint16, htlc *LogUpdate) error {
|
||||
var b bytes.Buffer
|
||||
if err := htlc.Encode(&b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bkt.Put(uint16Key(idx), b.Bytes())
|
||||
}
|
||||
|
||||
// LoadFwdPkgs scans the forwarding log for any packages that haven't been
|
||||
// processed, and returns their deserialized log updates in a map indexed by the
|
||||
// remote commitment height at which the updates were locked in.
|
||||
func (p *ChannelPackager) LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) {
|
||||
return loadChannelFwdPkgs(tx, p.source)
|
||||
}
|
||||
|
||||
// loadChannelFwdPkgs loads all forwarding packages owned by `source`.
|
||||
func loadChannelFwdPkgs(tx *bbolt.Tx, source lnwire.ShortChannelID) ([]*FwdPkg, error) {
|
||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
||||
if fwdPkgBkt == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sourceKey := makeLogKey(source.ToUint64())
|
||||
sourceBkt := fwdPkgBkt.Bucket(sourceKey[:])
|
||||
if sourceBkt == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var heights []uint64
|
||||
if err := sourceBkt.ForEach(func(k, _ []byte) error {
|
||||
if len(k) != 8 {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
heights = append(heights, byteOrder.Uint64(k))
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load the forwarding package for each retrieved height.
|
||||
fwdPkgs := make([]*FwdPkg, 0, len(heights))
|
||||
for _, height := range heights {
|
||||
fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fwdPkgs = append(fwdPkgs, fwdPkg)
|
||||
}
|
||||
|
||||
return fwdPkgs, nil
|
||||
}
|
||||
|
||||
// loadFwPkg reads the packager's fwd pkg at a given height, and determines the
|
||||
// appropriate FwdState.
|
||||
func loadFwdPkg(fwdPkgBkt *bbolt.Bucket, source lnwire.ShortChannelID,
|
||||
height uint64) (*FwdPkg, error) {
|
||||
|
||||
sourceKey := makeLogKey(source.ToUint64())
|
||||
sourceBkt := fwdPkgBkt.Bucket(sourceKey[:])
|
||||
if sourceBkt == nil {
|
||||
return nil, ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
heightKey := makeLogKey(height)
|
||||
heightBkt := sourceBkt.Bucket(heightKey[:])
|
||||
if heightBkt == nil {
|
||||
return nil, ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
// Load ADDs from disk.
|
||||
addBkt := heightBkt.Bucket(addBucketKey)
|
||||
if addBkt == nil {
|
||||
return nil, ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
adds, err := loadHtlcs(addBkt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load ack filter from disk.
|
||||
ackFilterBytes := heightBkt.Get(ackFilterKey)
|
||||
if ackFilterBytes == nil {
|
||||
return nil, ErrCorruptedFwdPkg
|
||||
}
|
||||
ackFilterReader := bytes.NewReader(ackFilterBytes)
|
||||
|
||||
ackFilter := &PkgFilter{}
|
||||
if err := ackFilter.Decode(ackFilterReader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load SETTLE/FAILs from disk.
|
||||
failSettleBkt := heightBkt.Bucket(failSettleBucketKey)
|
||||
if failSettleBkt == nil {
|
||||
return nil, ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
failSettles, err := loadHtlcs(failSettleBkt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load settle fail filter from disk.
|
||||
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
|
||||
if settleFailFilterBytes == nil {
|
||||
return nil, ErrCorruptedFwdPkg
|
||||
}
|
||||
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
|
||||
|
||||
settleFailFilter := &PkgFilter{}
|
||||
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize the fwding package, which always starts in the
|
||||
// FwdStateLockedIn. We can determine what state the package was left in
|
||||
// by examining constraints on the information loaded from disk.
|
||||
fwdPkg := &FwdPkg{
|
||||
Source: source,
|
||||
State: FwdStateLockedIn,
|
||||
Height: height,
|
||||
Adds: adds,
|
||||
AckFilter: ackFilter,
|
||||
SettleFails: failSettles,
|
||||
SettleFailFilter: settleFailFilter,
|
||||
}
|
||||
|
||||
// Check to see if we have written the set exported filter adds to
|
||||
// disk. If we haven't, processing of this package was never started, or
|
||||
// failed during the last attempt.
|
||||
fwdFilterBytes := heightBkt.Get(fwdFilterKey)
|
||||
if fwdFilterBytes == nil {
|
||||
nAdds := uint16(len(adds))
|
||||
fwdPkg.FwdFilter = NewPkgFilter(nAdds)
|
||||
return fwdPkg, nil
|
||||
}
|
||||
|
||||
fwdFilterReader := bytes.NewReader(fwdFilterBytes)
|
||||
fwdPkg.FwdFilter = &PkgFilter{}
|
||||
if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Otherwise, a complete round of processing was completed, and we
|
||||
// advance the package to FwdStateProcessed.
|
||||
fwdPkg.State = FwdStateProcessed
|
||||
|
||||
// If every add, settle, and fail has been fully acknowledged, we can
|
||||
// safely set the package's state to FwdStateCompleted, signalling that
|
||||
// it can be garbage collected.
|
||||
if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() {
|
||||
fwdPkg.State = FwdStateCompleted
|
||||
}
|
||||
|
||||
return fwdPkg, nil
|
||||
}
|
||||
|
||||
// loadHtlcs retrieves all serialized htlcs in a bucket, returning
|
||||
// them in order of the indexes they were written under.
|
||||
func loadHtlcs(bkt *bbolt.Bucket) ([]LogUpdate, error) {
|
||||
var htlcs []LogUpdate
|
||||
if err := bkt.ForEach(func(_, v []byte) error {
|
||||
var htlc LogUpdate
|
||||
if err := htlc.Decode(bytes.NewReader(v)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
htlcs = append(htlcs, htlc)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return htlcs, nil
|
||||
}
|
||||
|
||||
// SetFwdFilter writes the set of indexes corresponding to Adds at the
|
||||
// `height` that are to be forwarded to the switch. Calling this method causes
|
||||
// the forwarding package at `height` to be in FwdStateProcessed. We write this
|
||||
// forwarding decision so that we always arrive at the same behavior for HTLCs
|
||||
// leaving this channel. After a restart, we skip validation of these Adds,
|
||||
// since they are assumed to have already been validated, and make the switch or
|
||||
// outgoing link responsible for handling replays.
|
||||
func (p *ChannelPackager) SetFwdFilter(tx *bbolt.Tx, height uint64,
|
||||
fwdFilter *PkgFilter) error {
|
||||
|
||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
||||
if fwdPkgBkt == nil {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
source := makeLogKey(p.source.ToUint64())
|
||||
sourceBkt := fwdPkgBkt.Bucket(source[:])
|
||||
if sourceBkt == nil {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
heightKey := makeLogKey(height)
|
||||
heightBkt := sourceBkt.Bucket(heightKey[:])
|
||||
if heightBkt == nil {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
// If the fwd filter has already been written, we return early to avoid
|
||||
// modifying the persistent state.
|
||||
forwardedAddsBytes := heightBkt.Get(fwdFilterKey)
|
||||
if forwardedAddsBytes != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise we serialize and write the provided fwd filter.
|
||||
var b bytes.Buffer
|
||||
if err := fwdFilter.Encode(&b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return heightBkt.Put(fwdFilterKey, b.Bytes())
|
||||
}
|
||||
|
||||
// AckAddHtlcs accepts a list of references to add htlcs, and updates the
|
||||
// AckAddFilter of those forwarding packages to indicate that a settle or fail
|
||||
// has been received in response to the add.
|
||||
func (p *ChannelPackager) AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error {
|
||||
if len(addRefs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
||||
if fwdPkgBkt == nil {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
sourceKey := makeLogKey(p.source.ToUint64())
|
||||
sourceBkt := fwdPkgBkt.Bucket(sourceKey[:])
|
||||
if sourceBkt == nil {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
// Organize the forward references such that we just get a single slice
|
||||
// of indexes for each unique height.
|
||||
heightDiffs := make(map[uint64][]uint16)
|
||||
for _, addRef := range addRefs {
|
||||
heightDiffs[addRef.Height] = append(
|
||||
heightDiffs[addRef.Height],
|
||||
addRef.Index,
|
||||
)
|
||||
}
|
||||
|
||||
// Load each height bucket once and remove all acked htlcs at that
|
||||
// height.
|
||||
for height, indexes := range heightDiffs {
|
||||
err := ackAddHtlcsAtHeight(sourceBkt, height, indexes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package
|
||||
// with a list of indexes, writing the resulting filter back in its place.
|
||||
func ackAddHtlcsAtHeight(sourceBkt *bbolt.Bucket, height uint64,
|
||||
indexes []uint16) error {
|
||||
|
||||
heightKey := makeLogKey(height)
|
||||
heightBkt := sourceBkt.Bucket(heightKey[:])
|
||||
if heightBkt == nil {
|
||||
// If the height bucket isn't found, this could be because the
|
||||
// forwarding package was already removed. We'll return nil to
|
||||
// signal that the operation is successful, as there is nothing
|
||||
// to ack.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load ack filter from disk.
|
||||
ackFilterBytes := heightBkt.Get(ackFilterKey)
|
||||
if ackFilterBytes == nil {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
ackFilter := &PkgFilter{}
|
||||
ackFilterReader := bytes.NewReader(ackFilterBytes)
|
||||
if err := ackFilter.Decode(ackFilterReader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the ack filter for this height.
|
||||
for _, index := range indexes {
|
||||
ackFilter.Set(index)
|
||||
}
|
||||
|
||||
// Write the resulting filter to disk.
|
||||
var ackFilterBuf bytes.Buffer
|
||||
if err := ackFilter.Encode(&ackFilterBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes())
|
||||
}
|
||||
|
||||
// AckSettleFails persistently acknowledges settles or fails from a remote forwarding
|
||||
// package. This should only be called after the source of the Add has locked in
|
||||
// the settle/fail, or it becomes otherwise safe to forgo retransmitting the
|
||||
// settle/fail after a restart.
|
||||
func (p *ChannelPackager) AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error {
|
||||
return ackSettleFails(tx, settleFailRefs)
|
||||
}
|
||||
|
||||
// ackSettleFails persistently acknowledges a batch of settle fail references.
|
||||
func ackSettleFails(tx *bbolt.Tx, settleFailRefs []SettleFailRef) error {
|
||||
if len(settleFailRefs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
||||
if fwdPkgBkt == nil {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
// Organize the forward references such that we just get a single slice
|
||||
// of indexes for each unique destination-height pair.
|
||||
destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16)
|
||||
for _, settleFailRef := range settleFailRefs {
|
||||
destHeights, ok := destHeightDiffs[settleFailRef.Source]
|
||||
if !ok {
|
||||
destHeights = make(map[uint64][]uint16)
|
||||
destHeightDiffs[settleFailRef.Source] = destHeights
|
||||
}
|
||||
|
||||
destHeights[settleFailRef.Height] = append(
|
||||
destHeights[settleFailRef.Height],
|
||||
settleFailRef.Index,
|
||||
)
|
||||
}
|
||||
|
||||
// With the references organized by destination and height, we now load
|
||||
// each remote bucket, and update the settle fail filter for any
|
||||
// settle/fail htlcs.
|
||||
for dest, destHeights := range destHeightDiffs {
|
||||
destKey := makeLogKey(dest.ToUint64())
|
||||
destBkt := fwdPkgBkt.Bucket(destKey[:])
|
||||
if destBkt == nil {
|
||||
// If the destination bucket is not found, this is
|
||||
// likely the result of the destination channel being
|
||||
// closed and having it's forwarding packages wiped. We
|
||||
// won't treat this as an error, because the response
|
||||
// will no longer be retransmitted internally.
|
||||
continue
|
||||
}
|
||||
|
||||
for height, indexes := range destHeights {
|
||||
err := ackSettleFailsAtHeight(destBkt, height, indexes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes
|
||||
// at particular a height by updating the settle fail filter.
|
||||
func ackSettleFailsAtHeight(destBkt *bbolt.Bucket, height uint64,
|
||||
indexes []uint16) error {
|
||||
|
||||
heightKey := makeLogKey(height)
|
||||
heightBkt := destBkt.Bucket(heightKey[:])
|
||||
if heightBkt == nil {
|
||||
// If the height bucket isn't found, this could be because the
|
||||
// forwarding package was already removed. We'll return nil to
|
||||
// signal that the operation is as there is nothing to ack.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load ack filter from disk.
|
||||
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
|
||||
if settleFailFilterBytes == nil {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
settleFailFilter := &PkgFilter{}
|
||||
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
|
||||
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the ack filter for this height.
|
||||
for _, index := range indexes {
|
||||
settleFailFilter.Set(index)
|
||||
}
|
||||
|
||||
// Write the resulting filter to disk.
|
||||
var settleFailFilterBuf bytes.Buffer
|
||||
if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
|
||||
}
|
||||
|
||||
// RemovePkg deletes the forwarding package at the given height from the
|
||||
// packager's source bucket.
|
||||
func (p *ChannelPackager) RemovePkg(tx *bbolt.Tx, height uint64) error {
|
||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
||||
if fwdPkgBkt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sourceBytes := makeLogKey(p.source.ToUint64())
|
||||
sourceBkt := fwdPkgBkt.Bucket(sourceBytes[:])
|
||||
if sourceBkt == nil {
|
||||
return ErrCorruptedFwdPkg
|
||||
}
|
||||
|
||||
heightKey := makeLogKey(height)
|
||||
|
||||
return sourceBkt.DeleteBucket(heightKey[:])
|
||||
}
|
||||
|
||||
// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice.
|
||||
func uint16Key(i uint16) []byte {
|
||||
key := make([]byte, 2)
|
||||
byteOrder.PutUint16(key, i)
|
||||
return key
|
||||
}
|
||||
|
||||
// Compile-time constraint to ensure that ChannelPackager implements the public
|
||||
// FwdPackager interface.
|
||||
var _ FwdPackager = (*ChannelPackager)(nil)
|
||||
|
||||
// Compile-time constraint to ensure that SwitchPackager implements the public
|
||||
// FwdOperator interface.
|
||||
var _ FwdOperator = (*SwitchPackager)(nil)
|
|
@ -1,815 +0,0 @@
|
|||
package migration_01_to_11_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// TestPkgFilterBruteForce tests the behavior of a pkg filter up to size 1000,
|
||||
// which is greater than the number of HTLCs we permit on a commitment txn.
|
||||
// This should encapsulate every potential filter used in practice.
|
||||
func TestPkgFilterBruteForce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkPkgFilterRange(t, 1000)
|
||||
}
|
||||
|
||||
// checkPkgFilterRange verifies the behavior of a pkg filter when doing a linear
|
||||
// insertion of `high` elements. This is primarily to test that IsFull functions
|
||||
// properly for all relevant sizes of `high`.
|
||||
func checkPkgFilterRange(t *testing.T, high int) {
|
||||
for i := uint16(0); i < uint16(high); i++ {
|
||||
f := channeldb.NewPkgFilter(i)
|
||||
|
||||
if f.Count() != i {
|
||||
t.Fatalf("pkg filter count=%d is actually %d",
|
||||
i, f.Count())
|
||||
}
|
||||
checkPkgFilterEncodeDecode(t, i, f)
|
||||
|
||||
for j := uint16(0); j < i; j++ {
|
||||
if f.Contains(j) {
|
||||
t.Fatalf("pkg filter count=%d contains %d "+
|
||||
"before being added", i, j)
|
||||
}
|
||||
|
||||
f.Set(j)
|
||||
checkPkgFilterEncodeDecode(t, i, f)
|
||||
|
||||
if !f.Contains(j) {
|
||||
t.Fatalf("pkg filter count=%d missing %d "+
|
||||
"after being added", i, j)
|
||||
}
|
||||
|
||||
if j < i-1 && f.IsFull() {
|
||||
t.Fatalf("pkg filter count=%d already full", i)
|
||||
}
|
||||
}
|
||||
|
||||
if !f.IsFull() {
|
||||
t.Fatalf("pkg filter count=%d not full", i)
|
||||
}
|
||||
checkPkgFilterEncodeDecode(t, i, f)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPkgFilterRand uses a random permutation to verify the proper behavior of
|
||||
// the pkg filter if the entries are not inserted in-order.
|
||||
func TestPkgFilterRand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkPkgFilterRand(t, 3, 17)
|
||||
}
|
||||
|
||||
// checkPkgFilterRand checks the behavior of a pkg filter by randomly inserting
|
||||
// indices and asserting the invariants. The order in which indices are inserted
|
||||
// is parameterized by a base `b` coprime to `p`, and using modular
|
||||
// exponentiation to generate all elements in [1,p).
|
||||
func checkPkgFilterRand(t *testing.T, b, p uint16) {
|
||||
f := channeldb.NewPkgFilter(p)
|
||||
var j = b
|
||||
for i := uint16(1); i < p; i++ {
|
||||
if f.Contains(j) {
|
||||
t.Fatalf("pkg filter contains %d-%d "+
|
||||
"before being added", i, j)
|
||||
}
|
||||
|
||||
f.Set(j)
|
||||
checkPkgFilterEncodeDecode(t, i, f)
|
||||
|
||||
if !f.Contains(j) {
|
||||
t.Fatalf("pkg filter missing %d-%d "+
|
||||
"after being added", i, j)
|
||||
}
|
||||
|
||||
if i < p-1 && f.IsFull() {
|
||||
t.Fatalf("pkg filter %d already full", i)
|
||||
}
|
||||
checkPkgFilterEncodeDecode(t, i, f)
|
||||
|
||||
j = (b * j) % p
|
||||
}
|
||||
|
||||
// Set 0 independently, since it will never be emitted by the generator.
|
||||
f.Set(0)
|
||||
checkPkgFilterEncodeDecode(t, p, f)
|
||||
|
||||
if !f.IsFull() {
|
||||
t.Fatalf("pkg filter count=%d not full", p)
|
||||
}
|
||||
checkPkgFilterEncodeDecode(t, p, f)
|
||||
}
|
||||
|
||||
// checkPkgFilterEncodeDecode tests the serialization of a pkg filter by:
|
||||
// 1) writing it to a buffer
|
||||
// 2) verifying the number of bytes written matches the filter's Size()
|
||||
// 3) reconstructing the filter decoding the bytes
|
||||
// 4) checking that the two filters are the same according to Equal
|
||||
func checkPkgFilterEncodeDecode(t *testing.T, i uint16, f *channeldb.PkgFilter) {
|
||||
var b bytes.Buffer
|
||||
if err := f.Encode(&b); err != nil {
|
||||
t.Fatalf("unable to serialize pkg filter: %v", err)
|
||||
}
|
||||
|
||||
// +2 for uint16 length
|
||||
size := uint16(len(b.Bytes()))
|
||||
if size != f.Size() {
|
||||
t.Fatalf("pkg filter count=%d serialized size differs, "+
|
||||
"Size(): %d, len(bytes): %v", i, f.Size(), size)
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(b.Bytes())
|
||||
|
||||
f2 := &channeldb.PkgFilter{}
|
||||
if err := f2.Decode(reader); err != nil {
|
||||
t.Fatalf("unable to deserialize pkg filter: %v", err)
|
||||
}
|
||||
|
||||
if !f.Equal(f2) {
|
||||
t.Fatalf("pkg filter count=%v does is not equal "+
|
||||
"after deserialization, want: %v, got %v",
|
||||
i, f, f2)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
chanID = lnwire.NewChanIDFromOutPoint(&wire.OutPoint{})
|
||||
|
||||
adds = []channeldb.LogUpdate{
|
||||
{
|
||||
LogIndex: 0,
|
||||
UpdateMsg: &lnwire.UpdateAddHTLC{
|
||||
ChanID: chanID,
|
||||
ID: 1,
|
||||
Amount: 100,
|
||||
Expiry: 1000,
|
||||
PaymentHash: [32]byte{0},
|
||||
},
|
||||
},
|
||||
{
|
||||
LogIndex: 1,
|
||||
UpdateMsg: &lnwire.UpdateAddHTLC{
|
||||
ChanID: chanID,
|
||||
ID: 1,
|
||||
Amount: 101,
|
||||
Expiry: 1001,
|
||||
PaymentHash: [32]byte{1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
settleFails = []channeldb.LogUpdate{
|
||||
{
|
||||
LogIndex: 2,
|
||||
UpdateMsg: &lnwire.UpdateFulfillHTLC{
|
||||
ChanID: chanID,
|
||||
ID: 0,
|
||||
PaymentPreimage: [32]byte{0},
|
||||
},
|
||||
},
|
||||
{
|
||||
LogIndex: 3,
|
||||
UpdateMsg: &lnwire.UpdateFailHTLC{
|
||||
ChanID: chanID,
|
||||
ID: 1,
|
||||
Reason: []byte{},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// TestPackagerEmptyFwdPkg checks that the state transitions exhibited by a
|
||||
// forwarding package that contains no adds, fails or settles. We expect that
|
||||
// the fwdpkg reaches FwdStateCompleted immediately after writing the forwarding
|
||||
// decision via SetFwdFilter.
|
||||
func TestPackagerEmptyFwdPkg(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := makeFwdPkgDB(t, "")
|
||||
|
||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
||||
packager := channeldb.NewChannelPackager(shortChanID)
|
||||
|
||||
// To begin, there should be no forwarding packages on disk.
|
||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
|
||||
// Next, create and write a new forwarding package with no htlcs.
|
||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, nil)
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AddFwdPkg(tx, fwdPkg)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
||||
}
|
||||
|
||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
||||
// has been written, we expect it to be FwdStateLockedIn. With no HTLCs,
|
||||
// the ack filter will have no elements, and should always return true.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
||||
|
||||
// Now, write the forwarding decision. In this case, its just an empty
|
||||
// fwd filter.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
||||
}
|
||||
|
||||
// We should still have one package on disk. Since the forwarding
|
||||
// decision has been written, it will minimally be in FwdStateProcessed.
|
||||
// However with no htlcs, it should leap frog to FwdStateCompleted.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
||||
|
||||
// Lastly, remove the completed forwarding package from disk.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
||||
}
|
||||
|
||||
// Check that the fwd package was actually removed.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPackagerOnlyAdds checks that the fwdpkg does not reach FwdStateCompleted
|
||||
// as soon as all the adds in the package have been acked using AckAddHtlcs.
|
||||
func TestPackagerOnlyAdds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := makeFwdPkgDB(t, "")
|
||||
|
||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
||||
packager := channeldb.NewChannelPackager(shortChanID)
|
||||
|
||||
// To begin, there should be no forwarding packages on disk.
|
||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
|
||||
// Next, create and write a new forwarding package that only has add
|
||||
// htlcs.
|
||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, nil)
|
||||
|
||||
nAdds := len(adds)
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AddFwdPkg(tx, fwdPkg)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
||||
}
|
||||
|
||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
||||
// has been written, we expect it to be FwdStateLockedIn. The package
|
||||
// has unacked add HTLCs, so the ack filter should not be full.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
||||
|
||||
// Now, write the forwarding decision. Since we have not explicitly
|
||||
// added any adds to the fwdfilter, this would indicate that all of the
|
||||
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
|
||||
// was failed locally.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
||||
}
|
||||
|
||||
for i := range adds {
|
||||
// We should still have one package on disk. Since the forwarding
|
||||
// decision has been written, it will minimally be in FwdStateProcessed.
|
||||
// However not allf of the HTLCs have been acked, so should not
|
||||
// have advanced further.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
||||
|
||||
addRef := channeldb.AddRef{
|
||||
Height: fwdPkg.Height,
|
||||
Index: uint16(i),
|
||||
}
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AckAddHtlcs(tx, addRef)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to ack add htlc: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// We should still have one package on disk. Now that all adds have been
|
||||
// acked, the ack filter should return true and the package should be
|
||||
// FwdStateCompleted since there are no other settle/fail packets.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
||||
|
||||
// Lastly, remove the completed forwarding package from disk.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
||||
}
|
||||
|
||||
// Check that the fwd package was actually removed.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPackagerOnlySettleFails asserts that the fwdpkg remains in
|
||||
// FwdStateProcessed after writing the forwarding decision when there are no
|
||||
// adds in the fwdpkg. We expect this because an empty FwdFilter will always
|
||||
// return true, but we are still waiting for the remaining fails and settles to
|
||||
// be deleted.
|
||||
func TestPackagerOnlySettleFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := makeFwdPkgDB(t, "")
|
||||
|
||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
||||
packager := channeldb.NewChannelPackager(shortChanID)
|
||||
|
||||
// To begin, there should be no forwarding packages on disk.
|
||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
|
||||
// Next, create and write a new forwarding package that only has add
|
||||
// htlcs.
|
||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, settleFails)
|
||||
|
||||
nSettleFails := len(settleFails)
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AddFwdPkg(tx, fwdPkg)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
||||
}
|
||||
|
||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
||||
// has been written, we expect it to be FwdStateLockedIn. The package
|
||||
// has unacked add HTLCs, so the ack filter should not be full.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
||||
|
||||
// Now, write the forwarding decision. Since we have not explicitly
|
||||
// added any adds to the fwdfilter, this would indicate that all of the
|
||||
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
|
||||
// was failed locally.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
||||
}
|
||||
|
||||
for i := range settleFails {
|
||||
// We should still have one package on disk. Since the
|
||||
// forwarding decision has been written, it will minimally be in
|
||||
// FwdStateProcessed. However, not all of the HTLCs have been
|
||||
// acked, so should not have advanced further.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails)
|
||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
||||
|
||||
failSettleRef := channeldb.SettleFailRef{
|
||||
Source: shortChanID,
|
||||
Height: fwdPkg.Height,
|
||||
Index: uint16(i),
|
||||
}
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AckSettleFails(tx, failSettleRef)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to ack add htlc: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// We should still have one package on disk. Now that all settles and
|
||||
// fails have been removed, package should be FwdStateCompleted since
|
||||
// there are no other add packets.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails)
|
||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
||||
|
||||
// Lastly, remove the completed forwarding package from disk.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
||||
}
|
||||
|
||||
// Check that the fwd package was actually removed.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPackagerAddsThenSettleFails writes a fwdpkg containing both adds and
|
||||
// settle/fails, then checks the behavior when the adds are acked before any of
|
||||
// the settle fails. Here we expect pkg to remain in FwdStateProcessed while the
|
||||
// remainder of the fail/settles are being deleted.
|
||||
func TestPackagerAddsThenSettleFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := makeFwdPkgDB(t, "")
|
||||
|
||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
||||
packager := channeldb.NewChannelPackager(shortChanID)
|
||||
|
||||
// To begin, there should be no forwarding packages on disk.
|
||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
|
||||
// Next, create and write a new forwarding package that only has add
|
||||
// htlcs.
|
||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails)
|
||||
|
||||
nAdds := len(adds)
|
||||
nSettleFails := len(settleFails)
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AddFwdPkg(tx, fwdPkg)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
||||
}
|
||||
|
||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
||||
// has been written, we expect it to be FwdStateLockedIn. The package
|
||||
// has unacked add HTLCs, so the ack filter should not be full.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
||||
|
||||
// Now, write the forwarding decision. Since we have not explicitly
|
||||
// added any adds to the fwdfilter, this would indicate that all of the
|
||||
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
|
||||
// was failed locally.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
||||
}
|
||||
|
||||
for i := range adds {
|
||||
// We should still have one package on disk. Since the forwarding
|
||||
// decision has been written, it will minimally be in FwdStateProcessed.
|
||||
// However not allf of the HTLCs have been acked, so should not
|
||||
// have advanced further.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
||||
|
||||
addRef := channeldb.AddRef{
|
||||
Height: fwdPkg.Height,
|
||||
Index: uint16(i),
|
||||
}
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AckAddHtlcs(tx, addRef)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to ack add htlc: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range settleFails {
|
||||
// We should still have one package on disk. Since the
|
||||
// forwarding decision has been written, it will minimally be in
|
||||
// FwdStateProcessed. However not allf of the HTLCs have been
|
||||
// acked, so should not have advanced further.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
||||
|
||||
failSettleRef := channeldb.SettleFailRef{
|
||||
Source: shortChanID,
|
||||
Height: fwdPkg.Height,
|
||||
Index: uint16(i),
|
||||
}
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AckSettleFails(tx, failSettleRef)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to remove settle/fail htlc: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// We should still have one package on disk. Now that all settles and
|
||||
// fails have been removed, package should be FwdStateCompleted since
|
||||
// there are no other add packets.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
||||
|
||||
// Lastly, remove the completed forwarding package from disk.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
||||
}
|
||||
|
||||
// Check that the fwd package was actually removed.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPackagerSettleFailsThenAdds writes a fwdpkg with both adds and
|
||||
// settle/fails, then checks the behavior when the settle/fails are removed
|
||||
// before any of the adds have been acked. This should cause the fwdpkg to
|
||||
// remain in FwdStateProcessed until the final ack is recorded, at which point
|
||||
// it should be promoted directly to FwdStateCompleted.since all adds have been
|
||||
// removed.
|
||||
func TestPackagerSettleFailsThenAdds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := makeFwdPkgDB(t, "")
|
||||
|
||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
||||
packager := channeldb.NewChannelPackager(shortChanID)
|
||||
|
||||
// To begin, there should be no forwarding packages on disk.
|
||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
|
||||
// Next, create and write a new forwarding package that has both add
|
||||
// and settle/fail htlcs.
|
||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails)
|
||||
|
||||
nAdds := len(adds)
|
||||
nSettleFails := len(settleFails)
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AddFwdPkg(tx, fwdPkg)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
||||
}
|
||||
|
||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
||||
// has been written, we expect it to be FwdStateLockedIn. The package
|
||||
// has unacked add HTLCs, so the ack filter should not be full.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
||||
|
||||
// Now, write the forwarding decision. Since we have not explicitly
|
||||
// added any adds to the fwdfilter, this would indicate that all of the
|
||||
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
|
||||
// was failed locally.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
||||
}
|
||||
|
||||
// Simulate another channel deleting the settle/fails it received from
|
||||
// the original fwd pkg.
|
||||
// TODO(conner): use different packager/s?
|
||||
for i := range settleFails {
|
||||
// We should still have one package on disk. Since the
|
||||
// forwarding decision has been written, it will minimally be in
|
||||
// FwdStateProcessed. However none all of the add HTLCs have
|
||||
// been acked, so should not have advanced further.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
||||
|
||||
failSettleRef := channeldb.SettleFailRef{
|
||||
Source: shortChanID,
|
||||
Height: fwdPkg.Height,
|
||||
Index: uint16(i),
|
||||
}
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AckSettleFails(tx, failSettleRef)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to remove settle/fail htlc: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now simulate this channel receiving a fail/settle for the adds in the
|
||||
// fwdpkg.
|
||||
for i := range adds {
|
||||
// Again, we should still have one package on disk and be in
|
||||
// FwdStateProcessed. This should not change until all of the
|
||||
// add htlcs have been acked.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
||||
|
||||
addRef := channeldb.AddRef{
|
||||
Height: fwdPkg.Height,
|
||||
Index: uint16(i),
|
||||
}
|
||||
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.AckAddHtlcs(tx, addRef)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to ack add htlc: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// We should still have one package on disk. Now that all settles and
|
||||
// fails have been removed, package should be FwdStateCompleted since
|
||||
// there are no other add packets.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 1 {
|
||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
||||
}
|
||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
|
||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
||||
|
||||
// Lastly, remove the completed forwarding package from disk.
|
||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
||||
}
|
||||
|
||||
// Check that the fwd package was actually removed.
|
||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
||||
if len(fwdPkgs) != 0 {
|
||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
||||
}
|
||||
}
|
||||
|
||||
// assertFwdPkgState checks the current state of a fwdpkg meets our
|
||||
// expectations.
|
||||
func assertFwdPkgState(t *testing.T, fwdPkg *channeldb.FwdPkg,
|
||||
state channeldb.FwdState) {
|
||||
_, _, line, _ := runtime.Caller(1)
|
||||
if fwdPkg.State != state {
|
||||
t.Fatalf("line %d: expected fwdpkg in state %v, found %v",
|
||||
line, state, fwdPkg.State)
|
||||
}
|
||||
}
|
||||
|
||||
// assertFwdPkgNumAddsSettleFails checks that the number of adds and
|
||||
// settle/fail log updates are correct.
|
||||
func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *channeldb.FwdPkg,
|
||||
expectedNumAdds, expectedNumSettleFails int) {
|
||||
_, _, line, _ := runtime.Caller(1)
|
||||
if len(fwdPkg.Adds) != expectedNumAdds {
|
||||
t.Fatalf("line %d: expected fwdpkg to have %d adds, found %d",
|
||||
line, expectedNumAdds, len(fwdPkg.Adds))
|
||||
}
|
||||
|
||||
if len(fwdPkg.SettleFails) != expectedNumSettleFails {
|
||||
t.Fatalf("line %d: expected fwdpkg to have %d settle/fails, found %d",
|
||||
line, expectedNumSettleFails, len(fwdPkg.SettleFails))
|
||||
}
|
||||
}
|
||||
|
||||
// assertAckFilterIsFull checks whether or not a fwdpkg's ack filter matches our
|
||||
// expected full-ness.
|
||||
func assertAckFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) {
|
||||
_, _, line, _ := runtime.Caller(1)
|
||||
if fwdPkg.AckFilter.IsFull() != expected {
|
||||
t.Fatalf("line %d: expected fwdpkg ack filter IsFull to be %v, "+
|
||||
"found %v", line, expected, fwdPkg.AckFilter.IsFull())
|
||||
}
|
||||
}
|
||||
|
||||
// assertSettleFailFilterIsFull checks whether or not a fwdpkg's settle fail
|
||||
// filter matches our expected full-ness.
|
||||
func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) {
|
||||
_, _, line, _ := runtime.Caller(1)
|
||||
if fwdPkg.SettleFailFilter.IsFull() != expected {
|
||||
t.Fatalf("line %d: expected fwdpkg settle/fail filter IsFull to be %v, "+
|
||||
"found %v", line, expected, fwdPkg.SettleFailFilter.IsFull())
|
||||
}
|
||||
}
|
||||
|
||||
// loadFwdPkgs is a helper method that reads all forwarding packages for a
|
||||
// particular packager.
|
||||
func loadFwdPkgs(t *testing.T, db *bbolt.DB,
|
||||
packager channeldb.FwdPackager) []*channeldb.FwdPkg {
|
||||
|
||||
var fwdPkgs []*channeldb.FwdPkg
|
||||
if err := db.View(func(tx *bbolt.Tx) error {
|
||||
var err error
|
||||
fwdPkgs, err = packager.LoadFwdPkgs(tx)
|
||||
return err
|
||||
}); err != nil {
|
||||
t.Fatalf("unable to load fwd pkgs: %v", err)
|
||||
}
|
||||
|
||||
return fwdPkgs
|
||||
}
|
||||
|
||||
// makeFwdPkgDB initializes a test database for forwarding packages. If the
|
||||
// provided path is an empty, it will create a temp dir/file to use.
|
||||
func makeFwdPkgDB(t *testing.T, path string) *bbolt.DB {
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = ioutil.TempDir("", "fwdpkgdb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create temp path: %v", err)
|
||||
}
|
||||
|
||||
path = filepath.Join(path, "fwdpkg.db")
|
||||
}
|
||||
|
||||
db, err := bbolt.Open(path, 0600, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open boltdb: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -1,694 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
|
||||
var pre [32]byte
|
||||
if _, err := rand.Read(pre[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i := &Invoice{
|
||||
// Use single second precision to avoid false positive test
|
||||
// failures due to the monotonic time component.
|
||||
CreationDate: time.Unix(time.Now().Unix(), 0),
|
||||
Terms: ContractTerm{
|
||||
PaymentPreimage: pre,
|
||||
Value: value,
|
||||
},
|
||||
Htlcs: map[CircuitKey]*InvoiceHTLC{},
|
||||
Expiry: 4000,
|
||||
}
|
||||
i.Memo = []byte("memo")
|
||||
i.Receipt = []byte("receipt")
|
||||
|
||||
// Create a random byte slice of MaxPaymentRequestSize bytes to be used
|
||||
// as a dummy paymentrequest, and determine if it should be set based
|
||||
// on one of the random bytes.
|
||||
var r [MaxPaymentRequestSize]byte
|
||||
if _, err := rand.Read(r[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r[0]&1 == 0 {
|
||||
i.PaymentRequest = r[:]
|
||||
} else {
|
||||
i.PaymentRequest = []byte("")
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func TestInvoiceWorkflow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanUp, err := makeTestDB()
|
||||
defer cleanUp()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test db: %v", err)
|
||||
}
|
||||
|
||||
// Create a fake invoice which we'll use several times in the tests
|
||||
// below.
|
||||
fakeInvoice := &Invoice{
|
||||
// Use single second precision to avoid false positive test
|
||||
// failures due to the monotonic time component.
|
||||
CreationDate: time.Unix(time.Now().Unix(), 0),
|
||||
Htlcs: map[CircuitKey]*InvoiceHTLC{},
|
||||
}
|
||||
fakeInvoice.Memo = []byte("memo")
|
||||
fakeInvoice.Receipt = []byte("receipt")
|
||||
fakeInvoice.PaymentRequest = []byte("")
|
||||
copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:])
|
||||
fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000)
|
||||
|
||||
paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash()
|
||||
|
||||
// Add the invoice to the database, this should succeed as there aren't
|
||||
// any existing invoices within the database with the same payment
|
||||
// hash.
|
||||
if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil {
|
||||
t.Fatalf("unable to find invoice: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to retrieve the invoice which was just added to the
|
||||
// database. It should be found, and the invoice returned should be
|
||||
// identical to the one created above.
|
||||
dbInvoice, err := db.LookupInvoice(paymentHash)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to find invoice: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(*fakeInvoice, dbInvoice) {
|
||||
t.Fatalf("invoice fetched from db doesn't match original %v vs %v",
|
||||
spew.Sdump(fakeInvoice), spew.Sdump(dbInvoice))
|
||||
}
|
||||
|
||||
// The add index of the invoice retrieved from the database should now
|
||||
// be fully populated. As this is the first index written to the DB,
|
||||
// the addIndex should be 1.
|
||||
if dbInvoice.AddIndex != 1 {
|
||||
t.Fatalf("wrong add index: expected %v, got %v", 1,
|
||||
dbInvoice.AddIndex)
|
||||
}
|
||||
|
||||
// Settle the invoice, the version retrieved from the database should
|
||||
// now have the settled bit toggle to true and a non-default
|
||||
// SettledDate
|
||||
payAmt := fakeInvoice.Terms.Value * 2
|
||||
_, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to settle invoice: %v", err)
|
||||
}
|
||||
dbInvoice2, err := db.LookupInvoice(paymentHash)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch invoice: %v", err)
|
||||
}
|
||||
if dbInvoice2.Terms.State != ContractSettled {
|
||||
t.Fatalf("invoice should now be settled but isn't")
|
||||
}
|
||||
if dbInvoice2.SettleDate.IsZero() {
|
||||
t.Fatalf("invoice should have non-zero SettledDate but isn't")
|
||||
}
|
||||
|
||||
// Our 2x payment should be reflected, and also the settle index of 1
|
||||
// should also have been committed for this index.
|
||||
if dbInvoice2.AmtPaid != payAmt {
|
||||
t.Fatalf("wrong amt paid: expected %v, got %v", payAmt,
|
||||
dbInvoice2.AmtPaid)
|
||||
}
|
||||
if dbInvoice2.SettleIndex != 1 {
|
||||
t.Fatalf("wrong settle index: expected %v, got %v", 1,
|
||||
dbInvoice2.SettleIndex)
|
||||
}
|
||||
|
||||
// Attempt to insert generated above again, this should fail as
|
||||
// duplicates are rejected by the processing logic.
|
||||
if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice {
|
||||
t.Fatalf("invoice insertion should fail due to duplication, "+
|
||||
"instead %v", err)
|
||||
}
|
||||
|
||||
// Attempt to look up a non-existent invoice, this should also fail but
|
||||
// with a "not found" error.
|
||||
var fakeHash [32]byte
|
||||
if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound {
|
||||
t.Fatalf("lookup should have failed, instead %v", err)
|
||||
}
|
||||
|
||||
// Add 10 random invoices.
|
||||
const numInvoices = 10
|
||||
amt := lnwire.NewMSatFromSatoshis(1000)
|
||||
invoices := make([]*Invoice, numInvoices+1)
|
||||
invoices[0] = &dbInvoice2
|
||||
for i := 1; i < len(invoices)-1; i++ {
|
||||
invoice, err := randInvoice(amt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create invoice: %v", err)
|
||||
}
|
||||
|
||||
hash := invoice.Terms.PaymentPreimage.Hash()
|
||||
if _, err := db.AddInvoice(invoice, hash); err != nil {
|
||||
t.Fatalf("unable to add invoice %v", err)
|
||||
}
|
||||
|
||||
invoices[i] = invoice
|
||||
}
|
||||
|
||||
// Perform a scan to collect all the active invoices.
|
||||
dbInvoices, err := db.FetchAllInvoices(false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch all invoices: %v", err)
|
||||
}
|
||||
|
||||
// The retrieve list of invoices should be identical as since we're
|
||||
// using big endian, the invoices should be retrieved in ascending
|
||||
// order (and the primary key should be incremented with each
|
||||
// insertion).
|
||||
for i := 0; i < len(invoices)-1; i++ {
|
||||
if !reflect.DeepEqual(*invoices[i], dbInvoices[i]) {
|
||||
t.Fatalf("retrieved invoices don't match %v vs %v",
|
||||
spew.Sdump(invoices[i]),
|
||||
spew.Sdump(dbInvoices[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvoiceTimeSeries tests that newly added invoices invoices, as well as
|
||||
// settled invoices are added to the database are properly placed in the add
|
||||
// add or settle index which serves as an event time series.
|
||||
func TestInvoiceAddTimeSeries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanUp, err := makeTestDB()
|
||||
defer cleanUp()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test db: %v", err)
|
||||
}
|
||||
|
||||
// We'll start off by creating 20 random invoices, and inserting them
|
||||
// into the database.
|
||||
const numInvoices = 20
|
||||
amt := lnwire.NewMSatFromSatoshis(1000)
|
||||
invoices := make([]Invoice, numInvoices)
|
||||
for i := 0; i < len(invoices); i++ {
|
||||
invoice, err := randInvoice(amt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create invoice: %v", err)
|
||||
}
|
||||
|
||||
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
||||
|
||||
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
|
||||
t.Fatalf("unable to add invoice %v", err)
|
||||
}
|
||||
|
||||
invoices[i] = *invoice
|
||||
}
|
||||
|
||||
// With the invoices constructed, we'll now create a series of queries
|
||||
// that we'll use to assert expected return values of
|
||||
// InvoicesAddedSince.
|
||||
addQueries := []struct {
|
||||
sinceAddIndex uint64
|
||||
|
||||
resp []Invoice
|
||||
}{
|
||||
// If we specify a value of zero, we shouldn't get any invoices
|
||||
// back.
|
||||
{
|
||||
sinceAddIndex: 0,
|
||||
},
|
||||
|
||||
// If we specify a value well beyond the number of inserted
|
||||
// invoices, we shouldn't get any invoices back.
|
||||
{
|
||||
sinceAddIndex: 99999999,
|
||||
},
|
||||
|
||||
// Using an index of 1 should result in all values, but the
|
||||
// first one being returned.
|
||||
{
|
||||
sinceAddIndex: 1,
|
||||
resp: invoices[1:],
|
||||
},
|
||||
|
||||
// If we use an index of 10, then we should retrieve the
|
||||
// reaming 10 invoices.
|
||||
{
|
||||
sinceAddIndex: 10,
|
||||
resp: invoices[10:],
|
||||
},
|
||||
}
|
||||
|
||||
for i, query := range addQueries {
|
||||
resp, err := db.InvoicesAddedSince(query.sinceAddIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(query.resp, resp) {
|
||||
t.Fatalf("test #%v: expected %v, got %v", i,
|
||||
spew.Sdump(query.resp), spew.Sdump(resp))
|
||||
}
|
||||
}
|
||||
|
||||
// We'll now only settle the latter half of each of those invoices.
|
||||
for i := 10; i < len(invoices); i++ {
|
||||
invoice := &invoices[i]
|
||||
|
||||
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
||||
|
||||
_, err := db.UpdateInvoice(
|
||||
paymentHash, getUpdateInvoice(0),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to settle invoice: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
invoices, err = db.FetchAllInvoices(false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch invoices: %v", err)
|
||||
}
|
||||
|
||||
// We'll slice off the first 10 invoices, as we only settled the last
|
||||
// 10.
|
||||
invoices = invoices[10:]
|
||||
|
||||
// We'll now prepare an additional set of queries to ensure the settle
|
||||
// time series has properly been maintained in the database.
|
||||
settleQueries := []struct {
|
||||
sinceSettleIndex uint64
|
||||
|
||||
resp []Invoice
|
||||
}{
|
||||
// If we specify a value of zero, we shouldn't get any settled
|
||||
// invoices back.
|
||||
{
|
||||
sinceSettleIndex: 0,
|
||||
},
|
||||
|
||||
// If we specify a value well beyond the number of settled
|
||||
// invoices, we shouldn't get any invoices back.
|
||||
{
|
||||
sinceSettleIndex: 99999999,
|
||||
},
|
||||
|
||||
// Using an index of 1 should result in the final 10 invoices
|
||||
// being returned, as we only settled those.
|
||||
{
|
||||
sinceSettleIndex: 1,
|
||||
resp: invoices[1:],
|
||||
},
|
||||
}
|
||||
|
||||
for i, query := range settleQueries {
|
||||
resp, err := db.InvoicesSettledSince(query.sinceSettleIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(query.resp, resp) {
|
||||
t.Fatalf("test #%v: expected %v, got %v", i,
|
||||
spew.Sdump(query.resp), spew.Sdump(resp))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it
|
||||
// twice, then the second time we also receive the invoice that we settled as a
|
||||
// return argument.
|
||||
func TestDuplicateSettleInvoice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanUp, err := makeTestDB()
|
||||
defer cleanUp()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test db: %v", err)
|
||||
}
|
||||
db.now = func() time.Time { return time.Unix(1, 0) }
|
||||
|
||||
// We'll start out by creating an invoice and writing it to the DB.
|
||||
amt := lnwire.NewMSatFromSatoshis(1000)
|
||||
invoice, err := randInvoice(amt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create invoice: %v", err)
|
||||
}
|
||||
|
||||
payHash := invoice.Terms.PaymentPreimage.Hash()
|
||||
|
||||
if _, err := db.AddInvoice(invoice, payHash); err != nil {
|
||||
t.Fatalf("unable to add invoice %v", err)
|
||||
}
|
||||
|
||||
// With the invoice in the DB, we'll now attempt to settle the invoice.
|
||||
dbInvoice, err := db.UpdateInvoice(
|
||||
payHash, getUpdateInvoice(amt),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to settle invoice: %v", err)
|
||||
}
|
||||
|
||||
// We'll update what we expect the settle invoice to be so that our
|
||||
// comparison below has the correct assumption.
|
||||
invoice.SettleIndex = 1
|
||||
invoice.Terms.State = ContractSettled
|
||||
invoice.AmtPaid = amt
|
||||
invoice.SettleDate = dbInvoice.SettleDate
|
||||
invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{
|
||||
{}: {
|
||||
Amt: amt,
|
||||
AcceptTime: time.Unix(1, 0),
|
||||
ResolveTime: time.Unix(1, 0),
|
||||
State: HtlcStateSettled,
|
||||
},
|
||||
}
|
||||
|
||||
// We should get back the exact same invoice that we just inserted.
|
||||
if !reflect.DeepEqual(dbInvoice, invoice) {
|
||||
t.Fatalf("wrong invoice after settle, expected %v got %v",
|
||||
spew.Sdump(invoice), spew.Sdump(dbInvoice))
|
||||
}
|
||||
|
||||
// If we try to settle the invoice again, then we should get the very
|
||||
// same invoice back, but with an error this time.
|
||||
dbInvoice, err = db.UpdateInvoice(
|
||||
payHash, getUpdateInvoice(amt),
|
||||
)
|
||||
if err != ErrInvoiceAlreadySettled {
|
||||
t.Fatalf("expected ErrInvoiceAlreadySettled")
|
||||
}
|
||||
|
||||
if dbInvoice == nil {
|
||||
t.Fatalf("invoice from db is nil after settle!")
|
||||
}
|
||||
|
||||
invoice.SettleDate = dbInvoice.SettleDate
|
||||
if !reflect.DeepEqual(dbInvoice, invoice) {
|
||||
t.Fatalf("wrong invoice after second settle, expected %v got %v",
|
||||
spew.Sdump(invoice), spew.Sdump(dbInvoice))
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueryInvoices ensures that we can properly query the invoice database for
|
||||
// invoices using different types of queries.
|
||||
func TestQueryInvoices(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanUp, err := makeTestDB()
|
||||
defer cleanUp()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test db: %v", err)
|
||||
}
|
||||
|
||||
// To begin the test, we'll add 50 invoices to the database. We'll
|
||||
// assume that the index of the invoice within the database is the same
|
||||
// as the amount of the invoice itself.
|
||||
const numInvoices = 50
|
||||
for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ {
|
||||
invoice, err := randInvoice(i)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create invoice: %v", err)
|
||||
}
|
||||
|
||||
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
||||
|
||||
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
|
||||
t.Fatalf("unable to add invoice: %v", err)
|
||||
}
|
||||
|
||||
// We'll only settle half of all invoices created.
|
||||
if i%2 == 0 {
|
||||
_, err := db.UpdateInvoice(
|
||||
paymentHash, getUpdateInvoice(i),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to settle invoice: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We'll then retrieve the set of all invoices and pending invoices.
|
||||
// This will serve useful when comparing the expected responses of the
|
||||
// query with the actual ones.
|
||||
invoices, err := db.FetchAllInvoices(false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to retrieve invoices: %v", err)
|
||||
}
|
||||
pendingInvoices, err := db.FetchAllInvoices(true)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to retrieve pending invoices: %v", err)
|
||||
}
|
||||
|
||||
// The test will consist of several queries along with their respective
|
||||
// expected response. Each query response should match its expected one.
|
||||
testCases := []struct {
|
||||
query InvoiceQuery
|
||||
expected []Invoice
|
||||
}{
|
||||
// Fetch all invoices with a single query.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
expected: invoices,
|
||||
},
|
||||
// Fetch all invoices with a single query, reversed.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
Reversed: true,
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
expected: invoices,
|
||||
},
|
||||
// Fetch the first 25 invoices.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
NumMaxInvoices: numInvoices / 2,
|
||||
},
|
||||
expected: invoices[:numInvoices/2],
|
||||
},
|
||||
// Fetch the first 10 invoices, but this time iterating
|
||||
// backwards.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 11,
|
||||
Reversed: true,
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
expected: invoices[:10],
|
||||
},
|
||||
// Fetch the last 40 invoices.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 10,
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
expected: invoices[10:],
|
||||
},
|
||||
// Fetch all but the first invoice.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 1,
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
expected: invoices[1:],
|
||||
},
|
||||
// Fetch one invoice, reversed, with index offset 3. This
|
||||
// should give us the second invoice in the array.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 3,
|
||||
Reversed: true,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: invoices[1:2],
|
||||
},
|
||||
// Same as above, at index 2.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 2,
|
||||
Reversed: true,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: invoices[0:1],
|
||||
},
|
||||
// Fetch one invoice, at index 1, reversed. Since invoice#1 is
|
||||
// the very first, there won't be any left in a reverse search,
|
||||
// so we expect no invoices to be returned.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 1,
|
||||
Reversed: true,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: nil,
|
||||
},
|
||||
// Same as above, but don't restrict the number of invoices to
|
||||
// 1.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 1,
|
||||
Reversed: true,
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
expected: nil,
|
||||
},
|
||||
// Fetch one invoice, reversed, with no offset set. We expect
|
||||
// the last invoice in the response.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
Reversed: true,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: invoices[numInvoices-1:],
|
||||
},
|
||||
// Fetch one invoice, reversed, the offset set at numInvoices+1.
|
||||
// We expect this to return the last invoice.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: numInvoices + 1,
|
||||
Reversed: true,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: invoices[numInvoices-1:],
|
||||
},
|
||||
// Same as above, at offset numInvoices.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: numInvoices,
|
||||
Reversed: true,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: invoices[numInvoices-2 : numInvoices-1],
|
||||
},
|
||||
// Fetch one invoice, at no offset (same as offset 0). We
|
||||
// expect the first invoice only in the response.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: invoices[:1],
|
||||
},
|
||||
// Same as above, at offset 1.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 1,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: invoices[1:2],
|
||||
},
|
||||
// Same as above, at offset 2.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 2,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: invoices[2:3],
|
||||
},
|
||||
// Same as above, at offset numInvoices-1. Expect the last
|
||||
// invoice to be returned.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: numInvoices - 1,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: invoices[numInvoices-1:],
|
||||
},
|
||||
// Same as above, at offset numInvoices. No invoices should be
|
||||
// returned, as there are no invoices after this offset.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: numInvoices,
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
expected: nil,
|
||||
},
|
||||
// Fetch all pending invoices with a single query.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
PendingOnly: true,
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
expected: pendingInvoices,
|
||||
},
|
||||
// Fetch the first 12 pending invoices.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
PendingOnly: true,
|
||||
NumMaxInvoices: numInvoices / 4,
|
||||
},
|
||||
expected: pendingInvoices[:len(pendingInvoices)/2],
|
||||
},
|
||||
// Fetch the first 5 pending invoices, but this time iterating
|
||||
// backwards.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 10,
|
||||
PendingOnly: true,
|
||||
Reversed: true,
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
// Since we seek to the invoice with index 10 and
|
||||
// iterate backwards, there should only be 5 pending
|
||||
// invoices before it as every other invoice within the
|
||||
// index is settled.
|
||||
expected: pendingInvoices[:5],
|
||||
},
|
||||
// Fetch the last 15 invoices.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: 20,
|
||||
PendingOnly: true,
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
// Since we seek to the invoice with index 20, there are
|
||||
// 30 invoices left. From these 30, only 15 of them are
|
||||
// still pending.
|
||||
expected: pendingInvoices[len(pendingInvoices)-15:],
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
response, err := db.QueryInvoices(testCase.query)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query invoice database: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(response.Invoices, testCase.expected) {
|
||||
t.Fatalf("test #%d: query returned incorrect set of "+
|
||||
"invoices: expcted %v, got %v", i,
|
||||
spew.Sdump(response.Invoices),
|
||||
spew.Sdump(testCase.expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getUpdateInvoice returns an invoice update callback that, when called,
|
||||
// settles the invoice with the given amount.
|
||||
func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
|
||||
return func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
|
||||
if invoice.Terms.State == ContractSettled {
|
||||
return nil, ErrInvoiceAlreadySettled
|
||||
}
|
||||
|
||||
update := &InvoiceUpdateDesc{
|
||||
Preimage: invoice.Terms.PaymentPreimage,
|
||||
State: ContractSettled,
|
||||
Htlcs: map[CircuitKey]*HtlcAcceptDesc{
|
||||
{}: {
|
||||
Amt: amt,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return update, nil
|
||||
}
|
||||
}
|
|
@ -3,7 +3,6 @@ package migration_01_to_11
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
@ -16,9 +15,6 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
// UnknownPreimage is an all-zeroes preimage that indicates that the
|
||||
// preimage for this invoice is not yet known.
|
||||
UnknownPreimage lntypes.Preimage
|
||||
|
||||
// invoiceBucket is the name of the bucket within the database that
|
||||
// stores all data related to invoices no matter their final state.
|
||||
|
@ -26,23 +22,6 @@ var (
|
|||
// which is a monotonically increasing uint32.
|
||||
invoiceBucket = []byte("invoices")
|
||||
|
||||
// paymentHashIndexBucket is the name of the sub-bucket within the
|
||||
// invoiceBucket which indexes all invoices by their payment hash. The
|
||||
// payment hash is the sha256 of the invoice's payment preimage. This
|
||||
// index is used to detect duplicates, and also to provide a fast path
|
||||
// for looking up incoming HTLCs to determine if we're able to settle
|
||||
// them fully.
|
||||
//
|
||||
// maps: payHash => invoiceKey
|
||||
invoiceIndexBucket = []byte("paymenthashes")
|
||||
|
||||
// numInvoicesKey is the name of key which houses the auto-incrementing
|
||||
// invoice ID which is essentially used as a primary key. With each
|
||||
// invoice inserted, the primary key is incremented by one. This key is
|
||||
// stored within the invoiceIndexBucket. Within the invoiceBucket
|
||||
// invoices are uniquely identified by the invoice ID.
|
||||
numInvoicesKey = []byte("nik")
|
||||
|
||||
// addIndexBucket is an index bucket that we'll use to create a
|
||||
// monotonically increasing set of add indexes. Each time we add a new
|
||||
// invoice, this sequence number will be incremented and then populated
|
||||
|
@ -62,21 +41,6 @@ var (
|
|||
//
|
||||
// settleIndexNo => invoiceKey
|
||||
settleIndexBucket = []byte("invoice-settle-index")
|
||||
|
||||
// ErrInvoiceAlreadySettled is returned when the invoice is already
|
||||
// settled.
|
||||
ErrInvoiceAlreadySettled = errors.New("invoice already settled")
|
||||
|
||||
// ErrInvoiceAlreadyCanceled is returned when the invoice is already
|
||||
// canceled.
|
||||
ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled")
|
||||
|
||||
// ErrInvoiceAlreadyAccepted is returned when the invoice is already
|
||||
// accepted.
|
||||
ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted")
|
||||
|
||||
// ErrInvoiceStillOpen is returned when the invoice is still open.
|
||||
ErrInvoiceStillOpen = errors.New("invoice still open")
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -237,18 +201,6 @@ type Invoice struct {
|
|||
// HtlcState defines the states an htlc paying to an invoice can be in.
|
||||
type HtlcState uint8
|
||||
|
||||
const (
|
||||
// HtlcStateAccepted indicates the htlc is locked-in, but not resolved.
|
||||
HtlcStateAccepted HtlcState = iota
|
||||
|
||||
// HtlcStateCanceled indicates the htlc is canceled back to the
|
||||
// sender.
|
||||
HtlcStateCanceled
|
||||
|
||||
// HtlcStateSettled indicates the htlc is settled.
|
||||
HtlcStateSettled
|
||||
)
|
||||
|
||||
// InvoiceHTLC contains details about an htlc paying to this invoice.
|
||||
type InvoiceHTLC struct {
|
||||
// Amt is the amount that is carried by this htlc.
|
||||
|
@ -276,37 +228,6 @@ type InvoiceHTLC struct {
|
|||
State HtlcState
|
||||
}
|
||||
|
||||
// HtlcAcceptDesc describes the details of a newly accepted htlc.
|
||||
type HtlcAcceptDesc struct {
|
||||
// AcceptHeight is the block height at which this htlc was accepted.
|
||||
AcceptHeight int32
|
||||
|
||||
// Amt is the amount that is carried by this htlc.
|
||||
Amt lnwire.MilliSatoshi
|
||||
|
||||
// Expiry is the expiry height of this htlc.
|
||||
Expiry uint32
|
||||
}
|
||||
|
||||
// InvoiceUpdateDesc describes the changes that should be applied to the
|
||||
// invoice.
|
||||
type InvoiceUpdateDesc struct {
|
||||
// State is the new state that this invoice should progress to.
|
||||
State ContractState
|
||||
|
||||
// Htlcs describes the changes that need to be made to the invoice htlcs
|
||||
// in the database. Htlc map entries with their value set should be
|
||||
// added. If the map value is nil, the htlc should be canceled.
|
||||
Htlcs map[CircuitKey]*HtlcAcceptDesc
|
||||
|
||||
// Preimage must be set to the preimage when state is settled.
|
||||
Preimage lntypes.Preimage
|
||||
}
|
||||
|
||||
// InvoiceUpdateCallback is a callback used in the db transaction to update the
|
||||
// invoice.
|
||||
type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error)
|
||||
|
||||
func validateInvoice(i *Invoice) error {
|
||||
if len(i.Memo) > MaxMemoSize {
|
||||
return fmt.Errorf("max length a memo is %v, and invoice "+
|
||||
|
@ -325,186 +246,6 @@ func validateInvoice(i *Invoice) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// AddInvoice inserts the targeted invoice into the database. If the invoice has
|
||||
// *any* payment hashes which already exists within the database, then the
|
||||
// insertion will be aborted and rejected due to the strict policy banning any
|
||||
// duplicate payment hashes. A side effect of this function is that it sets
|
||||
// AddIndex on newInvoice.
|
||||
func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) (
|
||||
uint64, error) {
|
||||
|
||||
if err := validateInvoice(newInvoice); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var invoiceAddIndex uint64
|
||||
err := d.Update(func(tx *bbolt.Tx) error {
|
||||
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
invoiceIndex, err := invoices.CreateBucketIfNotExists(
|
||||
invoiceIndexBucket,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addIndex, err := invoices.CreateBucketIfNotExists(
|
||||
addIndexBucket,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that an invoice an identical payment hash doesn't
|
||||
// already exist within the index.
|
||||
if invoiceIndex.Get(paymentHash[:]) != nil {
|
||||
return ErrDuplicateInvoice
|
||||
}
|
||||
|
||||
// If the current running payment ID counter hasn't yet been
|
||||
// created, then create it now.
|
||||
var invoiceNum uint32
|
||||
invoiceCounter := invoiceIndex.Get(numInvoicesKey)
|
||||
if invoiceCounter == nil {
|
||||
var scratch [4]byte
|
||||
byteOrder.PutUint32(scratch[:], invoiceNum)
|
||||
err := invoiceIndex.Put(numInvoicesKey, scratch[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
invoiceNum = byteOrder.Uint32(invoiceCounter)
|
||||
}
|
||||
|
||||
newIndex, err := putInvoice(
|
||||
invoices, invoiceIndex, addIndex, newInvoice, invoiceNum,
|
||||
paymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
invoiceAddIndex = newIndex
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return invoiceAddIndex, err
|
||||
}
|
||||
|
||||
// InvoicesAddedSince can be used by callers to seek into the event time series
|
||||
// of all the invoices added in the database. The specified sinceAddIndex
|
||||
// should be the highest add index that the caller knows of. This method will
|
||||
// return all invoices with an add index greater than the specified
|
||||
// sinceAddIndex.
|
||||
//
|
||||
// NOTE: The index starts from 1, as a result. We enforce that specifying a
|
||||
// value below the starting index value is a noop.
|
||||
func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) {
|
||||
var newInvoices []Invoice
|
||||
|
||||
// If an index of zero was specified, then in order to maintain
|
||||
// backwards compat, we won't send out any new invoices.
|
||||
if sinceAddIndex == 0 {
|
||||
return newInvoices, nil
|
||||
}
|
||||
|
||||
var startIndex [8]byte
|
||||
byteOrder.PutUint64(startIndex[:], sinceAddIndex)
|
||||
|
||||
err := d.DB.View(func(tx *bbolt.Tx) error {
|
||||
invoices := tx.Bucket(invoiceBucket)
|
||||
if invoices == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
addIndex := invoices.Bucket(addIndexBucket)
|
||||
if addIndex == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
// We'll now run through each entry in the add index starting
|
||||
// at our starting index. We'll continue until we reach the
|
||||
// very end of the current key space.
|
||||
invoiceCursor := addIndex.Cursor()
|
||||
|
||||
// We'll seek to the starting index, then manually advance the
|
||||
// cursor in order to skip the entry with the since add index.
|
||||
invoiceCursor.Seek(startIndex[:])
|
||||
addSeqNo, invoiceKey := invoiceCursor.Next()
|
||||
|
||||
for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
|
||||
|
||||
// For each key found, we'll look up the actual
|
||||
// invoice, then accumulate it into our return value.
|
||||
invoice, err := fetchInvoice(invoiceKey, invoices)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newInvoices = append(newInvoices, invoice)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
switch {
|
||||
// If no invoices have been created, then we'll return the empty set of
|
||||
// invoices.
|
||||
case err == ErrNoInvoicesCreated:
|
||||
|
||||
case err != nil:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newInvoices, nil
|
||||
}
|
||||
|
||||
// LookupInvoice attempts to look up an invoice according to its 32 byte
|
||||
// payment hash. If an invoice which can settle the HTLC identified by the
|
||||
// passed payment hash isn't found, then an error is returned. Otherwise, the
|
||||
// full invoice is returned. Before setting the incoming HTLC, the values
|
||||
// SHOULD be checked to ensure the payer meets the agreed upon contractual
|
||||
// terms of the payment.
|
||||
func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
|
||||
var invoice Invoice
|
||||
err := d.View(func(tx *bbolt.Tx) error {
|
||||
invoices := tx.Bucket(invoiceBucket)
|
||||
if invoices == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
invoiceIndex := invoices.Bucket(invoiceIndexBucket)
|
||||
if invoiceIndex == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
// Check the invoice index to see if an invoice paying to this
|
||||
// hash exists within the DB.
|
||||
invoiceNum := invoiceIndex.Get(paymentHash[:])
|
||||
if invoiceNum == nil {
|
||||
return ErrInvoiceNotFound
|
||||
}
|
||||
|
||||
// An invoice matching the payment hash has been found, so
|
||||
// retrieve the record of the invoice itself.
|
||||
i, err := fetchInvoice(invoiceNum, invoices)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
invoice = i
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
return invoice, nil
|
||||
}
|
||||
|
||||
// FetchAllInvoices returns all invoices currently stored within the database.
|
||||
// If the pendingOnly param is true, then only unsettled invoices will be
|
||||
// returned, skipping all invoices that are fully settled.
|
||||
|
@ -549,343 +290,6 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) {
|
|||
return invoices, nil
|
||||
}
|
||||
|
||||
// InvoiceQuery represents a query to the invoice database. The query allows a
|
||||
// caller to retrieve all invoices starting from a particular add index and
|
||||
// limit the number of results returned.
|
||||
type InvoiceQuery struct {
|
||||
// IndexOffset is the offset within the add indices to start at. This
|
||||
// can be used to start the response at a particular invoice.
|
||||
IndexOffset uint64
|
||||
|
||||
// NumMaxInvoices is the maximum number of invoices that should be
|
||||
// starting from the add index.
|
||||
NumMaxInvoices uint64
|
||||
|
||||
// PendingOnly, if set, returns unsettled invoices starting from the
|
||||
// add index.
|
||||
PendingOnly bool
|
||||
|
||||
// Reversed, if set, indicates that the invoices returned should start
|
||||
// from the IndexOffset and go backwards.
|
||||
Reversed bool
|
||||
}
|
||||
|
||||
// InvoiceSlice is the response to a invoice query. It includes the original
|
||||
// query, the set of invoices that match the query, and an integer which
|
||||
// represents the offset index of the last item in the set of returned invoices.
|
||||
// This integer allows callers to resume their query using this offset in the
|
||||
// event that the query's response exceeds the maximum number of returnable
|
||||
// invoices.
|
||||
type InvoiceSlice struct {
|
||||
InvoiceQuery
|
||||
|
||||
// Invoices is the set of invoices that matched the query above.
|
||||
Invoices []Invoice
|
||||
|
||||
// FirstIndexOffset is the index of the first element in the set of
|
||||
// returned Invoices above. Callers can use this to resume their query
|
||||
// in the event that the slice has too many events to fit into a single
|
||||
// response.
|
||||
FirstIndexOffset uint64
|
||||
|
||||
// LastIndexOffset is the index of the last element in the set of
|
||||
// returned Invoices above. Callers can use this to resume their query
|
||||
// in the event that the slice has too many events to fit into a single
|
||||
// response.
|
||||
LastIndexOffset uint64
|
||||
}
|
||||
|
||||
// QueryInvoices allows a caller to query the invoice database for invoices
|
||||
// within the specified add index range.
|
||||
func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
|
||||
resp := InvoiceSlice{
|
||||
InvoiceQuery: q,
|
||||
}
|
||||
|
||||
err := d.View(func(tx *bbolt.Tx) error {
|
||||
// If the bucket wasn't found, then there aren't any invoices
|
||||
// within the database yet, so we can simply exit.
|
||||
invoices := tx.Bucket(invoiceBucket)
|
||||
if invoices == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
invoiceAddIndex := invoices.Bucket(addIndexBucket)
|
||||
if invoiceAddIndex == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
// keyForIndex is a helper closure that retrieves the invoice
|
||||
// key for the given add index of an invoice.
|
||||
keyForIndex := func(c *bbolt.Cursor, index uint64) []byte {
|
||||
var keyIndex [8]byte
|
||||
byteOrder.PutUint64(keyIndex[:], index)
|
||||
_, invoiceKey := c.Seek(keyIndex[:])
|
||||
return invoiceKey
|
||||
}
|
||||
|
||||
// nextKey is a helper closure to determine what the next
|
||||
// invoice key is when iterating over the invoice add index.
|
||||
nextKey := func(c *bbolt.Cursor) ([]byte, []byte) {
|
||||
if q.Reversed {
|
||||
return c.Prev()
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// We'll be using a cursor to seek into the database and return
|
||||
// a slice of invoices. We'll need to determine where to start
|
||||
// our cursor depending on the parameters set within the query.
|
||||
c := invoiceAddIndex.Cursor()
|
||||
invoiceKey := keyForIndex(c, q.IndexOffset+1)
|
||||
|
||||
// If the query is specifying reverse iteration, then we must
|
||||
// handle a few offset cases.
|
||||
if q.Reversed {
|
||||
switch q.IndexOffset {
|
||||
|
||||
// This indicates the default case, where no offset was
|
||||
// specified. In that case we just start from the last
|
||||
// invoice.
|
||||
case 0:
|
||||
_, invoiceKey = c.Last()
|
||||
|
||||
// This indicates the offset being set to the very
|
||||
// first invoice. Since there are no invoices before
|
||||
// this offset, and the direction is reversed, we can
|
||||
// return without adding any invoices to the response.
|
||||
case 1:
|
||||
return nil
|
||||
|
||||
// Otherwise we start iteration at the invoice prior to
|
||||
// the offset.
|
||||
default:
|
||||
invoiceKey = keyForIndex(c, q.IndexOffset-1)
|
||||
}
|
||||
}
|
||||
|
||||
// If we know that a set of invoices exists, then we'll begin
|
||||
// our seek through the bucket in order to satisfy the query.
|
||||
// We'll continue until either we reach the end of the range, or
|
||||
// reach our max number of invoices.
|
||||
for ; invoiceKey != nil; _, invoiceKey = nextKey(c) {
|
||||
// If our current return payload exceeds the max number
|
||||
// of invoices, then we'll exit now.
|
||||
if uint64(len(resp.Invoices)) >= q.NumMaxInvoices {
|
||||
break
|
||||
}
|
||||
|
||||
invoice, err := fetchInvoice(invoiceKey, invoices)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip any settled invoices if the caller is only
|
||||
// interested in unsettled.
|
||||
if q.PendingOnly &&
|
||||
invoice.Terms.State == ContractSettled {
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// At this point, we've exhausted the offset, so we'll
|
||||
// begin collecting invoices found within the range.
|
||||
resp.Invoices = append(resp.Invoices, invoice)
|
||||
}
|
||||
|
||||
// If we iterated through the add index in reverse order, then
|
||||
// we'll need to reverse the slice of invoices to return them in
|
||||
// forward order.
|
||||
if q.Reversed {
|
||||
numInvoices := len(resp.Invoices)
|
||||
for i := 0; i < numInvoices/2; i++ {
|
||||
opposite := numInvoices - i - 1
|
||||
resp.Invoices[i], resp.Invoices[opposite] =
|
||||
resp.Invoices[opposite], resp.Invoices[i]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil && err != ErrNoInvoicesCreated {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Finally, record the indexes of the first and last invoices returned
|
||||
// so that the caller can resume from this point later on.
|
||||
if len(resp.Invoices) > 0 {
|
||||
resp.FirstIndexOffset = resp.Invoices[0].AddIndex
|
||||
resp.LastIndexOffset = resp.Invoices[len(resp.Invoices)-1].AddIndex
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// UpdateInvoice attempts to update an invoice corresponding to the passed
|
||||
// payment hash. If an invoice matching the passed payment hash doesn't exist
|
||||
// within the database, then the action will fail with a "not found" error.
|
||||
//
|
||||
// The update is performed inside the same database transaction that fetches the
|
||||
// invoice and is therefore atomic. The fields to update are controlled by the
|
||||
// supplied callback.
|
||||
func (d *DB) UpdateInvoice(paymentHash lntypes.Hash,
|
||||
callback InvoiceUpdateCallback) (*Invoice, error) {
|
||||
|
||||
var updatedInvoice *Invoice
|
||||
err := d.Update(func(tx *bbolt.Tx) error {
|
||||
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
invoiceIndex, err := invoices.CreateBucketIfNotExists(
|
||||
invoiceIndexBucket,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
settleIndex, err := invoices.CreateBucketIfNotExists(
|
||||
settleIndexBucket,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check the invoice index to see if an invoice paying to this
|
||||
// hash exists within the DB.
|
||||
invoiceNum := invoiceIndex.Get(paymentHash[:])
|
||||
if invoiceNum == nil {
|
||||
return ErrInvoiceNotFound
|
||||
}
|
||||
|
||||
updatedInvoice, err = d.updateInvoice(
|
||||
paymentHash, invoices, settleIndex, invoiceNum,
|
||||
callback,
|
||||
)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return updatedInvoice, err
|
||||
}
|
||||
|
||||
// InvoicesSettledSince can be used by callers to catch up any settled invoices
|
||||
// they missed within the settled invoice time series. We'll return all known
|
||||
// settled invoice that have a settle index higher than the passed
|
||||
// sinceSettleIndex.
|
||||
//
|
||||
// NOTE: The index starts from 1, as a result. We enforce that specifying a
|
||||
// value below the starting index value is a noop.
|
||||
func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
|
||||
var settledInvoices []Invoice
|
||||
|
||||
// If an index of zero was specified, then in order to maintain
|
||||
// backwards compat, we won't send out any new invoices.
|
||||
if sinceSettleIndex == 0 {
|
||||
return settledInvoices, nil
|
||||
}
|
||||
|
||||
var startIndex [8]byte
|
||||
byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
|
||||
|
||||
err := d.DB.View(func(tx *bbolt.Tx) error {
|
||||
invoices := tx.Bucket(invoiceBucket)
|
||||
if invoices == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
settleIndex := invoices.Bucket(settleIndexBucket)
|
||||
if settleIndex == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
// We'll now run through each entry in the add index starting
|
||||
// at our starting index. We'll continue until we reach the
|
||||
// very end of the current key space.
|
||||
invoiceCursor := settleIndex.Cursor()
|
||||
|
||||
// We'll seek to the starting index, then manually advance the
|
||||
// cursor in order to skip the entry with the since add index.
|
||||
invoiceCursor.Seek(startIndex[:])
|
||||
seqNo, invoiceKey := invoiceCursor.Next()
|
||||
|
||||
for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, invoiceKey = invoiceCursor.Next() {
|
||||
|
||||
// For each key found, we'll look up the actual
|
||||
// invoice, then accumulate it into our return value.
|
||||
invoice, err := fetchInvoice(invoiceKey, invoices)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settledInvoices = append(settledInvoices, invoice)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return settledInvoices, nil
|
||||
}
|
||||
|
||||
func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket,
|
||||
i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
|
||||
uint64, error) {
|
||||
|
||||
// Create the invoice key which is just the big-endian representation
|
||||
// of the invoice number.
|
||||
var invoiceKey [4]byte
|
||||
byteOrder.PutUint32(invoiceKey[:], invoiceNum)
|
||||
|
||||
// Increment the num invoice counter index so the next invoice bares
|
||||
// the proper ID.
|
||||
var scratch [4]byte
|
||||
invoiceCounter := invoiceNum + 1
|
||||
byteOrder.PutUint32(scratch[:], invoiceCounter)
|
||||
if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Add the payment hash to the invoice index. This will let us quickly
|
||||
// identify if we can settle an incoming payment, and also to possibly
|
||||
// allow a single invoice to have multiple payment installations.
|
||||
err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Next, we'll obtain the next add invoice index (sequence
|
||||
// number), so we can properly place this invoice within this
|
||||
// event stream.
|
||||
nextAddSeqNo, err := addIndex.NextSequence()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// With the next sequence obtained, we'll updating the event series in
|
||||
// the add index bucket to map this current add counter to the index of
|
||||
// this new invoice.
|
||||
var seqNoBytes [8]byte
|
||||
byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
|
||||
if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
i.AddIndex = nextAddSeqNo
|
||||
|
||||
// Finally, serialize the invoice itself to be written to the disk.
|
||||
var buf bytes.Buffer
|
||||
if err := serializeInvoice(&buf, i); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return nextAddSeqNo, nil
|
||||
}
|
||||
|
||||
// serializeInvoice serializes an invoice to a writer.
|
||||
//
|
||||
// Note: this function is in use for a migration. Before making changes that
|
||||
|
@ -1006,17 +410,6 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func fetchInvoice(invoiceNum []byte, invoices *bbolt.Bucket) (Invoice, error) {
|
||||
invoiceBytes := invoices.Get(invoiceNum)
|
||||
if invoiceBytes == nil {
|
||||
return Invoice{}, ErrInvoiceNotFound
|
||||
}
|
||||
|
||||
invoiceReader := bytes.NewReader(invoiceBytes)
|
||||
|
||||
return deserializeInvoice(invoiceReader)
|
||||
}
|
||||
|
||||
func deserializeInvoice(r io.Reader) (Invoice, error) {
|
||||
var err error
|
||||
invoice := Invoice{}
|
||||
|
@ -1155,166 +548,3 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
|
|||
|
||||
return htlcs, nil
|
||||
}
|
||||
|
||||
// copySlice allocates a new slice and copies the source into it.
|
||||
func copySlice(src []byte) []byte {
|
||||
dest := make([]byte, len(src))
|
||||
copy(dest, src)
|
||||
return dest
|
||||
}
|
||||
|
||||
// copyInvoice makes a deep copy of the supplied invoice.
|
||||
func copyInvoice(src *Invoice) *Invoice {
|
||||
dest := Invoice{
|
||||
Memo: copySlice(src.Memo),
|
||||
Receipt: copySlice(src.Receipt),
|
||||
PaymentRequest: copySlice(src.PaymentRequest),
|
||||
FinalCltvDelta: src.FinalCltvDelta,
|
||||
CreationDate: src.CreationDate,
|
||||
SettleDate: src.SettleDate,
|
||||
Terms: src.Terms,
|
||||
AddIndex: src.AddIndex,
|
||||
SettleIndex: src.SettleIndex,
|
||||
AmtPaid: src.AmtPaid,
|
||||
Htlcs: make(
|
||||
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
|
||||
),
|
||||
}
|
||||
|
||||
for k, v := range src.Htlcs {
|
||||
dest.Htlcs[k] = v
|
||||
}
|
||||
|
||||
return &dest
|
||||
}
|
||||
|
||||
// updateInvoice fetches the invoice, obtains the update descriptor from the
|
||||
// callback and applies the updates in a single db transaction.
|
||||
func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucket,
|
||||
invoiceNum []byte, callback InvoiceUpdateCallback) (*Invoice, error) {
|
||||
|
||||
invoice, err := fetchInvoice(invoiceNum, invoices)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
preUpdateState := invoice.Terms.State
|
||||
|
||||
// Create deep copy to prevent any accidental modification in the
|
||||
// callback.
|
||||
copy := copyInvoice(&invoice)
|
||||
|
||||
// Call the callback and obtain the update descriptor.
|
||||
update, err := callback(copy)
|
||||
if err != nil {
|
||||
return &invoice, err
|
||||
}
|
||||
|
||||
// Update invoice state.
|
||||
invoice.Terms.State = update.State
|
||||
|
||||
now := d.now()
|
||||
|
||||
// Update htlc set.
|
||||
for key, htlcUpdate := range update.Htlcs {
|
||||
htlc, ok := invoice.Htlcs[key]
|
||||
|
||||
// No update means the htlc needs to be canceled.
|
||||
if htlcUpdate == nil {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown htlc %v", key)
|
||||
}
|
||||
if htlc.State != HtlcStateAccepted {
|
||||
return nil, fmt.Errorf("can only cancel " +
|
||||
"accepted htlcs")
|
||||
}
|
||||
|
||||
htlc.State = HtlcStateCanceled
|
||||
htlc.ResolveTime = now
|
||||
invoice.AmtPaid -= htlc.Amt
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Add new htlc paying to the invoice.
|
||||
if ok {
|
||||
return nil, fmt.Errorf("htlc %v already exists", key)
|
||||
}
|
||||
htlc = &InvoiceHTLC{
|
||||
Amt: htlcUpdate.Amt,
|
||||
Expiry: htlcUpdate.Expiry,
|
||||
AcceptHeight: uint32(htlcUpdate.AcceptHeight),
|
||||
AcceptTime: now,
|
||||
}
|
||||
if preUpdateState == ContractSettled {
|
||||
htlc.State = HtlcStateSettled
|
||||
htlc.ResolveTime = now
|
||||
} else {
|
||||
htlc.State = HtlcStateAccepted
|
||||
}
|
||||
|
||||
invoice.Htlcs[key] = htlc
|
||||
invoice.AmtPaid += htlc.Amt
|
||||
}
|
||||
|
||||
// If invoice moved to the settled state, update settle index and settle
|
||||
// time.
|
||||
if preUpdateState != invoice.Terms.State &&
|
||||
invoice.Terms.State == ContractSettled {
|
||||
|
||||
if update.Preimage.Hash() != hash {
|
||||
return nil, fmt.Errorf("preimage does not match")
|
||||
}
|
||||
invoice.Terms.PaymentPreimage = update.Preimage
|
||||
|
||||
// Settle all accepted htlcs.
|
||||
for _, htlc := range invoice.Htlcs {
|
||||
if htlc.State != HtlcStateAccepted {
|
||||
continue
|
||||
}
|
||||
|
||||
htlc.State = HtlcStateSettled
|
||||
htlc.ResolveTime = now
|
||||
}
|
||||
|
||||
err := setSettleFields(settleIndex, invoiceNum, &invoice, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := serializeInvoice(&buf, &invoice); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := invoices.Put(invoiceNum[:], buf.Bytes()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &invoice, nil
|
||||
}
|
||||
|
||||
func setSettleFields(settleIndex *bbolt.Bucket, invoiceNum []byte,
|
||||
invoice *Invoice, now time.Time) error {
|
||||
|
||||
// Now that we know the invoice hasn't already been settled, we'll
|
||||
// update the settle index so we can place this settle event in the
|
||||
// proper location within our time series.
|
||||
nextSettleSeqNo, err := settleIndex.NextSequence()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var seqNoBytes [8]byte
|
||||
byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
|
||||
if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
invoice.Terms.State = ContractSettled
|
||||
invoice.SettleDate = now
|
||||
invoice.SettleIndex = nextSettleSeqNo
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,316 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/coreos/bbolt"
|
||||
)
|
||||
|
||||
var (
|
||||
// nodeInfoBucket stores metadata pertaining to nodes that we've had
|
||||
// direct channel-based correspondence with. This bucket allows one to
|
||||
// query for all open channels pertaining to the node by exploring each
|
||||
// node's sub-bucket within the openChanBucket.
|
||||
nodeInfoBucket = []byte("nib")
|
||||
)
|
||||
|
||||
// LinkNode stores metadata related to node's that we have/had a direct
|
||||
// channel open with. Information such as the Bitcoin network the node
|
||||
// advertised, and its identity public key are also stored. Additionally, this
|
||||
// struct and the bucket its stored within have store data similar to that of
|
||||
// Bitcoin's addrmanager. The TCP address information stored within the struct
|
||||
// can be used to establish persistent connections will all channel
|
||||
// counterparties on daemon startup.
|
||||
//
|
||||
// TODO(roasbeef): also add current OnionKey plus rotation schedule?
|
||||
// TODO(roasbeef): add bitfield for supported services
|
||||
// * possibly add a wire.NetAddress type, type
|
||||
type LinkNode struct {
|
||||
// Network indicates the Bitcoin network that the LinkNode advertises
|
||||
// for incoming channel creation.
|
||||
Network wire.BitcoinNet
|
||||
|
||||
// IdentityPub is the node's current identity public key. Any
|
||||
// channel/topology related information received by this node MUST be
|
||||
// signed by this public key.
|
||||
IdentityPub *btcec.PublicKey
|
||||
|
||||
// LastSeen tracks the last time this node was seen within the network.
|
||||
// A node should be marked as seen if the daemon either is able to
|
||||
// establish an outgoing connection to the node or receives a new
|
||||
// incoming connection from the node. This timestamp (stored in unix
|
||||
// epoch) may be used within a heuristic which aims to determine when a
|
||||
// channel should be unilaterally closed due to inactivity.
|
||||
//
|
||||
// TODO(roasbeef): replace with block hash/height?
|
||||
// * possibly add a time-value metric into the heuristic?
|
||||
LastSeen time.Time
|
||||
|
||||
// Addresses is a list of IP address in which either we were able to
|
||||
// reach the node over in the past, OR we received an incoming
|
||||
// authenticated connection for the stored identity public key.
|
||||
Addresses []net.Addr
|
||||
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewLinkNode creates a new LinkNode from the provided parameters, which is
|
||||
// backed by an instance of channeldb.
|
||||
func (db *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey,
|
||||
addrs ...net.Addr) *LinkNode {
|
||||
|
||||
return &LinkNode{
|
||||
Network: bitNet,
|
||||
IdentityPub: pub,
|
||||
LastSeen: time.Now(),
|
||||
Addresses: addrs,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLastSeen updates the last time this node was directly encountered on
|
||||
// the Lightning Network.
|
||||
func (l *LinkNode) UpdateLastSeen(lastSeen time.Time) error {
|
||||
l.LastSeen = lastSeen
|
||||
|
||||
return l.Sync()
|
||||
}
|
||||
|
||||
// AddAddress appends the specified TCP address to the list of known addresses
|
||||
// this node is/was known to be reachable at.
|
||||
func (l *LinkNode) AddAddress(addr net.Addr) error {
|
||||
for _, a := range l.Addresses {
|
||||
if a.String() == addr.String() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
l.Addresses = append(l.Addresses, addr)
|
||||
|
||||
return l.Sync()
|
||||
}
|
||||
|
||||
// Sync performs a full database sync which writes the current up-to-date data
|
||||
// within the struct to the database.
|
||||
func (l *LinkNode) Sync() error {
|
||||
|
||||
// Finally update the database by storing the link node and updating
|
||||
// any relevant indexes.
|
||||
return l.db.Update(func(tx *bbolt.Tx) error {
|
||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
||||
if nodeMetaBucket == nil {
|
||||
return ErrLinkNodesNotFound
|
||||
}
|
||||
|
||||
return putLinkNode(nodeMetaBucket, l)
|
||||
})
|
||||
}
|
||||
|
||||
// putLinkNode serializes then writes the encoded version of the passed link
|
||||
// node into the nodeMetaBucket. This function is provided in order to allow
|
||||
// the ability to re-use a database transaction across many operations.
|
||||
func putLinkNode(nodeMetaBucket *bbolt.Bucket, l *LinkNode) error {
|
||||
// First serialize the LinkNode into its raw-bytes encoding.
|
||||
var b bytes.Buffer
|
||||
if err := serializeLinkNode(&b, l); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally insert the link-node into the node metadata bucket keyed
|
||||
// according to the its pubkey serialized in compressed form.
|
||||
nodePub := l.IdentityPub.SerializeCompressed()
|
||||
return nodeMetaBucket.Put(nodePub, b.Bytes())
|
||||
}
|
||||
|
||||
// DeleteLinkNode removes the link node with the given identity from the
|
||||
// database.
|
||||
func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error {
|
||||
return db.Update(func(tx *bbolt.Tx) error {
|
||||
return db.deleteLinkNode(tx, identity)
|
||||
})
|
||||
}
|
||||
|
||||
func (db *DB) deleteLinkNode(tx *bbolt.Tx, identity *btcec.PublicKey) error {
|
||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
||||
if nodeMetaBucket == nil {
|
||||
return ErrLinkNodesNotFound
|
||||
}
|
||||
|
||||
pubKey := identity.SerializeCompressed()
|
||||
return nodeMetaBucket.Delete(pubKey)
|
||||
}
|
||||
|
||||
// FetchLinkNode attempts to lookup the data for a LinkNode based on a target
|
||||
// identity public key. If a particular LinkNode for the passed identity public
|
||||
// key cannot be found, then ErrNodeNotFound if returned.
|
||||
func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
|
||||
var linkNode *LinkNode
|
||||
err := db.View(func(tx *bbolt.Tx) error {
|
||||
node, err := fetchLinkNode(tx, identity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
linkNode = node
|
||||
return nil
|
||||
})
|
||||
|
||||
return linkNode, err
|
||||
}
|
||||
|
||||
func fetchLinkNode(tx *bbolt.Tx, targetPub *btcec.PublicKey) (*LinkNode, error) {
|
||||
// First fetch the bucket for storing node metadata, bailing out early
|
||||
// if it hasn't been created yet.
|
||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
||||
if nodeMetaBucket == nil {
|
||||
return nil, ErrLinkNodesNotFound
|
||||
}
|
||||
|
||||
// If a link node for that particular public key cannot be located,
|
||||
// then exit early with an ErrNodeNotFound.
|
||||
pubKey := targetPub.SerializeCompressed()
|
||||
nodeBytes := nodeMetaBucket.Get(pubKey)
|
||||
if nodeBytes == nil {
|
||||
return nil, ErrNodeNotFound
|
||||
}
|
||||
|
||||
// Finally, decode and allocate a fresh LinkNode object to be returned
|
||||
// to the caller.
|
||||
nodeReader := bytes.NewReader(nodeBytes)
|
||||
return deserializeLinkNode(nodeReader)
|
||||
}
|
||||
|
||||
// TODO(roasbeef): update link node addrs in server upon connection
|
||||
|
||||
// FetchAllLinkNodes starts a new database transaction to fetch all nodes with
|
||||
// whom we have active channels with.
|
||||
func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
|
||||
var linkNodes []*LinkNode
|
||||
err := db.View(func(tx *bbolt.Tx) error {
|
||||
nodes, err := db.fetchAllLinkNodes(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
linkNodes = nodes
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return linkNodes, nil
|
||||
}
|
||||
|
||||
// fetchAllLinkNodes uses an existing database transaction to fetch all nodes
|
||||
// with whom we have active channels with.
|
||||
func (db *DB) fetchAllLinkNodes(tx *bbolt.Tx) ([]*LinkNode, error) {
|
||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
||||
if nodeMetaBucket == nil {
|
||||
return nil, ErrLinkNodesNotFound
|
||||
}
|
||||
|
||||
var linkNodes []*LinkNode
|
||||
err := nodeMetaBucket.ForEach(func(k, v []byte) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
nodeReader := bytes.NewReader(v)
|
||||
linkNode, err := deserializeLinkNode(nodeReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
linkNodes = append(linkNodes, linkNode)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return linkNodes, nil
|
||||
}
|
||||
|
||||
func serializeLinkNode(w io.Writer, l *LinkNode) error {
|
||||
var buf [8]byte
|
||||
|
||||
byteOrder.PutUint32(buf[:4], uint32(l.Network))
|
||||
if _, err := w.Write(buf[:4]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serializedID := l.IdentityPub.SerializeCompressed()
|
||||
if _, err := w.Write(serializedID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
seenUnix := uint64(l.LastSeen.Unix())
|
||||
byteOrder.PutUint64(buf[:], seenUnix)
|
||||
if _, err := w.Write(buf[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
numAddrs := uint32(len(l.Addresses))
|
||||
byteOrder.PutUint32(buf[:4], numAddrs)
|
||||
if _, err := w.Write(buf[:4]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, addr := range l.Addresses {
|
||||
if err := serializeAddr(w, addr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deserializeLinkNode(r io.Reader) (*LinkNode, error) {
|
||||
var (
|
||||
err error
|
||||
buf [8]byte
|
||||
)
|
||||
|
||||
node := &LinkNode{}
|
||||
|
||||
if _, err := io.ReadFull(r, buf[:4]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node.Network = wire.BitcoinNet(byteOrder.Uint32(buf[:4]))
|
||||
|
||||
var pub [33]byte
|
||||
if _, err := io.ReadFull(r, pub[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node.IdentityPub, err = btcec.ParsePubKey(pub[:], btcec.S256())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r, buf[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node.LastSeen = time.Unix(int64(byteOrder.Uint64(buf[:])), 0)
|
||||
|
||||
if _, err := io.ReadFull(r, buf[:4]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
numAddrs := byteOrder.Uint32(buf[:4])
|
||||
|
||||
node.Addresses = make([]net.Addr, numAddrs)
|
||||
for i := uint32(0); i < numAddrs; i++ {
|
||||
addr, err := deserializeAddr(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node.Addresses[i] = addr
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
|
@ -1,140 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
)
|
||||
|
||||
func TestLinkNodeEncodeDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// First we'll create some initial data to use for populating our test
|
||||
// LinkNode instances.
|
||||
_, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
|
||||
_, pub2 := btcec.PrivKeyFromBytes(btcec.S256(), rev[:])
|
||||
addr1, err := net.ResolveTCPAddr("tcp", "10.0.0.1:9000")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test addr: %v", err)
|
||||
}
|
||||
addr2, err := net.ResolveTCPAddr("tcp", "10.0.0.2:9000")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test addr: %v", err)
|
||||
}
|
||||
|
||||
// Create two fresh link node instances with the above dummy data, then
|
||||
// fully sync both instances to disk.
|
||||
node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1)
|
||||
node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2)
|
||||
if err := node1.Sync(); err != nil {
|
||||
t.Fatalf("unable to sync node: %v", err)
|
||||
}
|
||||
if err := node2.Sync(); err != nil {
|
||||
t.Fatalf("unable to sync node: %v", err)
|
||||
}
|
||||
|
||||
// Fetch all current link nodes from the database, they should exactly
|
||||
// match the two created above.
|
||||
originalNodes := []*LinkNode{node2, node1}
|
||||
linkNodes, err := cdb.FetchAllLinkNodes()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch nodes: %v", err)
|
||||
}
|
||||
for i, node := range linkNodes {
|
||||
if originalNodes[i].Network != node.Network {
|
||||
t.Fatalf("node networks don't match: expected %v, got %v",
|
||||
originalNodes[i].Network, node.Network)
|
||||
}
|
||||
|
||||
originalPubkey := originalNodes[i].IdentityPub.SerializeCompressed()
|
||||
dbPubkey := node.IdentityPub.SerializeCompressed()
|
||||
if !bytes.Equal(originalPubkey, dbPubkey) {
|
||||
t.Fatalf("node pubkeys don't match: expected %x, got %x",
|
||||
originalPubkey, dbPubkey)
|
||||
}
|
||||
if originalNodes[i].LastSeen.Unix() != node.LastSeen.Unix() {
|
||||
t.Fatalf("last seen timestamps don't match: expected %v got %v",
|
||||
originalNodes[i].LastSeen.Unix(), node.LastSeen.Unix())
|
||||
}
|
||||
if originalNodes[i].Addresses[0].String() != node.Addresses[0].String() {
|
||||
t.Fatalf("addresses don't match: expected %v, got %v",
|
||||
originalNodes[i].Addresses, node.Addresses)
|
||||
}
|
||||
}
|
||||
|
||||
// Next, we'll exercise the methods to append additional IP
|
||||
// addresses, and also to update the last seen time.
|
||||
if err := node1.UpdateLastSeen(time.Now()); err != nil {
|
||||
t.Fatalf("unable to update last seen: %v", err)
|
||||
}
|
||||
if err := node1.AddAddress(addr2); err != nil {
|
||||
t.Fatalf("unable to update addr: %v", err)
|
||||
}
|
||||
|
||||
// Fetch the same node from the database according to its public key.
|
||||
node1DB, err := cdb.FetchLinkNode(pub1)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to find node: %v", err)
|
||||
}
|
||||
|
||||
// Both the last seen timestamp and the list of reachable addresses for
|
||||
// the node should be updated.
|
||||
if node1DB.LastSeen.Unix() != node1.LastSeen.Unix() {
|
||||
t.Fatalf("last seen timestamps don't match: expected %v got %v",
|
||||
node1.LastSeen.Unix(), node1DB.LastSeen.Unix())
|
||||
}
|
||||
if len(node1DB.Addresses) != 2 {
|
||||
t.Fatalf("wrong length for node1 addresses: expected %v, got %v",
|
||||
2, len(node1DB.Addresses))
|
||||
}
|
||||
if node1DB.Addresses[0].String() != addr1.String() {
|
||||
t.Fatalf("wrong address for node: expected %v, got %v",
|
||||
addr1.String(), node1DB.Addresses[0].String())
|
||||
}
|
||||
if node1DB.Addresses[1].String() != addr2.String() {
|
||||
t.Fatalf("wrong address for node: expected %v, got %v",
|
||||
addr2.String(), node1DB.Addresses[1].String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteLinkNode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
_, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 1337,
|
||||
}
|
||||
linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr)
|
||||
if err := linkNode.Sync(); err != nil {
|
||||
t.Fatalf("unable to write link node to db: %v", err)
|
||||
}
|
||||
|
||||
if _, err := cdb.FetchLinkNode(pubKey); err != nil {
|
||||
t.Fatalf("unable to find link node: %v", err)
|
||||
}
|
||||
|
||||
if err := cdb.DeleteLinkNode(pubKey); err != nil {
|
||||
t.Fatalf("unable to delete link node from db: %v", err)
|
||||
}
|
||||
|
||||
if _, err := cdb.FetchLinkNode(pubKey); err == nil {
|
||||
t.Fatal("should not have found link node in db, but did")
|
||||
}
|
||||
}
|
|
@ -39,24 +39,3 @@ func DefaultOptions() Options {
|
|||
|
||||
// OptionModifier is a function signature for modifying the default Options.
|
||||
type OptionModifier func(*Options)
|
||||
|
||||
// OptionSetRejectCacheSize sets the RejectCacheSize to n.
|
||||
func OptionSetRejectCacheSize(n int) OptionModifier {
|
||||
return func(o *Options) {
|
||||
o.RejectCacheSize = n
|
||||
}
|
||||
}
|
||||
|
||||
// OptionSetChannelCacheSize sets the ChannelCacheSize to n.
|
||||
func OptionSetChannelCacheSize(n int) OptionModifier {
|
||||
return func(o *Options) {
|
||||
o.ChannelCacheSize = n
|
||||
}
|
||||
}
|
||||
|
||||
// OptionSetSyncFreelist allows the database to sync its freelist.
|
||||
func OptionSetSyncFreelist(b bool) OptionModifier {
|
||||
return func(o *Options) {
|
||||
o.NoFreelistSync = !b
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,373 +1,9 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrAlreadyPaid signals we have already paid this payment hash.
|
||||
ErrAlreadyPaid = errors.New("invoice is already paid")
|
||||
|
||||
// ErrPaymentInFlight signals that payment for this payment hash is
|
||||
// already "in flight" on the network.
|
||||
ErrPaymentInFlight = errors.New("payment is in transition")
|
||||
|
||||
// ErrPaymentNotInitiated is returned if payment wasn't initiated in
|
||||
// switch.
|
||||
ErrPaymentNotInitiated = errors.New("payment isn't initiated")
|
||||
|
||||
// ErrPaymentAlreadySucceeded is returned in the event we attempt to
|
||||
// change the status of a payment already succeeded.
|
||||
ErrPaymentAlreadySucceeded = errors.New("payment is already succeeded")
|
||||
|
||||
// ErrPaymentAlreadyFailed is returned in the event we attempt to
|
||||
// re-fail a failed payment.
|
||||
ErrPaymentAlreadyFailed = errors.New("payment has already failed")
|
||||
|
||||
// ErrUnknownPaymentStatus is returned when we do not recognize the
|
||||
// existing state of a payment.
|
||||
ErrUnknownPaymentStatus = errors.New("unknown payment status")
|
||||
|
||||
// errNoAttemptInfo is returned when no attempt info is stored yet.
|
||||
errNoAttemptInfo = errors.New("unable to find attempt info for " +
|
||||
"inflight payment")
|
||||
)
|
||||
|
||||
// PaymentControl implements persistence for payments and payment attempts.
|
||||
type PaymentControl struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewPaymentControl creates a new instance of the PaymentControl.
|
||||
func NewPaymentControl(db *DB) *PaymentControl {
|
||||
return &PaymentControl{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// InitPayment checks or records the given PaymentCreationInfo with the DB,
|
||||
// making sure it does not already exist as an in-flight payment. Then this
|
||||
// method returns successfully, the payment is guranteeed to be in the InFlight
|
||||
// state.
|
||||
func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
|
||||
info *PaymentCreationInfo) error {
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := serializePaymentCreationInfo(&b, info); err != nil {
|
||||
return err
|
||||
}
|
||||
infoBytes := b.Bytes()
|
||||
|
||||
var updateErr error
|
||||
err := p.db.Batch(func(tx *bbolt.Tx) error {
|
||||
// Reset the update error, to avoid carrying over an error
|
||||
// from a previous execution of the batched db transaction.
|
||||
updateErr = nil
|
||||
|
||||
bucket, err := createPaymentBucket(tx, paymentHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the existing status of this payment, if any.
|
||||
paymentStatus := fetchPaymentStatus(bucket)
|
||||
|
||||
switch paymentStatus {
|
||||
|
||||
// We allow retrying failed payments.
|
||||
case StatusFailed:
|
||||
|
||||
// This is a new payment that is being initialized for the
|
||||
// first time.
|
||||
case StatusUnknown:
|
||||
|
||||
// We already have an InFlight payment on the network. We will
|
||||
// disallow any new payments.
|
||||
case StatusInFlight:
|
||||
updateErr = ErrPaymentInFlight
|
||||
return nil
|
||||
|
||||
// We've already succeeded a payment to this payment hash,
|
||||
// forbid the switch from sending another.
|
||||
case StatusSucceeded:
|
||||
updateErr = ErrAlreadyPaid
|
||||
return nil
|
||||
|
||||
default:
|
||||
updateErr = ErrUnknownPaymentStatus
|
||||
return nil
|
||||
}
|
||||
|
||||
// Obtain a new sequence number for this payment. This is used
|
||||
// to sort the payments in order of creation, and also acts as
|
||||
// a unique identifier for each payment.
|
||||
sequenceNum, err := nextPaymentSequence(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = bucket.Put(paymentSequenceKey, sequenceNum)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the payment info to the bucket, which contains the
|
||||
// static information for this payment
|
||||
err = bucket.Put(paymentCreationInfoKey, infoBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We'll delete any lingering attempt info to start with, in
|
||||
// case we are initializing a payment that was attempted
|
||||
// earlier, but left in a state where we could retry.
|
||||
err = bucket.Delete(paymentAttemptInfoKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Also delete any lingering failure info now that we are
|
||||
// re-attempting.
|
||||
return bucket.Delete(paymentFailInfoKey)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return updateErr
|
||||
}
|
||||
|
||||
// RegisterAttempt atomically records the provided PaymentAttemptInfo to the
|
||||
// DB.
|
||||
func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
||||
attempt *PaymentAttemptInfo) error {
|
||||
|
||||
// Serialize the information before opening the db transaction.
|
||||
var a bytes.Buffer
|
||||
if err := serializePaymentAttemptInfo(&a, attempt); err != nil {
|
||||
return err
|
||||
}
|
||||
attemptBytes := a.Bytes()
|
||||
|
||||
var updateErr error
|
||||
err := p.db.Batch(func(tx *bbolt.Tx) error {
|
||||
// Reset the update error, to avoid carrying over an error
|
||||
// from a previous execution of the batched db transaction.
|
||||
updateErr = nil
|
||||
|
||||
bucket, err := fetchPaymentBucket(tx, paymentHash)
|
||||
if err == ErrPaymentNotInitiated {
|
||||
updateErr = ErrPaymentNotInitiated
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We can only register attempts for payments that are
|
||||
// in-flight.
|
||||
if err := ensureInFlight(bucket); err != nil {
|
||||
updateErr = err
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add the payment attempt to the payments bucket.
|
||||
return bucket.Put(paymentAttemptInfoKey, attemptBytes)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return updateErr
|
||||
}
|
||||
|
||||
// Success transitions a payment into the Succeeded state. After invoking this
|
||||
// method, InitPayment should always return an error to prevent us from making
|
||||
// duplicate payments to the same payment hash. The provided preimage is
|
||||
// atomically saved to the DB for record keeping.
|
||||
func (p *PaymentControl) Success(paymentHash lntypes.Hash,
|
||||
preimage lntypes.Preimage) (*route.Route, error) {
|
||||
|
||||
var (
|
||||
updateErr error
|
||||
route *route.Route
|
||||
)
|
||||
err := p.db.Batch(func(tx *bbolt.Tx) error {
|
||||
// Reset the update error, to avoid carrying over an error
|
||||
// from a previous execution of the batched db transaction.
|
||||
updateErr = nil
|
||||
|
||||
bucket, err := fetchPaymentBucket(tx, paymentHash)
|
||||
if err == ErrPaymentNotInitiated {
|
||||
updateErr = ErrPaymentNotInitiated
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We can only mark in-flight payments as succeeded.
|
||||
if err := ensureInFlight(bucket); err != nil {
|
||||
updateErr = err
|
||||
return nil
|
||||
}
|
||||
|
||||
// Record the successful payment info atomically to the
|
||||
// payments record.
|
||||
err = bucket.Put(paymentSettleInfoKey, preimage[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Retrieve attempt info for the notification.
|
||||
attempt, err := fetchPaymentAttempt(bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
route = &attempt.Route
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return route, updateErr
|
||||
}
|
||||
|
||||
// Fail transitions a payment into the Failed state, and records the reason the
|
||||
// payment failed. After invoking this method, InitPayment should return nil on
|
||||
// its next call for this payment hash, allowing the switch to make a
|
||||
// subsequent payment.
|
||||
func (p *PaymentControl) Fail(paymentHash lntypes.Hash,
|
||||
reason FailureReason) (*route.Route, error) {
|
||||
|
||||
var (
|
||||
updateErr error
|
||||
route *route.Route
|
||||
)
|
||||
err := p.db.Batch(func(tx *bbolt.Tx) error {
|
||||
// Reset the update error, to avoid carrying over an error
|
||||
// from a previous execution of the batched db transaction.
|
||||
updateErr = nil
|
||||
|
||||
bucket, err := fetchPaymentBucket(tx, paymentHash)
|
||||
if err == ErrPaymentNotInitiated {
|
||||
updateErr = ErrPaymentNotInitiated
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We can only mark in-flight payments as failed.
|
||||
if err := ensureInFlight(bucket); err != nil {
|
||||
updateErr = err
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put the failure reason in the bucket for record keeping.
|
||||
v := []byte{byte(reason)}
|
||||
err = bucket.Put(paymentFailInfoKey, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Retrieve attempt info for the notification, if available.
|
||||
attempt, err := fetchPaymentAttempt(bucket)
|
||||
if err != nil && err != errNoAttemptInfo {
|
||||
return err
|
||||
}
|
||||
if err != errNoAttemptInfo {
|
||||
route = &attempt.Route
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return route, updateErr
|
||||
}
|
||||
|
||||
// FetchPayment returns information about a payment from the database.
|
||||
func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
|
||||
*Payment, error) {
|
||||
|
||||
var payment *Payment
|
||||
err := p.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket, err := fetchPaymentBucket(tx, paymentHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payment, err = fetchPayment(bucket)
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return payment, nil
|
||||
}
|
||||
|
||||
// createPaymentBucket creates or fetches the sub-bucket assigned to this
|
||||
// payment hash.
|
||||
func createPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) (
|
||||
*bbolt.Bucket, error) {
|
||||
|
||||
payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return payments.CreateBucketIfNotExists(paymentHash[:])
|
||||
}
|
||||
|
||||
// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If
|
||||
// the bucket does not exist, it returns ErrPaymentNotInitiated.
|
||||
func fetchPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) (
|
||||
*bbolt.Bucket, error) {
|
||||
|
||||
payments := tx.Bucket(paymentsRootBucket)
|
||||
if payments == nil {
|
||||
return nil, ErrPaymentNotInitiated
|
||||
}
|
||||
|
||||
bucket := payments.Bucket(paymentHash[:])
|
||||
if bucket == nil {
|
||||
return nil, ErrPaymentNotInitiated
|
||||
}
|
||||
|
||||
return bucket, nil
|
||||
|
||||
}
|
||||
|
||||
// nextPaymentSequence returns the next sequence number to store for a new
|
||||
// payment.
|
||||
func nextPaymentSequence(tx *bbolt.Tx) ([]byte, error) {
|
||||
payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seq, err := payments.NextSequence()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, seq)
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// fetchPaymentStatus fetches the payment status of the payment. If the payment
|
||||
// isn't found, it will default to "StatusUnknown".
|
||||
func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus {
|
||||
|
@ -385,113 +21,3 @@ func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus {
|
|||
|
||||
return StatusUnknown
|
||||
}
|
||||
|
||||
// ensureInFlight checks whether the payment found in the given bucket has
|
||||
// status InFlight, and returns an error otherwise. This should be used to
|
||||
// ensure we only mark in-flight payments as succeeded or failed.
|
||||
func ensureInFlight(bucket *bbolt.Bucket) error {
|
||||
paymentStatus := fetchPaymentStatus(bucket)
|
||||
|
||||
switch {
|
||||
|
||||
// The payment was indeed InFlight, return.
|
||||
case paymentStatus == StatusInFlight:
|
||||
return nil
|
||||
|
||||
// Our records show the payment as unknown, meaning it never
|
||||
// should have left the switch.
|
||||
case paymentStatus == StatusUnknown:
|
||||
return ErrPaymentNotInitiated
|
||||
|
||||
// The payment succeeded previously.
|
||||
case paymentStatus == StatusSucceeded:
|
||||
return ErrPaymentAlreadySucceeded
|
||||
|
||||
// The payment was already failed.
|
||||
case paymentStatus == StatusFailed:
|
||||
return ErrPaymentAlreadyFailed
|
||||
|
||||
default:
|
||||
return ErrUnknownPaymentStatus
|
||||
}
|
||||
}
|
||||
|
||||
// fetchPaymentAttempt fetches the payment attempt from the bucket.
|
||||
func fetchPaymentAttempt(bucket *bbolt.Bucket) (*PaymentAttemptInfo, error) {
|
||||
attemptData := bucket.Get(paymentAttemptInfoKey)
|
||||
if attemptData == nil {
|
||||
return nil, errNoAttemptInfo
|
||||
}
|
||||
|
||||
r := bytes.NewReader(attemptData)
|
||||
return deserializePaymentAttemptInfo(r)
|
||||
}
|
||||
|
||||
// InFlightPayment is a wrapper around a payment that has status InFlight.
|
||||
type InFlightPayment struct {
|
||||
// Info is the PaymentCreationInfo of the in-flight payment.
|
||||
Info *PaymentCreationInfo
|
||||
|
||||
// Attempt contains information about the last payment attempt that was
|
||||
// made to this payment hash.
|
||||
//
|
||||
// NOTE: Might be nil.
|
||||
Attempt *PaymentAttemptInfo
|
||||
}
|
||||
|
||||
// FetchInFlightPayments returns all payments with status InFlight.
|
||||
func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) {
|
||||
var inFlights []*InFlightPayment
|
||||
err := p.db.View(func(tx *bbolt.Tx) error {
|
||||
payments := tx.Bucket(paymentsRootBucket)
|
||||
if payments == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return payments.ForEach(func(k, _ []byte) error {
|
||||
bucket := payments.Bucket(k)
|
||||
if bucket == nil {
|
||||
return fmt.Errorf("non bucket element")
|
||||
}
|
||||
|
||||
// If the status is not InFlight, we can return early.
|
||||
paymentStatus := fetchPaymentStatus(bucket)
|
||||
if paymentStatus != StatusInFlight {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
inFlight = &InFlightPayment{}
|
||||
err error
|
||||
)
|
||||
|
||||
// Get the CreationInfo.
|
||||
b := bucket.Get(paymentCreationInfoKey)
|
||||
if b == nil {
|
||||
return fmt.Errorf("unable to find creation " +
|
||||
"info for inflight payment")
|
||||
}
|
||||
|
||||
r := bytes.NewReader(b)
|
||||
inFlight.Info, err = deserializePaymentCreationInfo(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now get the attempt info. It could be that there is
|
||||
// no attempt info yet.
|
||||
inFlight.Attempt, err = fetchPaymentAttempt(bucket)
|
||||
if err != nil && err != errNoAttemptInfo {
|
||||
return err
|
||||
}
|
||||
|
||||
inFlights = append(inFlights, inFlight)
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return inFlights, nil
|
||||
}
|
||||
|
|
|
@ -1,550 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/fastsha256"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
)
|
||||
|
||||
func initDB() (*DB, error) {
|
||||
tempPath, err := ioutil.TempDir("", "switchdb")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := Open(tempPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, err
|
||||
}
|
||||
|
||||
func genPreimage() ([32]byte, error) {
|
||||
var preimage [32]byte
|
||||
if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil {
|
||||
return preimage, err
|
||||
}
|
||||
return preimage, nil
|
||||
}
|
||||
|
||||
func genInfo() (*PaymentCreationInfo, *PaymentAttemptInfo,
|
||||
lntypes.Preimage, error) {
|
||||
|
||||
preimage, err := genPreimage()
|
||||
if err != nil {
|
||||
return nil, nil, preimage, fmt.Errorf("unable to "+
|
||||
"generate preimage: %v", err)
|
||||
}
|
||||
|
||||
rhash := fastsha256.Sum256(preimage[:])
|
||||
return &PaymentCreationInfo{
|
||||
PaymentHash: rhash,
|
||||
Value: 1,
|
||||
CreationDate: time.Unix(time.Now().Unix(), 0),
|
||||
PaymentRequest: []byte("hola"),
|
||||
},
|
||||
&PaymentAttemptInfo{
|
||||
PaymentID: 1,
|
||||
SessionKey: priv,
|
||||
Route: testRoute,
|
||||
}, preimage, nil
|
||||
}
|
||||
|
||||
// TestPaymentControlSwitchFail checks that payment status returns to Failed
|
||||
// status after failing, and that InitPayment allows another HTLC for the
|
||||
// same payment hash.
|
||||
func TestPaymentControlSwitchFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := initDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init db: %v", err)
|
||||
}
|
||||
|
||||
pControl := NewPaymentControl(db)
|
||||
|
||||
info, attempt, preimg, err := genInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Sends base htlc message which initiate StatusInFlight.
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Fail the payment, which should moved it to Failed.
|
||||
failReason := FailureReasonNoRoute
|
||||
_, err = pControl.Fail(info.PaymentHash, failReason)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fail payment hash: %v", err)
|
||||
}
|
||||
|
||||
// Verify the status is indeed Failed.
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusFailed)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
|
||||
&failReason,
|
||||
)
|
||||
|
||||
// Sends the htlc again, which should succeed since the prior payment
|
||||
// failed.
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Record a new attempt.
|
||||
attempt.PaymentID = 2
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, info, attempt, lntypes.Preimage{},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Verifies that status was changed to StatusSucceeded.
|
||||
var route *route.Route
|
||||
route, err = pControl.Success(info.PaymentHash, preimg)
|
||||
if err != nil {
|
||||
t.Fatalf("error shouldn't have been received, got: %v", err)
|
||||
}
|
||||
|
||||
err = assertRouteEqual(route, &attempt.Route)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected route returned: %v vs %v: %v",
|
||||
spew.Sdump(attempt.Route), spew.Sdump(*route), err)
|
||||
}
|
||||
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
|
||||
assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil)
|
||||
|
||||
// Attempt a final payment, which should now fail since the prior
|
||||
// payment succeed.
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != ErrAlreadyPaid {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaymentControlSwitchDoubleSend checks the ability of payment control to
|
||||
// prevent double sending of htlc message, when message is in StatusInFlight.
|
||||
func TestPaymentControlSwitchDoubleSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := initDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init db: %v", err)
|
||||
}
|
||||
|
||||
pControl := NewPaymentControl(db)
|
||||
|
||||
info, attempt, preimg, err := genInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Sends base htlc message which initiate base status and move it to
|
||||
// StatusInFlight and verifies that it was changed.
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Try to initiate double sending of htlc message with the same
|
||||
// payment hash, should result in error indicating that payment has
|
||||
// already been sent.
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != ErrPaymentInFlight {
|
||||
t.Fatalf("payment control wrong behaviour: " +
|
||||
"double sending must trigger ErrPaymentInFlight error")
|
||||
}
|
||||
|
||||
// Record an attempt.
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, info, attempt, lntypes.Preimage{},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Sends base htlc message which initiate StatusInFlight.
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != ErrPaymentInFlight {
|
||||
t.Fatalf("payment control wrong behaviour: " +
|
||||
"double sending must trigger ErrPaymentInFlight error")
|
||||
}
|
||||
|
||||
// After settling, the error should be ErrAlreadyPaid.
|
||||
if _, err := pControl.Success(info.PaymentHash, preimg); err != nil {
|
||||
t.Fatalf("error shouldn't have been received, got: %v", err)
|
||||
}
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
|
||||
assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil)
|
||||
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != ErrAlreadyPaid {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaymentControlSuccessesWithoutInFlight checks that the payment
|
||||
// control will disallow calls to Success when no payment is in flight.
|
||||
func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := initDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init db: %v", err)
|
||||
}
|
||||
|
||||
pControl := NewPaymentControl(db)
|
||||
|
||||
info, _, preimg, err := genInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to complete the payment should fail.
|
||||
_, err = pControl.Success(info.PaymentHash, preimg)
|
||||
if err != ErrPaymentNotInitiated {
|
||||
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
|
||||
}
|
||||
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, nil, nil, lntypes.Preimage{},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// TestPaymentControlFailsWithoutInFlight checks that a strict payment
|
||||
// control will disallow calls to Fail when no payment is in flight.
|
||||
func TestPaymentControlFailsWithoutInFlight(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := initDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init db: %v", err)
|
||||
}
|
||||
|
||||
pControl := NewPaymentControl(db)
|
||||
|
||||
info, _, _, err := genInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Calling Fail should return an error.
|
||||
_, err = pControl.Fail(info.PaymentHash, FailureReasonNoRoute)
|
||||
if err != ErrPaymentNotInitiated {
|
||||
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
|
||||
}
|
||||
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil,
|
||||
)
|
||||
}
|
||||
|
||||
// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only
|
||||
// deletes payments from the database that are not in-flight.
|
||||
func TestPaymentControlDeleteNonInFligt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := initDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init db: %v", err)
|
||||
}
|
||||
|
||||
pControl := NewPaymentControl(db)
|
||||
|
||||
payments := []struct {
|
||||
failed bool
|
||||
success bool
|
||||
}{
|
||||
{
|
||||
failed: true,
|
||||
success: false,
|
||||
},
|
||||
{
|
||||
failed: false,
|
||||
success: true,
|
||||
},
|
||||
{
|
||||
failed: false,
|
||||
success: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range payments {
|
||||
info, attempt, preimg, err := genInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Sends base htlc message which initiate StatusInFlight.
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
if p.failed {
|
||||
// Fail the payment, which should moved it to Failed.
|
||||
failReason := FailureReasonNoRoute
|
||||
_, err = pControl.Fail(info.PaymentHash, failReason)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fail payment hash: %v", err)
|
||||
}
|
||||
|
||||
// Verify the status is indeed Failed.
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusFailed)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, info, attempt,
|
||||
lntypes.Preimage{}, &failReason,
|
||||
)
|
||||
} else if p.success {
|
||||
// Verifies that status was changed to StatusSucceeded.
|
||||
_, err := pControl.Success(info.PaymentHash, preimg)
|
||||
if err != nil {
|
||||
t.Fatalf("error shouldn't have been received, got: %v", err)
|
||||
}
|
||||
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, info, attempt, preimg, nil,
|
||||
)
|
||||
} else {
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, db, info.PaymentHash, info, attempt,
|
||||
lntypes.Preimage{}, nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete payments.
|
||||
if err := db.DeletePayments(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// This should leave the in-flight payment.
|
||||
dbPayments, err := db.FetchPayments()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(dbPayments) != 1 {
|
||||
t.Fatalf("expected one payment, got %d", len(dbPayments))
|
||||
}
|
||||
|
||||
status := dbPayments[0].Status
|
||||
if status != StatusInFlight {
|
||||
t.Fatalf("expected in-fligth status, got %v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func assertPaymentStatus(t *testing.T, db *DB,
|
||||
hash [32]byte, expStatus PaymentStatus) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
var paymentStatus = StatusUnknown
|
||||
err := db.View(func(tx *bbolt.Tx) error {
|
||||
payments := tx.Bucket(paymentsRootBucket)
|
||||
if payments == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bucket := payments.Bucket(hash[:])
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the existing status of this payment, if any.
|
||||
paymentStatus = fetchPaymentStatus(bucket)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch payment status: %v", err)
|
||||
}
|
||||
|
||||
if paymentStatus != expStatus {
|
||||
t.Fatalf("payment status mismatch: expected %v, got %v",
|
||||
expStatus, paymentStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error {
|
||||
b := bucket.Get(paymentCreationInfoKey)
|
||||
switch {
|
||||
case b == nil && c == nil:
|
||||
return nil
|
||||
case b == nil:
|
||||
return fmt.Errorf("expected creation info not found")
|
||||
case c == nil:
|
||||
return fmt.Errorf("unexpected creation info found")
|
||||
}
|
||||
|
||||
r := bytes.NewReader(b)
|
||||
c2, err := deserializePaymentCreationInfo(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !reflect.DeepEqual(c, c2) {
|
||||
return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v",
|
||||
spew.Sdump(c), spew.Sdump(c2))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error {
|
||||
b := bucket.Get(paymentAttemptInfoKey)
|
||||
switch {
|
||||
case b == nil && a == nil:
|
||||
return nil
|
||||
case b == nil:
|
||||
return fmt.Errorf("expected attempt info not found")
|
||||
case a == nil:
|
||||
return fmt.Errorf("unexpected attempt info found")
|
||||
}
|
||||
|
||||
r := bytes.NewReader(b)
|
||||
a2, err := deserializePaymentAttemptInfo(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return assertRouteEqual(&a.Route, &a2.Route)
|
||||
}
|
||||
|
||||
func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error {
|
||||
zero := lntypes.Preimage{}
|
||||
b := bucket.Get(paymentSettleInfoKey)
|
||||
switch {
|
||||
case b == nil && preimg == zero:
|
||||
return nil
|
||||
case b == nil:
|
||||
return fmt.Errorf("expected preimage not found")
|
||||
case preimg == zero:
|
||||
return fmt.Errorf("unexpected preimage found")
|
||||
}
|
||||
|
||||
var pre2 lntypes.Preimage
|
||||
copy(pre2[:], b[:])
|
||||
if preimg != pre2 {
|
||||
return fmt.Errorf("Preimages don't match: %x vs %x",
|
||||
preimg, pre2)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error {
|
||||
b := bucket.Get(paymentFailInfoKey)
|
||||
switch {
|
||||
case b == nil && failReason == nil:
|
||||
return nil
|
||||
case b == nil:
|
||||
return fmt.Errorf("expected fail info not found")
|
||||
case failReason == nil:
|
||||
return fmt.Errorf("unexpected fail info found")
|
||||
}
|
||||
|
||||
failReason2 := FailureReason(b[0])
|
||||
if *failReason != failReason2 {
|
||||
return fmt.Errorf("Failure infos don't match: %v vs %v",
|
||||
*failReason, failReason2)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash,
|
||||
c *PaymentCreationInfo, a *PaymentAttemptInfo, s lntypes.Preimage,
|
||||
f *FailureReason) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
err := db.View(func(tx *bbolt.Tx) error {
|
||||
payments := tx.Bucket(paymentsRootBucket)
|
||||
if payments == nil && c == nil {
|
||||
return nil
|
||||
}
|
||||
if payments == nil {
|
||||
return fmt.Errorf("sent payments not found")
|
||||
}
|
||||
|
||||
bucket := payments.Bucket(hash[:])
|
||||
if bucket == nil && c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if bucket == nil {
|
||||
return fmt.Errorf("payment not found")
|
||||
}
|
||||
|
||||
if err := checkPaymentCreationInfo(bucket, c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checkPaymentAttemptInfo(bucket, a); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checkSettleInfo(bucket, s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checkFailInfo(bucket, f); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("assert payment info failed: %v", err)
|
||||
}
|
||||
|
||||
}
|
|
@ -375,48 +375,6 @@ func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
// DeletePayments deletes all completed and failed payments from the DB.
|
||||
func (db *DB) DeletePayments() error {
|
||||
return db.Update(func(tx *bbolt.Tx) error {
|
||||
payments := tx.Bucket(paymentsRootBucket)
|
||||
if payments == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var deleteBuckets [][]byte
|
||||
err := payments.ForEach(func(k, _ []byte) error {
|
||||
bucket := payments.Bucket(k)
|
||||
if bucket == nil {
|
||||
// We only expect sub-buckets to be found in
|
||||
// this top-level bucket.
|
||||
return fmt.Errorf("non bucket element in " +
|
||||
"payments bucket")
|
||||
}
|
||||
|
||||
// If the status is InFlight, we cannot safely delete
|
||||
// the payment information, so we return early.
|
||||
paymentStatus := fetchPaymentStatus(bucket)
|
||||
if paymentStatus == StatusInFlight {
|
||||
return nil
|
||||
}
|
||||
|
||||
deleteBuckets = append(deleteBuckets, k)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, k := range deleteBuckets {
|
||||
if err := payments.DeleteBucket(k); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
|
||||
var scratch [8]byte
|
||||
|
||||
|
|
|
@ -2,55 +2,17 @@ package migration_01_to_11
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
var (
|
||||
priv, _ = btcec.NewPrivateKey(btcec.S256())
|
||||
pub = priv.PubKey()
|
||||
|
||||
tlvBytes = []byte{1, 2, 3}
|
||||
tlvEncoder = tlv.StubEncoder(tlvBytes)
|
||||
testHop1 = &route.Hop{
|
||||
PubKeyBytes: route.NewVertex(pub),
|
||||
ChannelID: 12345,
|
||||
OutgoingTimeLock: 111,
|
||||
AmtToForward: 555,
|
||||
TLVRecords: []tlv.Record{
|
||||
tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil),
|
||||
tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil),
|
||||
},
|
||||
}
|
||||
|
||||
testHop2 = &route.Hop{
|
||||
PubKeyBytes: route.NewVertex(pub),
|
||||
ChannelID: 12345,
|
||||
OutgoingTimeLock: 111,
|
||||
AmtToForward: 555,
|
||||
LegacyPayload: true,
|
||||
}
|
||||
|
||||
testRoute = route.Route{
|
||||
TotalTimeLock: 123,
|
||||
TotalAmount: 1234567,
|
||||
SourcePubKey: route.NewVertex(pub),
|
||||
Hops: []*route.Hop{
|
||||
testHop1,
|
||||
testHop2,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func makeFakePayment() *outgoingPayment {
|
||||
|
@ -81,27 +43,6 @@ func makeFakePayment() *outgoingPayment {
|
|||
return fakePayment
|
||||
}
|
||||
|
||||
func makeFakeInfo() (*PaymentCreationInfo, *PaymentAttemptInfo) {
|
||||
var preimg lntypes.Preimage
|
||||
copy(preimg[:], rev[:])
|
||||
|
||||
c := &PaymentCreationInfo{
|
||||
PaymentHash: preimg.Hash(),
|
||||
Value: 1000,
|
||||
// Use single second precision to avoid false positive test
|
||||
// failures due to the monotonic time component.
|
||||
CreationDate: time.Unix(time.Now().Unix(), 0),
|
||||
PaymentRequest: []byte(""),
|
||||
}
|
||||
|
||||
a := &PaymentAttemptInfo{
|
||||
PaymentID: 44,
|
||||
SessionKey: priv,
|
||||
Route: testRoute,
|
||||
}
|
||||
return c, a
|
||||
}
|
||||
|
||||
// randomBytes creates random []byte with length in range [minLen, maxLen)
|
||||
func randomBytes(minLen, maxLen int) ([]byte, error) {
|
||||
randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen))
|
||||
|
@ -165,160 +106,3 @@ func makeRandomFakePayment() (*outgoingPayment, error) {
|
|||
|
||||
return fakePayment, nil
|
||||
}
|
||||
|
||||
func TestSentPaymentSerialization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, s := makeFakeInfo()
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := serializePaymentCreationInfo(&b, c); err != nil {
|
||||
t.Fatalf("unable to serialize creation info: %v", err)
|
||||
}
|
||||
|
||||
newCreationInfo, err := deserializePaymentCreationInfo(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to deserialize creation info: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(c, newCreationInfo) {
|
||||
t.Fatalf("Payments do not match after "+
|
||||
"serialization/deserialization %v vs %v",
|
||||
spew.Sdump(c), spew.Sdump(newCreationInfo),
|
||||
)
|
||||
}
|
||||
|
||||
b.Reset()
|
||||
if err := serializePaymentAttemptInfo(&b, s); err != nil {
|
||||
t.Fatalf("unable to serialize info: %v", err)
|
||||
}
|
||||
|
||||
newAttemptInfo, err := deserializePaymentAttemptInfo(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to deserialize info: %v", err)
|
||||
}
|
||||
|
||||
// First we verify all the records match up porperly, as they aren't
|
||||
// able to be properly compared using reflect.DeepEqual.
|
||||
err = assertRouteEqual(&s.Route, &newAttemptInfo.Route)
|
||||
if err != nil {
|
||||
t.Fatalf("Routes do not match after "+
|
||||
"serialization/deserialization: %v", err)
|
||||
}
|
||||
|
||||
// Clear routes to allow DeepEqual to compare the remaining fields.
|
||||
newAttemptInfo.Route = route.Route{}
|
||||
s.Route = route.Route{}
|
||||
|
||||
if !reflect.DeepEqual(s, newAttemptInfo) {
|
||||
s.SessionKey.Curve = nil
|
||||
newAttemptInfo.SessionKey.Curve = nil
|
||||
t.Fatalf("Payments do not match after "+
|
||||
"serialization/deserialization %v vs %v",
|
||||
spew.Sdump(s), spew.Sdump(newAttemptInfo),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// assertRouteEquals compares to routes for equality and returns an error if
|
||||
// they are not equal.
|
||||
func assertRouteEqual(a, b *route.Route) error {
|
||||
err := assertRouteHopRecordsEqual(a, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TLV records have already been compared and need to be cleared to
|
||||
// properly compare the remaining fields using DeepEqual.
|
||||
copyRouteNoHops := func(r *route.Route) *route.Route {
|
||||
copy := *r
|
||||
copy.Hops = make([]*route.Hop, len(r.Hops))
|
||||
for i, hop := range r.Hops {
|
||||
hopCopy := *hop
|
||||
hopCopy.TLVRecords = nil
|
||||
copy.Hops[i] = &hopCopy
|
||||
}
|
||||
return ©
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) {
|
||||
return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v",
|
||||
spew.Sdump(a), spew.Sdump(b))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertRouteHopRecordsEqual(r1, r2 *route.Route) error {
|
||||
if len(r1.Hops) != len(r2.Hops) {
|
||||
return errors.New("route hop count mismatch")
|
||||
}
|
||||
|
||||
for i := 0; i < len(r1.Hops); i++ {
|
||||
records1 := r1.Hops[i].TLVRecords
|
||||
records2 := r2.Hops[i].TLVRecords
|
||||
if len(records1) != len(records2) {
|
||||
return fmt.Errorf("route record count for hop %v "+
|
||||
"mismatch", i)
|
||||
}
|
||||
|
||||
for j := 0; j < len(records1); j++ {
|
||||
expectedRecord := records1[j]
|
||||
newRecord := records2[j]
|
||||
|
||||
err := assertHopRecordsEqual(expectedRecord, newRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route record mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertHopRecordsEqual(h1, h2 tlv.Record) error {
|
||||
if h1.Type() != h2.Type() {
|
||||
return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(),
|
||||
h2.Type())
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := h2.Encode(&b); err != nil {
|
||||
return fmt.Errorf("unable to encode record: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(b.Bytes(), tlvBytes) {
|
||||
return fmt.Errorf("wrong raw record: expected %x, got %x",
|
||||
tlvBytes, b.Bytes())
|
||||
}
|
||||
|
||||
if h1.Size() != h2.Size() {
|
||||
return fmt.Errorf("wrong size: expected %v, "+
|
||||
"got %v", h1.Size(), h2.Size())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRouteSerialization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := SerializeRoute(&b, testRoute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := bytes.NewReader(b.Bytes())
|
||||
route2, err := DeserializeRoute(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// First we verify all the records match up porperly, as they aren't
|
||||
// able to be properly compared using reflect.DeepEqual.
|
||||
err = assertRouteEqual(&testRoute, &route2)
|
||||
if err != nil {
|
||||
t.Fatalf("routes not equal: \n%v vs \n%v",
|
||||
spew.Sdump(testRoute), spew.Sdump(route2))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,95 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
// rejectFlags is a compact representation of various metadata stored by the
|
||||
// reject cache about a particular channel.
|
||||
type rejectFlags uint8
|
||||
|
||||
const (
|
||||
// rejectFlagExists is a flag indicating whether the channel exists,
|
||||
// i.e. the channel is open and has a recent channel update. If this
|
||||
// flag is not set, the channel is either a zombie or unknown.
|
||||
rejectFlagExists rejectFlags = 1 << iota
|
||||
|
||||
// rejectFlagZombie is a flag indicating whether the channel is a
|
||||
// zombie, i.e. the channel is open but has no recent channel updates.
|
||||
rejectFlagZombie
|
||||
)
|
||||
|
||||
// packRejectFlags computes the rejectFlags corresponding to the passed boolean
|
||||
// values indicating whether the edge exists or is a zombie.
|
||||
func packRejectFlags(exists, isZombie bool) rejectFlags {
|
||||
var flags rejectFlags
|
||||
if exists {
|
||||
flags |= rejectFlagExists
|
||||
}
|
||||
if isZombie {
|
||||
flags |= rejectFlagZombie
|
||||
}
|
||||
|
||||
return flags
|
||||
}
|
||||
|
||||
// unpack returns the booleans packed into the rejectFlags. The first indicates
|
||||
// if the edge exists in our graph, the second indicates if the edge is a
|
||||
// zombie.
|
||||
func (f rejectFlags) unpack() (bool, bool) {
|
||||
return f&rejectFlagExists == rejectFlagExists,
|
||||
f&rejectFlagZombie == rejectFlagZombie
|
||||
}
|
||||
|
||||
// rejectCacheEntry caches frequently accessed information about a channel,
|
||||
// including the timestamps of its latest edge policies and whether or not the
|
||||
// channel exists in the graph.
|
||||
type rejectCacheEntry struct {
|
||||
upd1Time int64
|
||||
upd2Time int64
|
||||
flags rejectFlags
|
||||
}
|
||||
|
||||
// rejectCache is an in-memory cache used to improve the performance of
|
||||
// HasChannelEdge. It caches information about the whether or channel exists, as
|
||||
// well as the most recent timestamps for each policy (if they exists).
|
||||
type rejectCache struct {
|
||||
n int
|
||||
edges map[uint64]rejectCacheEntry
|
||||
}
|
||||
|
||||
// newRejectCache creates a new rejectCache with maximum capacity of n entries.
|
||||
func newRejectCache(n int) *rejectCache {
|
||||
return &rejectCache{
|
||||
n: n,
|
||||
edges: make(map[uint64]rejectCacheEntry, n),
|
||||
}
|
||||
}
|
||||
|
||||
// get returns the entry from the cache for chanid, if it exists.
|
||||
func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) {
|
||||
entry, ok := c.edges[chanid]
|
||||
return entry, ok
|
||||
}
|
||||
|
||||
// insert adds the entry to the reject cache. If an entry for chanid already
|
||||
// exists, it will be replaced with the new entry. If the entry doesn't exists,
|
||||
// it will be inserted to the cache, performing a random eviction if the cache
|
||||
// is at capacity.
|
||||
func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) {
|
||||
// If entry exists, replace it.
|
||||
if _, ok := c.edges[chanid]; ok {
|
||||
c.edges[chanid] = entry
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, evict an entry at random and insert.
|
||||
if len(c.edges) == c.n {
|
||||
for id := range c.edges {
|
||||
delete(c.edges, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
c.edges[chanid] = entry
|
||||
}
|
||||
|
||||
// remove deletes an entry for chanid from the cache, if it exists.
|
||||
func (c *rejectCache) remove(chanid uint64) {
|
||||
delete(c.edges, chanid)
|
||||
}
|
|
@ -1,107 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestRejectCache checks the behavior of the rejectCache with respect to insertion,
|
||||
// eviction, and removal of cache entries.
|
||||
func TestRejectCache(t *testing.T) {
|
||||
const cacheSize = 100
|
||||
|
||||
// Create a new reject cache with the configured max size.
|
||||
c := newRejectCache(cacheSize)
|
||||
|
||||
// As a sanity check, assert that querying the empty cache does not
|
||||
// return an entry.
|
||||
_, ok := c.get(0)
|
||||
if ok {
|
||||
t.Fatalf("reject cache should be empty")
|
||||
}
|
||||
|
||||
// Now, fill up the cache entirely.
|
||||
for i := uint64(0); i < cacheSize; i++ {
|
||||
c.insert(i, entryForInt(i))
|
||||
}
|
||||
|
||||
// Assert that the cache has all of the entries just inserted, since no
|
||||
// eviction should occur until we try to surpass the max size.
|
||||
assertHasEntries(t, c, 0, cacheSize)
|
||||
|
||||
// Now, insert a new element that causes the cache to evict an element.
|
||||
c.insert(cacheSize, entryForInt(cacheSize))
|
||||
|
||||
// Assert that the cache has this last entry, as the cache should evict
|
||||
// some prior element and not the newly inserted one.
|
||||
assertHasEntries(t, c, cacheSize, cacheSize)
|
||||
|
||||
// Iterate over all inserted elements and construct a set of the evicted
|
||||
// elements.
|
||||
evicted := make(map[uint64]struct{})
|
||||
for i := uint64(0); i < cacheSize+1; i++ {
|
||||
_, ok := c.get(i)
|
||||
if !ok {
|
||||
evicted[i] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Assert that exactly one element has been evicted.
|
||||
numEvicted := len(evicted)
|
||||
if numEvicted != 1 {
|
||||
t.Fatalf("expected one evicted entry, got: %d", numEvicted)
|
||||
}
|
||||
|
||||
// Remove the highest item which initially caused the eviction and
|
||||
// reinsert the element that was evicted prior.
|
||||
c.remove(cacheSize)
|
||||
for i := range evicted {
|
||||
c.insert(i, entryForInt(i))
|
||||
}
|
||||
|
||||
// Since the removal created an extra slot, the last insertion should
|
||||
// not have caused an eviction and the entries for all channels in the
|
||||
// original set that filled the cache should be present.
|
||||
assertHasEntries(t, c, 0, cacheSize)
|
||||
|
||||
// Finally, reinsert the existing set back into the cache and test that
|
||||
// the cache still has all the entries. If the randomized eviction were
|
||||
// happening on inserts for existing cache items, we expect this to fail
|
||||
// with high probability.
|
||||
for i := uint64(0); i < cacheSize; i++ {
|
||||
c.insert(i, entryForInt(i))
|
||||
}
|
||||
assertHasEntries(t, c, 0, cacheSize)
|
||||
|
||||
}
|
||||
|
||||
// assertHasEntries queries the reject cache for all channels in the range [start,
|
||||
// end), asserting that they exist and their value matches the entry produced by
|
||||
// entryForInt.
|
||||
func assertHasEntries(t *testing.T, c *rejectCache, start, end uint64) {
|
||||
t.Helper()
|
||||
|
||||
for i := start; i < end; i++ {
|
||||
entry, ok := c.get(i)
|
||||
if !ok {
|
||||
t.Fatalf("reject cache should contain chan %d", i)
|
||||
}
|
||||
|
||||
expEntry := entryForInt(i)
|
||||
if !reflect.DeepEqual(entry, expEntry) {
|
||||
t.Fatalf("entry mismatch, want: %v, got: %v",
|
||||
expEntry, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// entryForInt generates a unique rejectCacheEntry given an integer.
|
||||
func entryForInt(i uint64) rejectCacheEntry {
|
||||
exists := i%2 == 0
|
||||
isZombie := i%3 == 0
|
||||
return rejectCacheEntry{
|
||||
upd1Time: int64(2 * i),
|
||||
upd2Time: int64(2*i + 1),
|
||||
flags: packRejectFlags(exists, isZombie),
|
||||
}
|
||||
}
|
|
@ -1,251 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
|
||||
"io"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
var (
|
||||
// waitingProofsBucketKey byte string name of the waiting proofs store.
|
||||
waitingProofsBucketKey = []byte("waitingproofs")
|
||||
|
||||
// ErrWaitingProofNotFound is returned if waiting proofs haven't been
|
||||
// found by db.
|
||||
ErrWaitingProofNotFound = errors.New("waiting proofs haven't been " +
|
||||
"found")
|
||||
|
||||
// ErrWaitingProofAlreadyExist is returned if waiting proofs haven't been
|
||||
// found by db.
|
||||
ErrWaitingProofAlreadyExist = errors.New("waiting proof with such " +
|
||||
"key already exist")
|
||||
)
|
||||
|
||||
// WaitingProofStore is the bold db map-like storage for half announcement
|
||||
// signatures. The one responsibility of this storage is to be able to
|
||||
// retrieve waiting proofs after client restart.
|
||||
type WaitingProofStore struct {
|
||||
// cache is used in order to reduce the number of redundant get
|
||||
// calls, when object isn't stored in it.
|
||||
cache map[WaitingProofKey]struct{}
|
||||
db *DB
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewWaitingProofStore creates new instance of proofs storage.
|
||||
func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) {
|
||||
s := &WaitingProofStore{
|
||||
db: db,
|
||||
cache: make(map[WaitingProofKey]struct{}),
|
||||
}
|
||||
|
||||
if err := s.ForAll(func(proof *WaitingProof) error {
|
||||
s.cache[proof.Key()] = struct{}{}
|
||||
return nil
|
||||
}); err != nil && err != ErrWaitingProofNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Add adds new waiting proof in the storage.
|
||||
func (s *WaitingProofStore) Add(proof *WaitingProof) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
err := s.db.Update(func(tx *bbolt.Tx) error {
|
||||
var err error
|
||||
var b bytes.Buffer
|
||||
|
||||
// Get or create the bucket.
|
||||
bucket, err := tx.CreateBucketIfNotExists(waitingProofsBucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encode the objects and place it in the bucket.
|
||||
if err := proof.Encode(&b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := proof.Key()
|
||||
|
||||
return bucket.Put(key[:], b.Bytes())
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Knowing that the write succeeded, we can now update the in-memory
|
||||
// cache with the proof's key.
|
||||
s.cache[proof.Key()] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove removes the proof from storage by its key.
|
||||
func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, ok := s.cache[key]; !ok {
|
||||
return ErrWaitingProofNotFound
|
||||
}
|
||||
|
||||
err := s.db.Update(func(tx *bbolt.Tx) error {
|
||||
// Get or create the top bucket.
|
||||
bucket := tx.Bucket(waitingProofsBucketKey)
|
||||
if bucket == nil {
|
||||
return ErrWaitingProofNotFound
|
||||
}
|
||||
|
||||
return bucket.Delete(key[:])
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Since the proof was successfully deleted from the store, we can now
|
||||
// remove it from the in-memory cache.
|
||||
delete(s.cache, key)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForAll iterates thought all waiting proofs and passing the waiting proof
|
||||
// in the given callback.
|
||||
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error {
|
||||
return s.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(waitingProofsBucketKey)
|
||||
if bucket == nil {
|
||||
return ErrWaitingProofNotFound
|
||||
}
|
||||
|
||||
// Iterate over objects buckets.
|
||||
return bucket.ForEach(func(k, v []byte) error {
|
||||
// Skip buckets fields.
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := bytes.NewReader(v)
|
||||
proof := &WaitingProof{}
|
||||
if err := proof.Decode(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return cb(proof)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Get returns the object which corresponds to the given index.
|
||||
func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
|
||||
proof := &WaitingProof{}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if _, ok := s.cache[key]; !ok {
|
||||
return nil, ErrWaitingProofNotFound
|
||||
}
|
||||
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(waitingProofsBucketKey)
|
||||
if bucket == nil {
|
||||
return ErrWaitingProofNotFound
|
||||
}
|
||||
|
||||
// Iterate over objects buckets.
|
||||
v := bucket.Get(key[:])
|
||||
if v == nil {
|
||||
return ErrWaitingProofNotFound
|
||||
}
|
||||
|
||||
r := bytes.NewReader(v)
|
||||
return proof.Decode(r)
|
||||
})
|
||||
|
||||
return proof, err
|
||||
}
|
||||
|
||||
// WaitingProofKey is the proof key which uniquely identifies the waiting
|
||||
// proof object. The goal of this key is distinguish the local and remote
|
||||
// proof for the same channel id.
|
||||
type WaitingProofKey [9]byte
|
||||
|
||||
// WaitingProof is the storable object, which encapsulate the half proof and
|
||||
// the information about from which side this proof came. This structure is
|
||||
// needed to make channel proof exchange persistent, so that after client
|
||||
// restart we may receive remote/local half proof and process it.
|
||||
type WaitingProof struct {
|
||||
*lnwire.AnnounceSignatures
|
||||
isRemote bool
|
||||
}
|
||||
|
||||
// NewWaitingProof constructs a new waiting prof instance.
|
||||
func NewWaitingProof(isRemote bool, proof *lnwire.AnnounceSignatures) *WaitingProof {
|
||||
return &WaitingProof{
|
||||
AnnounceSignatures: proof,
|
||||
isRemote: isRemote,
|
||||
}
|
||||
}
|
||||
|
||||
// OppositeKey returns the key which uniquely identifies opposite waiting proof.
|
||||
func (p *WaitingProof) OppositeKey() WaitingProofKey {
|
||||
var key [9]byte
|
||||
binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64())
|
||||
|
||||
if !p.isRemote {
|
||||
key[8] = 1
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// Key returns the key which uniquely identifies waiting proof.
|
||||
func (p *WaitingProof) Key() WaitingProofKey {
|
||||
var key [9]byte
|
||||
binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64())
|
||||
|
||||
if p.isRemote {
|
||||
key[8] = 1
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// Encode writes the internal representation of waiting proof in byte stream.
|
||||
func (p *WaitingProof) Encode(w io.Writer) error {
|
||||
if err := binary.Write(w, byteOrder, p.isRemote); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := p.AnnounceSignatures.Encode(w, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode reads the data from the byte stream and initializes the
|
||||
// waiting proof object with it.
|
||||
func (p *WaitingProof) Decode(r io.Reader) error {
|
||||
if err := binary.Read(r, byteOrder, &p.isRemote); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg := &lnwire.AnnounceSignatures{}
|
||||
if err := msg.Decode(r, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*p).AnnounceSignatures = msg
|
||||
return nil
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"reflect"
|
||||
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// TestWaitingProofStore tests add/get/remove functions of the waiting proof
|
||||
// storage.
|
||||
func TestWaitingProofStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanup, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make test database: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{
|
||||
NodeSignature: wireSig,
|
||||
BitcoinSignature: wireSig,
|
||||
})
|
||||
|
||||
store, err := NewWaitingProofStore(db)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create the waiting proofs storage: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
if err := store.Add(proof1); err != nil {
|
||||
t.Fatalf("unable add proof to storage: %v", err)
|
||||
}
|
||||
|
||||
proof2, err := store.Get(proof1.Key())
|
||||
if err != nil {
|
||||
t.Fatalf("unable retrieve proof from storage: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(proof1, proof2) {
|
||||
t.Fatal("wrong proof retrieved")
|
||||
}
|
||||
|
||||
if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound {
|
||||
t.Fatalf("proof shouldn't be found: %v", err)
|
||||
}
|
||||
|
||||
if err := store.Remove(proof1.Key()); err != nil {
|
||||
t.Fatalf("unable remove proof from storage: %v", err)
|
||||
}
|
||||
|
||||
if err := store.ForAll(func(proof *WaitingProof) error {
|
||||
return errors.New("storage should be empty")
|
||||
}); err != nil && err != ErrWaitingProofNotFound {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
|
@ -1,229 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoWitnesses is an error that's returned when no new witnesses have
|
||||
// been added to the WitnessCache.
|
||||
ErrNoWitnesses = fmt.Errorf("no witnesses")
|
||||
|
||||
// ErrUnknownWitnessType is returned if a caller attempts to
|
||||
ErrUnknownWitnessType = fmt.Errorf("unknown witness type")
|
||||
)
|
||||
|
||||
// WitnessType is enum that denotes what "type" of witness is being
|
||||
// stored/retrieved. As the WitnessCache itself is agnostic and doesn't enforce
|
||||
// any structure on added witnesses, we use this type to partition the
|
||||
// witnesses on disk, and also to know how to map a witness to its look up key.
|
||||
type WitnessType uint8
|
||||
|
||||
var (
|
||||
// Sha256HashWitness is a witness that is simply the pre image to a
|
||||
// hash image. In order to map to its key, we'll use sha256.
|
||||
Sha256HashWitness WitnessType = 1
|
||||
)
|
||||
|
||||
// toDBKey is a helper method that maps a witness type to the key that we'll
|
||||
// use to store it within the database.
|
||||
func (w WitnessType) toDBKey() ([]byte, error) {
|
||||
switch w {
|
||||
|
||||
case Sha256HashWitness:
|
||||
return []byte{byte(w)}, nil
|
||||
|
||||
default:
|
||||
return nil, ErrUnknownWitnessType
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// witnessBucketKey is the name of the bucket that we use to store all
|
||||
// witnesses encountered. Within this bucket, we'll create a sub-bucket for
|
||||
// each witness type.
|
||||
witnessBucketKey = []byte("byte")
|
||||
)
|
||||
|
||||
// WitnessCache is a persistent cache of all witnesses we've encountered on the
|
||||
// network. In the case of multi-hop, multi-step contracts, a cache of all
|
||||
// witnesses can be useful in the case of partial contract resolution. If
|
||||
// negotiations break down, we may be forced to locate the witness for a
|
||||
// portion of the contract on-chain. In this case, we'll then add that witness
|
||||
// to the cache so the incoming contract can fully resolve witness.
|
||||
// Additionally, as one MUST always use a unique witness on the network, we may
|
||||
// use this cache to detect duplicate witnesses.
|
||||
//
|
||||
// TODO(roasbeef): need expiry policy?
|
||||
// * encrypt?
|
||||
type WitnessCache struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewWitnessCache returns a new instance of the witness cache.
|
||||
func (d *DB) NewWitnessCache() *WitnessCache {
|
||||
return &WitnessCache{
|
||||
db: d,
|
||||
}
|
||||
}
|
||||
|
||||
// witnessEntry is a key-value struct that holds each key -> witness pair, used
|
||||
// when inserting records into the cache.
|
||||
type witnessEntry struct {
|
||||
key []byte
|
||||
witness []byte
|
||||
}
|
||||
|
||||
// AddSha256Witnesses adds a batch of new sha256 preimages into the witness
|
||||
// cache. This is an alias for AddWitnesses that uses Sha256HashWitness as the
|
||||
// preimages' witness type.
|
||||
func (w *WitnessCache) AddSha256Witnesses(preimages ...lntypes.Preimage) error {
|
||||
// Optimistically compute the preimages' hashes before attempting to
|
||||
// start the db transaction.
|
||||
entries := make([]witnessEntry, 0, len(preimages))
|
||||
for i := range preimages {
|
||||
hash := preimages[i].Hash()
|
||||
entries = append(entries, witnessEntry{
|
||||
key: hash[:],
|
||||
witness: preimages[i][:],
|
||||
})
|
||||
}
|
||||
|
||||
return w.addWitnessEntries(Sha256HashWitness, entries)
|
||||
}
|
||||
|
||||
// addWitnessEntries inserts the witnessEntry key-value pairs into the cache,
|
||||
// using the appropriate witness type to segment the namespace of possible
|
||||
// witness types.
|
||||
func (w *WitnessCache) addWitnessEntries(wType WitnessType,
|
||||
entries []witnessEntry) error {
|
||||
|
||||
// Exit early if there are no witnesses to add.
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.db.Batch(func(tx *bbolt.Tx) error {
|
||||
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
witnessTypeBucketKey, err := wType.toDBKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists(
|
||||
witnessTypeBucketKey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
err = witnessTypeBucket.Put(entry.key, entry.witness)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// LookupSha256Witness attempts to lookup the preimage for a sha256 hash. If
|
||||
// the witness isn't found, ErrNoWitnesses will be returned.
|
||||
func (w *WitnessCache) LookupSha256Witness(hash lntypes.Hash) (lntypes.Preimage, error) {
|
||||
witness, err := w.lookupWitness(Sha256HashWitness, hash[:])
|
||||
if err != nil {
|
||||
return lntypes.Preimage{}, err
|
||||
}
|
||||
|
||||
return lntypes.MakePreimage(witness)
|
||||
}
|
||||
|
||||
// lookupWitness attempts to lookup a witness according to its type and also
|
||||
// its witness key. In the case that the witness isn't found, ErrNoWitnesses
|
||||
// will be returned.
|
||||
func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) {
|
||||
var witness []byte
|
||||
err := w.db.View(func(tx *bbolt.Tx) error {
|
||||
witnessBucket := tx.Bucket(witnessBucketKey)
|
||||
if witnessBucket == nil {
|
||||
return ErrNoWitnesses
|
||||
}
|
||||
|
||||
witnessTypeBucketKey, err := wType.toDBKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
witnessTypeBucket := witnessBucket.Bucket(witnessTypeBucketKey)
|
||||
if witnessTypeBucket == nil {
|
||||
return ErrNoWitnesses
|
||||
}
|
||||
|
||||
dbWitness := witnessTypeBucket.Get(witnessKey)
|
||||
if dbWitness == nil {
|
||||
return ErrNoWitnesses
|
||||
}
|
||||
|
||||
witness = make([]byte, len(dbWitness))
|
||||
copy(witness[:], dbWitness)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return witness, nil
|
||||
}
|
||||
|
||||
// DeleteSha256Witness attempts to delete a sha256 preimage identified by hash.
|
||||
func (w *WitnessCache) DeleteSha256Witness(hash lntypes.Hash) error {
|
||||
return w.deleteWitness(Sha256HashWitness, hash[:])
|
||||
}
|
||||
|
||||
// deleteWitness attempts to delete a particular witness from the database.
|
||||
func (w *WitnessCache) deleteWitness(wType WitnessType, witnessKey []byte) error {
|
||||
return w.db.Batch(func(tx *bbolt.Tx) error {
|
||||
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
witnessTypeBucketKey, err := wType.toDBKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists(
|
||||
witnessTypeBucketKey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return witnessTypeBucket.Delete(witnessKey)
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteWitnessClass attempts to delete an *entire* class of witnesses. After
|
||||
// this function return with a non-nil error,
|
||||
func (w *WitnessCache) DeleteWitnessClass(wType WitnessType) error {
|
||||
return w.db.Batch(func(tx *bbolt.Tx) error {
|
||||
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
witnessTypeBucketKey, err := wType.toDBKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return witnessBucket.DeleteBucket(witnessTypeBucketKey)
|
||||
})
|
||||
}
|
|
@ -1,238 +0,0 @@
|
|||
package migration_01_to_11
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
)
|
||||
|
||||
// TestWitnessCacheSha256Retrieval tests that we're able to add and lookup new
|
||||
// sha256 preimages to the witness cache.
|
||||
func TestWitnessCacheSha256Retrieval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
wCache := cdb.NewWitnessCache()
|
||||
|
||||
// We'll be attempting to add then lookup two simple sha256 preimages
|
||||
// within this test.
|
||||
preimage1 := lntypes.Preimage(rev)
|
||||
preimage2 := lntypes.Preimage(key)
|
||||
|
||||
preimages := []lntypes.Preimage{preimage1, preimage2}
|
||||
hashes := []lntypes.Hash{preimage1.Hash(), preimage2.Hash()}
|
||||
|
||||
// First, we'll attempt to add the preimages to the database.
|
||||
err = wCache.AddSha256Witnesses(preimages...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to add witness: %v", err)
|
||||
}
|
||||
|
||||
// With the preimages stored, we'll now attempt to look them up.
|
||||
for i, hash := range hashes {
|
||||
preimage := preimages[i]
|
||||
|
||||
// We should get back the *exact* same preimage as we originally
|
||||
// stored.
|
||||
dbPreimage, err := wCache.LookupSha256Witness(hash)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to look up witness: %v", err)
|
||||
}
|
||||
|
||||
if preimage != dbPreimage {
|
||||
t.Fatalf("witnesses don't match: expected %x, got %x",
|
||||
preimage[:], dbPreimage[:])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestWitnessCacheSha256Deletion tests that we're able to delete a single
|
||||
// sha256 preimage, and also a class of witnesses from the cache.
|
||||
func TestWitnessCacheSha256Deletion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
wCache := cdb.NewWitnessCache()
|
||||
|
||||
// We'll start by adding two preimages to the cache.
|
||||
preimage1 := lntypes.Preimage(key)
|
||||
hash1 := preimage1.Hash()
|
||||
|
||||
preimage2 := lntypes.Preimage(rev)
|
||||
hash2 := preimage2.Hash()
|
||||
|
||||
if err := wCache.AddSha256Witnesses(preimage1); err != nil {
|
||||
t.Fatalf("unable to add witness: %v", err)
|
||||
}
|
||||
|
||||
if err := wCache.AddSha256Witnesses(preimage2); err != nil {
|
||||
t.Fatalf("unable to add witness: %v", err)
|
||||
}
|
||||
|
||||
// We'll now delete the first preimage. If we attempt to look it up, we
|
||||
// should get ErrNoWitnesses.
|
||||
err = wCache.DeleteSha256Witness(hash1)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to delete witness: %v", err)
|
||||
}
|
||||
_, err = wCache.LookupSha256Witness(hash1)
|
||||
if err != ErrNoWitnesses {
|
||||
t.Fatalf("expected ErrNoWitnesses instead got: %v", err)
|
||||
}
|
||||
|
||||
// Next, we'll attempt to delete the entire witness class itself. When
|
||||
// we try to lookup the second preimage, we should again get
|
||||
// ErrNoWitnesses.
|
||||
if err := wCache.DeleteWitnessClass(Sha256HashWitness); err != nil {
|
||||
t.Fatalf("unable to delete witness class: %v", err)
|
||||
}
|
||||
_, err = wCache.LookupSha256Witness(hash2)
|
||||
if err != ErrNoWitnesses {
|
||||
t.Fatalf("expected ErrNoWitnesses instead got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWitnessCacheUnknownWitness tests that we get an error if we attempt to
|
||||
// query/add/delete an unknown witness.
|
||||
func TestWitnessCacheUnknownWitness(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
wCache := cdb.NewWitnessCache()
|
||||
|
||||
// We'll attempt to add a new, undefined witness type to the database.
|
||||
// We should get an error.
|
||||
err = wCache.legacyAddWitnesses(234, key[:])
|
||||
if err != ErrUnknownWitnessType {
|
||||
t.Fatalf("expected ErrUnknownWitnessType, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves
|
||||
// identically to the insertion via the generalized interface.
|
||||
func TestAddSha256Witnesses(t *testing.T) {
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
wCache := cdb.NewWitnessCache()
|
||||
|
||||
// We'll start by adding a witnesses to the cache using the generic
|
||||
// AddWitnesses method.
|
||||
witness1 := rev[:]
|
||||
preimage1 := lntypes.Preimage(rev)
|
||||
hash1 := preimage1.Hash()
|
||||
|
||||
witness2 := key[:]
|
||||
preimage2 := lntypes.Preimage(key)
|
||||
hash2 := preimage2.Hash()
|
||||
|
||||
var (
|
||||
witnesses = [][]byte{witness1, witness2}
|
||||
preimages = []lntypes.Preimage{preimage1, preimage2}
|
||||
hashes = []lntypes.Hash{hash1, hash2}
|
||||
)
|
||||
|
||||
err = wCache.legacyAddWitnesses(Sha256HashWitness, witnesses...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to add witness: %v", err)
|
||||
}
|
||||
|
||||
for i, hash := range hashes {
|
||||
preimage := preimages[i]
|
||||
|
||||
dbPreimage, err := wCache.LookupSha256Witness(hash)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to lookup witness: %v", err)
|
||||
}
|
||||
|
||||
// Assert that the retrieved witness matches the original.
|
||||
if dbPreimage != preimage {
|
||||
t.Fatalf("retrieved witness mismatch, want: %x, "+
|
||||
"got: %x", preimage, dbPreimage)
|
||||
}
|
||||
|
||||
// We'll now delete the witness, as we'll be reinserting it
|
||||
// using the specialized AddSha256Witnesses method.
|
||||
err = wCache.DeleteSha256Witness(hash)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to delete witness: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now, add the same witnesses using the type-safe interface for
|
||||
// lntypes.Preimages..
|
||||
err = wCache.AddSha256Witnesses(preimages...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to add sha256 preimage: %v", err)
|
||||
}
|
||||
|
||||
// Finally, iterate over the keys and assert that the returned witnesses
|
||||
// match the original witnesses. This asserts that the specialized
|
||||
// insertion method behaves identically to the generalized interface.
|
||||
for i, hash := range hashes {
|
||||
preimage := preimages[i]
|
||||
|
||||
dbPreimage, err := wCache.LookupSha256Witness(hash)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to lookup witness: %v", err)
|
||||
}
|
||||
|
||||
// Assert that the retrieved witness matches the original.
|
||||
if dbPreimage != preimage {
|
||||
t.Fatalf("retrieved witness mismatch, want: %x, "+
|
||||
"got: %x", preimage, dbPreimage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// legacyAddWitnesses adds a batch of new witnesses of wType to the witness
|
||||
// cache. The type of the witness will be used to map each witness to the key
|
||||
// that will be used to look it up. All witnesses should be of the same
|
||||
// WitnessType.
|
||||
//
|
||||
// NOTE: Previously this method exposed a generic interface for adding
|
||||
// witnesses, which has since been deprecated in favor of a strongly typed
|
||||
// interface for each witness class. We keep this method around to assert the
|
||||
// correctness of specialized witness adding methods.
|
||||
func (w *WitnessCache) legacyAddWitnesses(wType WitnessType,
|
||||
witnesses ...[]byte) error {
|
||||
|
||||
// Optimistically compute the witness keys before attempting to start
|
||||
// the db transaction.
|
||||
entries := make([]witnessEntry, 0, len(witnesses))
|
||||
for _, witness := range witnesses {
|
||||
// Map each witness to its key by applying the appropriate
|
||||
// transformation for the given witness type.
|
||||
switch wType {
|
||||
case Sha256HashWitness:
|
||||
key := sha256.Sum256(witness)
|
||||
entries = append(entries, witnessEntry{
|
||||
key: key[:],
|
||||
witness: witness,
|
||||
})
|
||||
default:
|
||||
return ErrUnknownWitnessType
|
||||
}
|
||||
}
|
||||
|
||||
return w.addWitnessEntries(wType, entries)
|
||||
}
|
Loading…
Add table
Reference in a new issue