mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 13:27:56 +01:00
htlcswitch: introduce resolutionStore to persist cnct messages
This commit is contained in:
parent
9bbee09497
commit
bfed7a088f
202
htlcswitch/resolution_store.go
Normal file
202
htlcswitch/resolution_store.go
Normal file
@ -0,0 +1,202 @@
|
||||
package htlcswitch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/contractcourt"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
var (
|
||||
// resBucketKey is used for the root level bucket that stores the
|
||||
// CircuitKey -> ResolutionMsg mapping.
|
||||
resBucketKey = []byte("resolution-store-bucket-key")
|
||||
|
||||
// errResMsgNotFound is used to let callers know that the resolution
|
||||
// message was not found for the given CircuitKey. This is used in the
|
||||
// checkResolutionMsg function.
|
||||
errResMsgNotFound = errors.New("resolution message not found")
|
||||
)
|
||||
|
||||
// resolutionStore contains ResolutionMsgs received from the contractcourt. The
|
||||
// Switch deletes these from the store when the underlying circuit has been
|
||||
// removed via DeleteCircuits. If the circuit hasn't been deleted, the Switch
|
||||
// will dispatch the ResolutionMsg to a link if this was a multi-hop HTLC or to
|
||||
// itself if the Switch initiated the payment.
|
||||
type resolutionStore struct {
|
||||
backend kvdb.Backend
|
||||
}
|
||||
|
||||
func newResolutionStore(db kvdb.Backend) *resolutionStore {
|
||||
return &resolutionStore{
|
||||
backend: db,
|
||||
}
|
||||
}
|
||||
|
||||
// addResolutionMsg persists a ResolutionMsg to the resolutionStore.
|
||||
func (r *resolutionStore) addResolutionMsg(
|
||||
resMsg *contractcourt.ResolutionMsg) error {
|
||||
|
||||
// The outKey will be the database key.
|
||||
outKey := &CircuitKey{
|
||||
ChanID: resMsg.SourceChan,
|
||||
HtlcID: resMsg.HtlcIndex,
|
||||
}
|
||||
|
||||
var resBuf bytes.Buffer
|
||||
if err := serializeResolutionMsg(&resBuf, resMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := kvdb.Update(r.backend, func(tx kvdb.RwTx) error {
|
||||
resBucket, err := tx.CreateTopLevelBucket(resBucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return resBucket.Put(outKey.Bytes(), resBuf.Bytes())
|
||||
}, func() {})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkResolutionMsg returns nil if the resolution message is found in the
|
||||
// store. It returns an error if no resolution message was found for the
|
||||
// passed outKey or if a database error occurred.
|
||||
func (r *resolutionStore) checkResolutionMsg(outKey *CircuitKey) error {
|
||||
err := kvdb.View(r.backend, func(tx kvdb.RTx) error {
|
||||
resBucket := tx.ReadBucket(resBucketKey)
|
||||
if resBucket == nil {
|
||||
// Return an error if the bucket doesn't exist.
|
||||
return errResMsgNotFound
|
||||
}
|
||||
|
||||
msg := resBucket.Get(outKey.Bytes())
|
||||
if msg == nil {
|
||||
// Return the not found error since no message exists
|
||||
// for this CircuitKey.
|
||||
return errResMsgNotFound
|
||||
}
|
||||
|
||||
// Return nil to indicate that the message was found.
|
||||
return nil
|
||||
}, func() {})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchAllResolutionMsg returns a slice of all stored ResolutionMsgs. This is
|
||||
// used by the Switch on start-up.
|
||||
func (r *resolutionStore) fetchAllResolutionMsg() (
|
||||
[]*contractcourt.ResolutionMsg, error) {
|
||||
|
||||
var msgs []*contractcourt.ResolutionMsg
|
||||
|
||||
err := kvdb.View(r.backend, func(tx kvdb.RTx) error {
|
||||
resBucket := tx.ReadBucket(resBucketKey)
|
||||
if resBucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return resBucket.ForEach(func(k, v []byte) error {
|
||||
kr := bytes.NewReader(k)
|
||||
outKey := &CircuitKey{}
|
||||
if err := outKey.Decode(kr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vr := bytes.NewReader(v)
|
||||
resMsg, err := deserializeResolutionMsg(vr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the CircuitKey values on the ResolutionMsg.
|
||||
resMsg.SourceChan = outKey.ChanID
|
||||
resMsg.HtlcIndex = outKey.HtlcID
|
||||
|
||||
msgs = append(msgs, resMsg)
|
||||
return nil
|
||||
})
|
||||
}, func() {
|
||||
msgs = nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
// deleteResolutionMsg removes a ResolutionMsg with the passed-in CircuitKey.
|
||||
func (r *resolutionStore) deleteResolutionMsg(outKey *CircuitKey) error {
|
||||
err := kvdb.Update(r.backend, func(tx kvdb.RwTx) error {
|
||||
resBucket, err := tx.CreateTopLevelBucket(resBucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return resBucket.Delete(outKey.Bytes())
|
||||
}, func() {})
|
||||
return err
|
||||
}
|
||||
|
||||
// serializeResolutionMsg writes part of a ResolutionMsg to the passed
|
||||
// io.Writer.
|
||||
func serializeResolutionMsg(w io.Writer,
|
||||
resMsg *contractcourt.ResolutionMsg) error {
|
||||
|
||||
isFail := resMsg.Failure != nil
|
||||
|
||||
if err := channeldb.WriteElement(w, isFail); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If this is a failure message, then we're done serializing.
|
||||
if isFail {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Else this is a settle message, and we need to write the preimage.
|
||||
return channeldb.WriteElement(w, *resMsg.PreImage)
|
||||
}
|
||||
|
||||
// deserializeResolutionMsg reads part of a ResolutionMsg from the passed
|
||||
// io.Reader.
|
||||
func deserializeResolutionMsg(r io.Reader) (*contractcourt.ResolutionMsg,
|
||||
error) {
|
||||
|
||||
resMsg := &contractcourt.ResolutionMsg{}
|
||||
var isFail bool
|
||||
|
||||
if err := channeldb.ReadElements(r, &isFail); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If a failure resolution msg was stored, set the Failure field.
|
||||
if isFail {
|
||||
failureMsg := &lnwire.FailPermanentChannelFailure{}
|
||||
resMsg.Failure = failureMsg
|
||||
return resMsg, nil
|
||||
}
|
||||
|
||||
var preimage [32]byte
|
||||
resMsg.PreImage = &preimage
|
||||
|
||||
// Else this is a settle resolution msg and we will read the preimage.
|
||||
if err := channeldb.ReadElement(r, resMsg.PreImage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resMsg, nil
|
||||
}
|
154
htlcswitch/resolution_store_test.go
Normal file
154
htlcswitch/resolution_store_test.go
Normal file
@ -0,0 +1,154 @@
|
||||
package htlcswitch
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/contractcourt"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestInsertAndDelete tests that an inserted resolution message can be
|
||||
// deleted.
|
||||
func TestInsertAndDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scid := lnwire.NewShortChanIDFromInt(1)
|
||||
|
||||
failResMsg := &contractcourt.ResolutionMsg{
|
||||
SourceChan: scid,
|
||||
HtlcIndex: 2,
|
||||
Failure: &lnwire.FailTemporaryChannelFailure{},
|
||||
}
|
||||
|
||||
settleBytes := [32]byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
|
||||
}
|
||||
|
||||
settleResMsg := &contractcourt.ResolutionMsg{
|
||||
SourceChan: scid,
|
||||
HtlcIndex: 3,
|
||||
PreImage: &settleBytes,
|
||||
}
|
||||
|
||||
// Create the backend database and use it to create the resolution
|
||||
// store.
|
||||
dbDir, err := ioutil.TempDir("", "resolutionStore")
|
||||
require.NoError(t, err)
|
||||
|
||||
dbPath := filepath.Join(dbDir, "testdb")
|
||||
db, err := kvdb.Create(
|
||||
kvdb.BoltBackendName, dbPath, true, kvdb.DefaultDBTimeout,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanUp := func() {
|
||||
db.Close()
|
||||
os.RemoveAll(dbDir)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
resStore := newResolutionStore(db)
|
||||
|
||||
// We'll add the failure resolution message first, then check that it
|
||||
// exists in the store.
|
||||
err = resStore.addResolutionMsg(failResMsg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert that checkResolutionMsg returns nil, signalling that the
|
||||
// resolution message was properly stored.
|
||||
outKey := &CircuitKey{
|
||||
ChanID: failResMsg.SourceChan,
|
||||
HtlcID: failResMsg.HtlcIndex,
|
||||
}
|
||||
err = resStore.checkResolutionMsg(outKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
resMsgs, err := resStore.fetchAllResolutionMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(resMsgs))
|
||||
|
||||
// It should match failResMsg above.
|
||||
require.Equal(t, failResMsg.SourceChan, resMsgs[0].SourceChan)
|
||||
require.Equal(t, failResMsg.HtlcIndex, resMsgs[0].HtlcIndex)
|
||||
require.NotNil(t, resMsgs[0].Failure)
|
||||
require.Nil(t, resMsgs[0].PreImage)
|
||||
|
||||
// We'll add the settleResMsg now.
|
||||
err = resStore.addResolutionMsg(settleResMsg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that checkResolutionMsg returns nil for the settle CircuitKey.
|
||||
outKey.ChanID = settleResMsg.SourceChan
|
||||
outKey.HtlcID = settleResMsg.HtlcIndex
|
||||
err = resStore.checkResolutionMsg(outKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We should have two resolution messages in the store, one failure and
|
||||
// one success.
|
||||
resMsgs, err = resStore.fetchAllResolutionMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(resMsgs))
|
||||
|
||||
// The first resolution message should be the failure.
|
||||
require.Equal(t, failResMsg.SourceChan, resMsgs[0].SourceChan)
|
||||
require.Equal(t, failResMsg.HtlcIndex, resMsgs[0].HtlcIndex)
|
||||
require.NotNil(t, resMsgs[0].Failure)
|
||||
require.Nil(t, resMsgs[0].PreImage)
|
||||
|
||||
// The second resolution message should be the success.
|
||||
require.Equal(t, settleResMsg.SourceChan, resMsgs[1].SourceChan)
|
||||
require.Equal(t, settleResMsg.HtlcIndex, resMsgs[1].HtlcIndex)
|
||||
require.Nil(t, resMsgs[1].Failure)
|
||||
require.Equal(t, settleBytes, *resMsgs[1].PreImage)
|
||||
|
||||
// We'll now delete the failure resolution message and assert that only
|
||||
// the success is left.
|
||||
failKey := &CircuitKey{
|
||||
ChanID: scid,
|
||||
HtlcID: failResMsg.HtlcIndex,
|
||||
}
|
||||
|
||||
err = resStore.deleteResolutionMsg(failKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert that checkResolutionMsg returns errResMsgNotFound.
|
||||
err = resStore.checkResolutionMsg(failKey)
|
||||
require.ErrorIs(t, err, errResMsgNotFound)
|
||||
|
||||
resMsgs, err = resStore.fetchAllResolutionMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(resMsgs))
|
||||
|
||||
// Assert that the success is left.
|
||||
require.Equal(t, settleResMsg.SourceChan, resMsgs[0].SourceChan)
|
||||
require.Equal(t, settleResMsg.HtlcIndex, resMsgs[0].HtlcIndex)
|
||||
require.Nil(t, resMsgs[0].Failure)
|
||||
require.Equal(t, settleBytes, *resMsgs[0].PreImage)
|
||||
|
||||
// Now we'll delete the settle resolution message and assert that the
|
||||
// store is empty.
|
||||
settleKey := &CircuitKey{
|
||||
ChanID: scid,
|
||||
HtlcID: settleResMsg.HtlcIndex,
|
||||
}
|
||||
|
||||
err = resStore.deleteResolutionMsg(settleKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert that checkResolutionMsg returns errResMsgNotFound for the
|
||||
// settle key.
|
||||
err = resStore.checkResolutionMsg(settleKey)
|
||||
require.ErrorIs(t, err, errResMsgNotFound)
|
||||
|
||||
resMsgs, err = resStore.fetchAllResolutionMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(resMsgs))
|
||||
}
|
Loading…
Reference in New Issue
Block a user