channeldb+lnwallet: store revocation log using the new format

This commit removes the usage of the old revocation log bucket and
starts to perform db operations using the new sub-bucket.
This commit is contained in:
yyforyongyu 2022-04-08 07:36:26 +08:00
parent df810114cf
commit 37b11c4503
No known key found for this signature in database
GPG key ID: 9BCD95C4FF296868
4 changed files with 114 additions and 48 deletions

View file

@ -2306,7 +2306,7 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error {
// set of local updates that the peer still needs to send us a signature for. // set of local updates that the peer still needs to send us a signature for.
// We store this set of updates in case we go down. // We store this set of updates in case we go down.
func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg, func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg,
updates []LogUpdate) error { updates []LogUpdate, ourOutputIndex, theirOutputIndex uint32) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -2352,7 +2352,7 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg,
// TODO(roasbeef): could make the deltas relative, would save // TODO(roasbeef): could make the deltas relative, would save
// space, but then tradeoff for more disk-seeks to recover the // space, but then tradeoff for more disk-seeks to recover the
// full state. // full state.
logKey := revocationLogBucketDeprecated logKey := revocationLogBucket
logBucket, err := chanBucket.CreateBucketIfNotExists(logKey) logBucket, err := chanBucket.CreateBucketIfNotExists(logKey)
if err != nil { if err != nil {
return err return err
@ -2379,9 +2379,10 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg,
// With the commitment pointer swapped, we can now add the // With the commitment pointer swapped, we can now add the
// revoked (prior) state to the revocation log. // revoked (prior) state to the revocation log.
// err = putRevocationLog(
// TODO(roasbeef): store less logBucket, &c.RemoteCommitment,
err = appendChannelLogEntry(logBucket, &c.RemoteCommitment) ourOutputIndex, theirOutputIndex,
)
if err != nil { if err != nil {
return err return err
} }
@ -2591,9 +2592,9 @@ func (c *OpenChannel) revocationLogTailCommitHeight() (uint64, error) {
return err return err
} }
logBucket := chanBucket.NestedReadBucket(revocationLogBucketDeprecated) logBucket, err := fetchLogBucket(chanBucket)
if logBucket == nil { if err != nil {
return ErrNoPastDeltas return err
} }
// Once we have the bucket that stores the revocation log from // Once we have the bucket that stores the revocation log from
@ -2654,11 +2655,15 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) {
// intended to be used for obtaining the relevant data needed to claim all // intended to be used for obtaining the relevant data needed to claim all
// funds rightfully spendable in the case of an on-chain broadcast of the // funds rightfully spendable in the case of an on-chain broadcast of the
// commitment transaction. // commitment transaction.
func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, error) { func (c *OpenChannel) FindPreviousState(
updateNum uint64) (*RevocationLog, *ChannelCommitment, error) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
var commit ChannelCommitment commit := &ChannelCommitment{}
rl := &RevocationLog{}
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
@ -2667,24 +2672,24 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e
return err return err
} }
logBucket := chanBucket.NestedReadBucket(revocationLogBucketDeprecated) // Find the revocation log from both the new and the old
if logBucket == nil { // bucket.
return ErrNoPastDeltas r, c, err := fetchRevocationLogCompatible(chanBucket, updateNum)
}
c, err := fetchOldRevocationLog(logBucket, updateNum)
if err != nil { if err != nil {
return err return err
} }
rl = r
commit = c commit = c
return nil return nil
}, func() {}) }, func() {})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return &commit, nil // Either the `rl` or the `commit` is nil here. We return them as-is
// and leave it to the caller to decide its following action.
return rl, commit, nil
} }
// ClosureType is an enum like structure that details exactly _how_ a channel // ClosureType is an enum like structure that details exactly _how_ a channel
@ -2881,12 +2886,8 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary,
// With the base channel data deleted, attempt to delete the // With the base channel data deleted, attempt to delete the
// information stored within the revocation log. // information stored within the revocation log.
logBucket := chanBucket.NestedReadWriteBucket(revocationLogBucketDeprecated) if err := deleteLogBucket(chanBucket); err != nil {
if logBucket != nil { return err
err = chanBucket.DeleteNestedBucket(revocationLogBucketDeprecated)
if err != nil {
return err
}
} }
err = chainBucket.DeleteNestedBucket(chanPointBuf.Bytes()) err = chainBucket.DeleteNestedBucket(chanPointBuf.Bytes())
@ -3643,19 +3644,6 @@ func makeLogKey(updateNum uint64) [8]byte {
return key return key
} }
// TODO: delete
func appendChannelLogEntry(log kvdb.RwBucket,
commit *ChannelCommitment) error {
var b bytes.Buffer
if err := serializeChanCommit(&b, commit); err != nil {
return err
}
logEntrykey := makeLogKey(commit.CommitHeight)
return log.Put(logEntrykey[:], b.Bytes())
}
func fetchThawHeight(chanBucket kvdb.RBucket) (uint32, error) { func fetchThawHeight(chanBucket kvdb.RBucket) (uint32, error) {
var height uint32 var height uint32

View file

@ -52,8 +52,17 @@ var (
Port: 18555, Port: 18555,
} }
// keyLocIndex is the KeyLocator Index we use for TestKeyLocatorEncoding. // keyLocIndex is the KeyLocator Index we use for
// TestKeyLocatorEncoding.
keyLocIndex = uint32(2049) keyLocIndex = uint32(2049)
// dummyLocalOutputIndex specifics a default value for our output index
// in this test.
dummyLocalOutputIndex = uint32(0)
// dummyRemoteOutIndex specifics a default value for their output index
// in this test.
dummyRemoteOutIndex = uint32(1)
) )
// testChannelParams is a struct which details the specifics of how a channel // testChannelParams is a struct which details the specifics of how a channel
@ -548,6 +557,32 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
} }
} }
// assertRevocationLogEntryEqual asserts that, for all the fields of a given
// revocation log entry, their values match those on a given ChannelCommitment.
func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment,
r *RevocationLog) {
// Check the common fields.
require.EqualValues(
t, r.CommitTxHash, c.CommitTx.TxHash(), "CommitTx mismatch",
)
// Now check the common fields from the HTLCs.
require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch")
for i, rHtlc := range r.HTLCEntries {
cHtlc := c.Htlcs[i]
require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch")
require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(),
"Amt mismatch")
require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout,
"RefundTimeout mismatch")
require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex,
"OutputIndex mismatch")
require.Equal(t, rHtlc.Incoming, cHtlc.Incoming,
"Incoming mismatch")
}
}
func TestChannelStateTransition(t *testing.T) { func TestChannelStateTransition(t *testing.T) {
t.Parallel() t.Parallel()
@ -748,7 +783,9 @@ func TestChannelStateTransition(t *testing.T) {
fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight,
diskCommitDiff.LogUpdates, nil) diskCommitDiff.LogUpdates, nil)
err = channel.AdvanceCommitChainTail(fwdPkg, nil) err = channel.AdvanceCommitChainTail(
fwdPkg, nil, dummyLocalOutputIndex, dummyRemoteOutIndex,
)
if err != nil { if err != nil {
t.Fatalf("unable to append to revocation log: %v", err) t.Fatalf("unable to append to revocation log: %v", err)
} }
@ -761,16 +798,24 @@ func TestChannelStateTransition(t *testing.T) {
// We should be able to fetch the channel delta created above by its // We should be able to fetch the channel delta created above by its
// update number with all the state properly reconstructed. // update number with all the state properly reconstructed.
diskPrevCommit, err := channel.FindPreviousState( diskPrevCommit, _, err := channel.FindPreviousState(
oldRemoteCommit.CommitHeight, oldRemoteCommit.CommitHeight,
) )
if err != nil { if err != nil {
t.Fatalf("unable to fetch past delta: %v", err) t.Fatalf("unable to fetch past delta: %v", err)
} }
// Check the output indexes are saved as expected.
require.EqualValues(
t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex,
)
require.EqualValues(
t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex,
)
// The two deltas (the original vs the on-disk version) should // The two deltas (the original vs the on-disk version) should
// identical, and all HTLC data should properly be retained. // identical, and all HTLC data should properly be retained.
assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit) assertRevocationLogEntryEqual(t, &oldRemoteCommit, diskPrevCommit)
// The state number recovered from the tail of the revocation log // The state number recovered from the tail of the revocation log
// should be identical to this current state. // should be identical to this current state.
@ -796,17 +841,30 @@ func TestChannelStateTransition(t *testing.T) {
fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil) fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil)
err = channel.AdvanceCommitChainTail(fwdPkg, nil) err = channel.AdvanceCommitChainTail(
fwdPkg, nil, dummyLocalOutputIndex, dummyRemoteOutIndex,
)
if err != nil { if err != nil {
t.Fatalf("unable to append to revocation log: %v", err) t.Fatalf("unable to append to revocation log: %v", err)
} }
// Once again, fetch the state and ensure it has been properly updated. // Once again, fetch the state and ensure it has been properly updated.
prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight) prevCommit, _, err := channel.FindPreviousState(
oldRemoteCommit.CommitHeight,
)
if err != nil { if err != nil {
t.Fatalf("unable to fetch past delta: %v", err) t.Fatalf("unable to fetch past delta: %v", err)
} }
assertCommitmentEqual(t, &oldRemoteCommit, prevCommit)
// Check the output indexes are saved as expected.
require.EqualValues(
t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex,
)
require.EqualValues(
t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex,
)
assertRevocationLogEntryEqual(t, &oldRemoteCommit, prevCommit)
// Once again, state number recovered from the tail of the revocation // Once again, state number recovered from the tail of the revocation
// log should be identical to this current state. // log should be identical to this current state.
@ -860,7 +918,9 @@ func TestChannelStateTransition(t *testing.T) {
// Attempting to find previous states on the channel should fail as the // Attempting to find previous states on the channel should fail as the
// revocation log has been deleted. // revocation log has been deleted.
_, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight) _, _, err = updatedChannel[0].FindPreviousState(
oldRemoteCommit.CommitHeight,
)
if err == nil { if err == nil {
t.Fatal("revocation log search should have failed") t.Fatal("revocation log search should have failed")
} }

