contractcourt: fix concurrent access to resolved

This commit makes `resolved` an atomic bool to avoid data race. This
field is now defined in `contractResolverKit` to avoid code duplication.
This commit is contained in:
yyforyongyu 2024-07-10 18:08:23 +08:00
parent 47722292c5
commit 4f5ccb8650
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
11 changed files with 111 additions and 129 deletions

View File

@ -24,9 +24,6 @@ type anchorResolver struct {
// anchor is the outpoint on the commitment transaction.
anchor wire.OutPoint
// resolved reflects if the contract has been fully resolved or not.
resolved bool
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
@ -89,7 +86,7 @@ func (c *anchorResolver) ResolverKey() []byte {
// NOTE: Part of the ContractResolver interface.
func (c *anchorResolver) Resolve() (ContractResolver, error) {
// If we're already resolved, then we can exit early.
if c.resolved {
if c.IsResolved() {
c.log.Errorf("already resolved")
return nil, nil
}
@ -139,7 +136,7 @@ func (c *anchorResolver) Resolve() (ContractResolver, error) {
)
c.reportLock.Unlock()
c.resolved = true
c.markResolved()
return nil, c.PutResolverReport(nil, report)
}
@ -154,14 +151,6 @@ func (c *anchorResolver) Stop() {
close(c.quit)
}
// IsResolved returns true if the stored state in the resolve is fully
// resolved. In this case the target output can be forgotten.
//
// NOTE: Part of the ContractResolver interface.
func (c *anchorResolver) IsResolved() bool {
return c.resolved
}
// SupplementState allows the user of a ContractResolver to supplement it with
// state required for the proper resolution of a contract.
//
@ -198,7 +187,7 @@ func (c *anchorResolver) Launch() error {
c.launched = true
// If we're already resolved, then we can exit early.
if c.resolved {
if c.IsResolved() {
c.log.Errorf("already resolved")
return nil
}

View File

@ -12,9 +12,6 @@ import (
// future, this will likely take over the duties the current BreachArbitrator
// has.
type breachResolver struct {
// resolved reflects if the contract has been fully resolved or not.
resolved bool
// subscribed denotes whether or not the breach resolver has subscribed
// to the BreachArbitrator for breach resolution.
subscribed bool
@ -62,7 +59,7 @@ func (b *breachResolver) Resolve() (ContractResolver, error) {
// If the breach resolution process is already complete, then
// we can cleanup and checkpoint the resolved state.
if complete {
b.resolved = true
b.markResolved()
return nil, b.Checkpoint(b)
}
@ -75,8 +72,9 @@ func (b *breachResolver) Resolve() (ContractResolver, error) {
// The replyChan has been closed, signalling that the breach
// has been fully resolved. Checkpoint the resolved state and
// exit.
b.resolved = true
b.markResolved()
return nil, b.Checkpoint(b)
case <-b.quit:
}
@ -89,19 +87,13 @@ func (b *breachResolver) Stop() {
close(b.quit)
}
// IsResolved returns true if the breachResolver is fully resolved and cleanup
// can occur.
func (b *breachResolver) IsResolved() bool {
return b.resolved
}
// SupplementState adds additional state to the breachResolver.
func (b *breachResolver) SupplementState(_ *channeldb.OpenChannel) {
}
// Encode encodes the breachResolver to the passed writer.
func (b *breachResolver) Encode(w io.Writer) error {
return binary.Write(w, endian, b.resolved)
return binary.Write(w, endian, b.IsResolved())
}
// newBreachResolverFromReader attempts to decode an encoded breachResolver
@ -114,9 +106,13 @@ func newBreachResolverFromReader(r io.Reader, resCfg ResolverConfig) (
replyChan: make(chan struct{}),
}
if err := binary.Read(r, endian, &b.resolved); err != nil {
var resolved bool
if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
if resolved {
b.markResolved()
}
b.initLogger(fmt.Sprintf("%T(%v)", b, b.ChanPoint))

View File

@ -206,8 +206,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver,
ogRes.outputIncubating, diskRes.outputIncubating)
}
if ogRes.resolved != diskRes.resolved {
t.Fatalf("expected %v, got %v", ogRes.resolved,
diskRes.resolved)
t.Fatalf("expected %v, got %v", ogRes.resolved.Load(),
diskRes.resolved.Load())
}
if ogRes.broadcastHeight != diskRes.broadcastHeight {
t.Fatalf("expected %v, got %v",
@ -229,8 +229,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver,
ogRes.outputIncubating, diskRes.outputIncubating)
}
if ogRes.resolved != diskRes.resolved {
t.Fatalf("expected %v, got %v", ogRes.resolved,
diskRes.resolved)
t.Fatalf("expected %v, got %v", ogRes.resolved.Load(),
diskRes.resolved.Load())
}
if ogRes.broadcastHeight != diskRes.broadcastHeight {
t.Fatalf("expected %v, got %v",
@ -275,8 +275,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver,
ogRes.commitResolution, diskRes.commitResolution)
}
if ogRes.resolved != diskRes.resolved {
t.Fatalf("expected %v, got %v", ogRes.resolved,
diskRes.resolved)
t.Fatalf("expected %v, got %v", ogRes.resolved.Load(),
diskRes.resolved.Load())
}
if ogRes.broadcastHeight != diskRes.broadcastHeight {
t.Fatalf("expected %v, got %v",
@ -312,13 +312,14 @@ func TestContractInsertionRetrieval(t *testing.T) {
SweepSignDesc: testSignDesc,
},
outputIncubating: true,
resolved: true,
broadcastHeight: 102,
htlc: channeldb.HTLC{
HtlcIndex: 12,
},
}
successResolver := htlcSuccessResolver{
timeoutResolver.resolved.Store(true)
successResolver := &htlcSuccessResolver{
htlcResolution: lnwallet.IncomingHtlcResolution{
Preimage: testPreimage,
SignedSuccessTx: nil,
@ -327,40 +328,49 @@ func TestContractInsertionRetrieval(t *testing.T) {
SweepSignDesc: testSignDesc,
},
outputIncubating: true,
resolved: true,
broadcastHeight: 109,
htlc: channeldb.HTLC{
RHash: testPreimage,
},
}
resolvers := []ContractResolver{
&timeoutResolver,
&successResolver,
&commitSweepResolver{
commitResolution: lnwallet.CommitOutputResolution{
SelfOutPoint: testChanPoint2,
SelfOutputSignDesc: testSignDesc,
MaturityDelay: 99,
},
resolved: false,
broadcastHeight: 109,
chanPoint: testChanPoint1,
successResolver.resolved.Store(true)
commitResolver := &commitSweepResolver{
commitResolution: lnwallet.CommitOutputResolution{
SelfOutPoint: testChanPoint2,
SelfOutputSignDesc: testSignDesc,
MaturityDelay: 99,
},
broadcastHeight: 109,
chanPoint: testChanPoint1,
}
commitResolver.resolved.Store(false)
resolvers := []ContractResolver{
&timeoutResolver, successResolver, commitResolver,
}
// All resolvers require a unique ResolverKey() output. To achieve this
// for the composite resolvers, we'll mutate the underlying resolver
// with a new outpoint.
contestTimeout := timeoutResolver
contestTimeout.htlcResolution.ClaimOutpoint = randOutPoint()
contestTimeout := htlcTimeoutResolver{
htlcResolution: lnwallet.OutgoingHtlcResolution{
ClaimOutpoint: randOutPoint(),
SweepSignDesc: testSignDesc,
},
}
resolvers = append(resolvers, &htlcOutgoingContestResolver{
htlcTimeoutResolver: &contestTimeout,
})
contestSuccess := successResolver
contestSuccess.htlcResolution.ClaimOutpoint = randOutPoint()
contestSuccess := &htlcSuccessResolver{
htlcResolution: lnwallet.IncomingHtlcResolution{
ClaimOutpoint: randOutPoint(),
SweepSignDesc: testSignDesc,
},
}
resolvers = append(resolvers, &htlcIncomingContestResolver{
htlcExpiry: 100,
htlcSuccessResolver: &contestSuccess,
htlcSuccessResolver: contestSuccess,
})
// For quick lookup during the test, we'll create this map which allow
@ -438,12 +448,12 @@ func TestContractResolution(t *testing.T) {
SweepSignDesc: testSignDesc,
},
outputIncubating: true,
resolved: true,
broadcastHeight: 192,
htlc: channeldb.HTLC{
HtlcIndex: 9912,
},
}
timeoutResolver.resolved.Store(true)
// First, we'll insert the resolver into the database and ensure that
// we get the same resolver out the other side. We do not need to apply
@ -491,12 +501,13 @@ func TestContractSwapping(t *testing.T) {
SweepSignDesc: testSignDesc,
},
outputIncubating: true,
resolved: true,
broadcastHeight: 102,
htlc: channeldb.HTLC{
HtlcIndex: 12,
},
}
timeoutResolver.resolved.Store(true)
contestResolver := &htlcOutgoingContestResolver{
htlcTimeoutResolver: timeoutResolver,
}

View File

@ -39,9 +39,6 @@ type commitSweepResolver struct {
// this HTLC on-chain.
commitResolution lnwallet.CommitOutputResolution
// resolved reflects if the contract has been fully resolved or not.
resolved bool
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
@ -171,7 +168,7 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) {
//nolint:funlen
func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
// If we're already resolved, then we can exit early.
if c.resolved {
if c.IsResolved() {
c.log.Errorf("already resolved")
return nil, nil
}
@ -224,7 +221,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
report := c.currentReport.resolverReport(
&sweepTxID, channeldb.ResolverTypeCommit, outcome,
)
c.resolved = true
c.markResolved()
// Checkpoint the resolver with a closure that will write the outcome
// of the resolver and its sweep transaction to disk.
@ -241,14 +238,6 @@ func (c *commitSweepResolver) Stop() {
close(c.quit)
}
// IsResolved returns true if the stored state in the resolve is fully
// resolved. In this case the target output can be forgotten.
//
// NOTE: Part of the ContractResolver interface.
func (c *commitSweepResolver) IsResolved() bool {
return c.resolved
}
// SupplementState allows the user of a ContractResolver to supplement it with
// state required for the proper resolution of a contract.
//
@ -277,7 +266,7 @@ func (c *commitSweepResolver) Encode(w io.Writer) error {
return err
}
if err := binary.Write(w, endian, c.resolved); err != nil {
if err := binary.Write(w, endian, c.IsResolved()); err != nil {
return err
}
if err := binary.Write(w, endian, c.broadcastHeight); err != nil {
@ -312,9 +301,14 @@ func newCommitSweepResolverFromReader(r io.Reader, resCfg ResolverConfig) (
return nil, err
}
if err := binary.Read(r, endian, &c.resolved); err != nil {
var resolved bool
if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
if resolved {
c.markResolved()
}
if err := binary.Read(r, endian, &c.broadcastHeight); err != nil {
return nil, err
}
@ -383,7 +377,7 @@ func (c *commitSweepResolver) Launch() error {
c.launched = true
// If we're already resolved, then we can exit early.
if c.resolved {
if c.IsResolved() {
c.log.Errorf("already resolved")
return nil
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"sync/atomic"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog/v2"
@ -119,6 +120,9 @@ type contractResolverKit struct {
// launched specifies whether the resolver has been launched. Calling
// `Launch` will be a no-op if this is true.
launched bool
// resolved reflects if the contract has been fully resolved or not.
resolved atomic.Bool
}
// newContractResolverKit instantiates the mix-in struct.
@ -137,6 +141,19 @@ func (r *contractResolverKit) initLogger(prefix string) {
r.log = log.WithPrefix(logPrefix)
}
// IsResolved returns true if the stored state in the resolve is fully
// resolved. In this case the target output can be forgotten.
//
// NOTE: Part of the ContractResolver interface.
func (r *contractResolverKit) IsResolved() bool {
return r.resolved.Load()
}
// markResolved marks the resolver as resolved.
func (r *contractResolverKit) markResolved() {
r.resolved.Store(true)
}
var (
// errResolverShuttingDown is returned when the resolver stops
// progressing because it received the quit signal.

View File

@ -124,7 +124,7 @@ func (h *htlcIncomingContestResolver) Launch() error {
func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
// If we're already full resolved, then we don't have anything further
// to do.
if h.resolved {
if h.IsResolved() {
h.log.Errorf("already resolved")
return nil, nil
}
@ -140,7 +140,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
// will time it out and get their funds back. This situation
// can present itself when we crash before processRemoteAdds in
// the link has ran.
h.resolved = true
h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
@ -193,7 +193,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
log.Infof("%T(%v): HTLC has timed out (expiry=%v, height=%v), "+
"abandoning", h, h.htlcResolution.ClaimOutpoint,
h.htlcExpiry, currentHeight)
h.resolved = true
h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
@ -234,7 +234,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
h.htlcResolution.ClaimOutpoint,
h.htlcExpiry, currentHeight)
h.resolved = true
h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
@ -396,7 +396,8 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
"(expiry=%v, height=%v), abandoning", h,
h.htlcResolution.ClaimOutpoint,
h.htlcExpiry, currentHeight)
h.resolved = true
h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
@ -517,14 +518,6 @@ func (h *htlcIncomingContestResolver) Stop() {
close(h.quit)
}
// IsResolved returns true if the stored state in the resolve is fully
// resolved. In this case the target output can be forgotten.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcIncomingContestResolver) IsResolved() bool {
return h.resolved
}
// Encode writes an encoded version of the ContractResolver into the passed
// Writer.
//

View File

@ -82,7 +82,7 @@ func (h *htlcOutgoingContestResolver) Launch() error {
func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) {
// If we're already full resolved, then we don't have anything further
// to do.
if h.resolved {
if h.IsResolved() {
h.log.Errorf("already resolved")
return nil, nil
}
@ -215,14 +215,6 @@ func (h *htlcOutgoingContestResolver) Stop() {
close(h.quit)
}
// IsResolved returns true if the stored state in the resolve is fully
// resolved. In this case the target output can be forgotten.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcOutgoingContestResolver) IsResolved() bool {
return h.resolved
}
// Encode writes an encoded version of the ContractResolver into the passed
// Writer.
//

View File

@ -42,9 +42,6 @@ type htlcSuccessResolver struct {
// second-level output (true).
outputIncubating bool
// resolved reflects if the contract has been fully resolved or not.
resolved bool
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
@ -122,7 +119,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) {
switch {
// If we're already resolved, then we can exit early.
case h.resolved:
case h.IsResolved():
h.log.Errorf("already resolved")
// If this is an output on the remote party's commitment transaction,
@ -226,7 +223,7 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash) error {
}
// Finally, we checkpoint the resolver with our report(s).
h.resolved = true
h.markResolved()
return h.Checkpoint(h, reports...)
}
@ -241,14 +238,6 @@ func (h *htlcSuccessResolver) Stop() {
close(h.quit)
}
// IsResolved returns true if the stored state in the resolve is fully
// resolved. In this case the target output can be forgotten.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcSuccessResolver) IsResolved() bool {
return h.resolved
}
// report returns a report on the resolution state of the contract.
func (h *htlcSuccessResolver) report() *ContractReport {
// If the sign details are nil, the report will be created by handled
@ -298,7 +287,7 @@ func (h *htlcSuccessResolver) Encode(w io.Writer) error {
if err := binary.Write(w, endian, h.outputIncubating); err != nil {
return err
}
if err := binary.Write(w, endian, h.resolved); err != nil {
if err := binary.Write(w, endian, h.IsResolved()); err != nil {
return err
}
if err := binary.Write(w, endian, h.broadcastHeight); err != nil {
@ -337,9 +326,15 @@ func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) (
if err := binary.Read(r, endian, &h.outputIncubating); err != nil {
return nil, err
}
if err := binary.Read(r, endian, &h.resolved); err != nil {
var resolved bool
if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
if resolved {
h.markResolved()
}
if err := binary.Read(r, endian, &h.broadcastHeight); err != nil {
return nil, err
}
@ -745,7 +740,7 @@ func (h *htlcSuccessResolver) Launch() error {
switch {
// If we're already resolved, then we can exit early.
case h.resolved:
case h.IsResolved():
h.log.Errorf("already resolved")
return nil

View File

@ -616,11 +616,11 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext,
var resolved, incubating bool
if h, ok := resolver.(*htlcSuccessResolver); ok {
resolved = h.resolved
resolved = h.resolved.Load()
incubating = h.outputIncubating
}
if h, ok := resolver.(*htlcTimeoutResolver); ok {
resolved = h.resolved
resolved = h.resolved.Load()
incubating = h.outputIncubating
}

View File

@ -38,9 +38,6 @@ type htlcTimeoutResolver struct {
// incubator (utxo nursery).
outputIncubating bool
// resolved reflects if the contract has been fully resolved or not.
resolved bool
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
@ -238,7 +235,7 @@ func (h *htlcTimeoutResolver) claimCleanUp(
}); err != nil {
return err
}
h.resolved = true
h.markResolved()
// Checkpoint our resolver with a report which reflects the preimage
// claim by the remote party.
@ -424,7 +421,7 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool {
// NOTE: Part of the ContractResolver interface.
func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) {
// If we're already resolved, then we can exit early.
if h.resolved {
if h.IsResolved() {
h.log.Errorf("already resolved")
return nil, nil
}
@ -622,14 +619,6 @@ func (h *htlcTimeoutResolver) Stop() {
close(h.quit)
}
// IsResolved returns true if the stored state in the resolve is fully
// resolved. In this case the target output can be forgotten.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcTimeoutResolver) IsResolved() bool {
return h.resolved
}
// report returns a report on the resolution state of the contract.
func (h *htlcTimeoutResolver) report() *ContractReport {
// If we have a SignedTimeoutTx but no SignDetails, this is a local
@ -689,7 +678,7 @@ func (h *htlcTimeoutResolver) Encode(w io.Writer) error {
if err := binary.Write(w, endian, h.outputIncubating); err != nil {
return err
}
if err := binary.Write(w, endian, h.resolved); err != nil {
if err := binary.Write(w, endian, h.IsResolved()); err != nil {
return err
}
if err := binary.Write(w, endian, h.broadcastHeight); err != nil {
@ -730,9 +719,15 @@ func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) (
if err := binary.Read(r, endian, &h.outputIncubating); err != nil {
return nil, err
}
if err := binary.Read(r, endian, &h.resolved); err != nil {
var resolved bool
if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
if resolved {
h.markResolved()
}
if err := binary.Read(r, endian, &h.broadcastHeight); err != nil {
return nil, err
}
@ -1149,7 +1144,7 @@ func (h *htlcTimeoutResolver) checkpointClaim(
}
// Finally, we checkpoint the resolver with our report(s).
h.resolved = true
h.markResolved()
return h.Checkpoint(h, report)
}
@ -1285,7 +1280,7 @@ func (h *htlcTimeoutResolver) Launch() error {
switch {
// If we're already resolved, then we can exit early.
case h.resolved:
case h.IsResolved():
h.log.Errorf("already resolved")
return nil

View File

@ -532,7 +532,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
wg.Wait()
// Finally, the resolver should be marked as resolved.
if !resolver.resolved {
if !resolver.resolved.Load() {
t.Fatalf("resolver should be marked as resolved")
}
}