View file

@ -409,7 +409,9 @@ func TestRestoreChannelShells(t *testing.T) {
if err != ErrNoRestoredChannelMutation { if err != ErrNoRestoredChannelMutation {
t.Fatalf("able to mutate restored channel") t.Fatalf("able to mutate restored channel")
} }
err = channel.AdvanceCommitChainTail(nil, nil) err = channel.AdvanceCommitChainTail(
nil, nil, dummyLocalOutputIndex, dummyRemoteOutIndex,
)
if err != ErrNoRestoredChannelMutation { if err != ErrNoRestoredChannelMutation {
t.Fatalf("able to mutate restored channel") t.Fatalf("able to mutate restored channel")
} }

View file

@ -2282,7 +2282,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64,
// Query the on-disk revocation log for the snapshot which was recorded // Query the on-disk revocation log for the snapshot which was recorded
// at this particular state num. // at this particular state num.
revokedSnapshot, err := chanState.FindPreviousState(stateNum) _, revokedSnapshot, err := chanState.FindPreviousState(stateNum)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -4872,12 +4872,28 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
source, remoteChainTail, addUpdates, settleFailUpdates, source, remoteChainTail, addUpdates, settleFailUpdates,
) )
// We will soon be saving the current remote commitment to revocation
// log bucket, which is `lc.channelState.RemoteCommitment`. After that,
// the `RemoteCommitment` will be replaced with a newer version found
// in `CommitDiff`. Thus we need to compute the output indexes here
// before the change since the indexes are meant for the current,
// revoked remote commitment.
ourOutputIndex, theirOutputIndex, err := findOutputIndexesFromRemote(
revocation, lc.channelState,
)
if err != nil {
return nil, nil, nil, nil, err
}
// At this point, the revocation has been accepted, and we've rotated // At this point, the revocation has been accepted, and we've rotated
// the current revocation key+hash for the remote party. Therefore we // the current revocation key+hash for the remote party. Therefore we
// sync now to ensure the revocation producer state is consistent with // sync now to ensure the revocation producer state is consistent with
// the current commitment height and also to advance the on-disk // the current commitment height and also to advance the on-disk
// commitment chain. // commitment chain.
err = lc.channelState.AdvanceCommitChainTail(fwdPkg, localPeerUpdates) err = lc.channelState.AdvanceCommitChainTail(
fwdPkg, localPeerUpdates,
ourOutputIndex, theirOutputIndex,
)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }