contractcourt: store new taproot resolution info in new key

We pull the information from the sign descriptors and store them in the
resolutions. However, the resolvers created end up duplicating the
resolution data, so we update the sign descs as needed during start up.
This commit is contained in:
Olaoluwa Osuntokun 2023-03-01 22:13:27 -08:00
parent a1788fe4a2
commit 47d4eb341d
No known key found for this signature in database
GPG Key ID: 3BBD59E99B280306
3 changed files with 294 additions and 17 deletions

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
@ -363,6 +362,10 @@ var (
// store the confirmed active HTLC sets once we learn that a channel // store the confirmed active HTLC sets once we learn that a channel
// has closed out on chain. // has closed out on chain.
commitSetKey = []byte("commit-set") commitSetKey = []byte("commit-set")
// taprootDataKey is the key we'll use to store taproot specific data
// for the set of channels we'll need to sweep/claim.
taprootDataKey = []byte("taproot-data")
) )
var ( var (
@ -820,7 +823,26 @@ func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error
} }
} }
// If this isn't a taproot channel, then we can exit early here
// as there's no extra data to write.
switch {
case c.AnchorResolution == nil:
return nil return nil
case !txscript.IsPayToTaproot(
c.AnchorResolution.AnchorSignDescriptor.Output.PkScript,
):
return nil
}
// With everything else encoded, we'll now populate the taproot
// specific items we need to store for the musig2 channels.
var tb bytes.Buffer
err = encodeTaprootAuxData(&tb, c)
if err != nil {
return err
}
return scopeBucket.Put(taprootDataKey, tb.Bytes())
}) })
} }
@ -861,7 +883,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
resReader, c.CommitResolution, resReader, c.CommitResolution,
) )
if err != nil { if err != nil {
return err return fmt.Errorf("unable to decode "+
"commit res: %w", err)
} }
} }
@ -882,7 +905,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
resReader, &c.HtlcResolutions.IncomingHTLCs[i], resReader, &c.HtlcResolutions.IncomingHTLCs[i],
) )
if err != nil { if err != nil {
return err return fmt.Errorf("unable to decode "+
"incoming res: %w", err)
} }
} }
@ -896,7 +920,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
resReader, &c.HtlcResolutions.OutgoingHTLCs[i], resReader, &c.HtlcResolutions.OutgoingHTLCs[i],
) )
if err != nil { if err != nil {
return err return fmt.Errorf("unable to decode "+
"outgoing res: %w", err)
} }
} }
@ -914,7 +939,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
htlc := &c.HtlcResolutions.IncomingHTLCs[i] htlc := &c.HtlcResolutions.IncomingHTLCs[i]
htlc.SignDetails, err = decodeSignDetails(r) htlc.SignDetails, err = decodeSignDetails(r)
if err != nil { if err != nil {
return err return fmt.Errorf("unable to decode "+
"incoming sign desc: %w", err)
} }
} }
@ -922,7 +948,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
htlc := &c.HtlcResolutions.OutgoingHTLCs[i] htlc := &c.HtlcResolutions.OutgoingHTLCs[i]
htlc.SignDetails, err = decodeSignDetails(r) htlc.SignDetails, err = decodeSignDetails(r)
if err != nil { if err != nil {
return err return fmt.Errorf("unable to decode "+
"outgoing sign desc: %w", err)
} }
} }
} }
@ -935,7 +962,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
resReader, c.AnchorResolution, resReader, c.AnchorResolution,
) )
if err != nil { if err != nil {
return err return fmt.Errorf("unable to read anchor "+
"data: %w", err)
} }
} }
@ -947,10 +975,21 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
resReader, c.BreachResolution, resReader, c.BreachResolution,
) )
if err != nil { if err != nil {
return err return fmt.Errorf("unable to read breach "+
"data: %w", err)
} }
} }
tapCaseBytes := scopeBucket.Get(taprootDataKey)
if tapCaseBytes != nil {
err = decodeTapRootAuxData(
bytes.NewReader(tapCaseBytes), c,
)
if err != nil {
return fmt.Errorf("unable to read taproot "+
"data: %w", err)
}
}
return nil return nil
}, func() { }, func() {
c = &ContractResolutions{} c = &ContractResolutions{}
@ -1209,11 +1248,11 @@ func decodeSignDetails(r io.Reader) (*input.SignDetails, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
sig, err := ecdsa.ParseDERSignature(rawSig)
s.PeerSig, err = input.ParseSignature(rawSig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.PeerSig = sig
return &s, nil return &s, nil
} }
@ -1506,3 +1545,160 @@ func decodeCommitSet(r io.Reader) (*CommitSet, error) {
return c, nil return c, nil
} }
func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
tapCase := newTaprootBriefcase()
if c.CommitResolution != nil {
commitResolution := c.CommitResolution
commitSignDesc := commitResolution.SelfOutputSignDesc
tapCase.CtrlBlocks.CommitSweepCtrlBlock = commitSignDesc.ControlBlock
}
for _, htlc := range c.HtlcResolutions.IncomingHTLCs {
htlc := htlc
htlcSignDesc := htlc.SweepSignDesc
ctrlBlock := htlcSignDesc.ControlBlock
if ctrlBlock == nil {
continue
}
if htlc.SignedSuccessTx != nil {
resID := newResolverID(
htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint,
)
//nolint:lll
tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] = ctrlBlock
// For HTLCs we need to go to the second level for, we
// also need to store the control block needed to
// publish the second level transaction.
if htlc.SignDetails != nil {
//nolint:lll
bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock
//nolint:lll
tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] = bridgeCtrlBlock
}
} else {
resID := newResolverID(htlc.ClaimOutpoint)
tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] = ctrlBlock
}
}
for _, htlc := range c.HtlcResolutions.OutgoingHTLCs {
htlc := htlc
htlcSignDesc := htlc.SweepSignDesc
ctrlBlock := htlcSignDesc.ControlBlock
if ctrlBlock == nil {
continue
}
if htlc.SignedTimeoutTx != nil {
resID := newResolverID(
htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint,
)
//nolint:lll
tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] = ctrlBlock
// For HTLCs we need to go to the second level for, we
// also need to store the control block needed to
// publish the second level transaction.
//
//nolint:lll
if htlc.SignDetails != nil {
//nolint:lll
bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock
//nolint:lll
tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] = bridgeCtrlBlock
}
} else {
resID := newResolverID(htlc.ClaimOutpoint)
tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] = ctrlBlock
}
}
if c.AnchorResolution != nil {
anchorSignDesc := c.AnchorResolution.AnchorSignDescriptor
tapCase.TapTweaks.AnchorTweak = anchorSignDesc.TapTweak
}
return tapCase.Encode(w)
}
func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error {
tapCase := newTaprootBriefcase()
if err := tapCase.Decode(r); err != nil {
return err
}
if c.CommitResolution != nil {
c.CommitResolution.SelfOutputSignDesc.ControlBlock =
tapCase.CtrlBlocks.CommitSweepCtrlBlock
}
for i := range c.HtlcResolutions.IncomingHTLCs {
htlc := c.HtlcResolutions.IncomingHTLCs[i]
var resID resolverID
if htlc.SignedSuccessTx != nil {
resID = newResolverID(
htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint,
)
ctrlBlock := tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock
//nolint:lll
if htlc.SignDetails != nil {
bridgeCtrlBlock := tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID]
htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock
}
} else {
resID = newResolverID(htlc.ClaimOutpoint)
ctrlBlock := tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock
}
c.HtlcResolutions.IncomingHTLCs[i] = htlc
}
for i := range c.HtlcResolutions.OutgoingHTLCs {
htlc := c.HtlcResolutions.OutgoingHTLCs[i]
var resID resolverID
if htlc.SignedTimeoutTx != nil {
resID = newResolverID(
htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint,
)
ctrlBlock := tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock
//nolint:lll
if htlc.SignDetails != nil {
bridgeCtrlBlock := tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID]
htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock
}
} else {
resID = newResolverID(htlc.ClaimOutpoint)
ctrlBlock := tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock
}
c.HtlcResolutions.OutgoingHTLCs[i] = htlc
}
if c.AnchorResolution != nil {
c.AnchorResolution.AnchorSignDescriptor.TapTweak =
tapCase.TapTweaks.AnchorTweak
}
return nil
}

View File

@ -625,10 +625,7 @@ func TestContractResolutionsStorage(t *testing.T) {
diskRes, err := testLog.FetchContractResolutions() diskRes, err := testLog.FetchContractResolutions()
require.NoError(t, err, "unable to read resolution from db") require.NoError(t, err, "unable to read resolution from db")
if !reflect.DeepEqual(&res, diskRes) { require.Equal(t, res, *diskRes)
t.Fatalf("resolution mismatch: expected %v\n, got %v",
spew.Sdump(&res), spew.Sdump(diskRes))
}
// We'll now delete the state, then attempt to retrieve the set of // We'll now delete the state, then attempt to retrieve the set of
// resolvers, no resolutions should be found. // resolvers, no resolutions should be found.

View File

@ -567,6 +567,80 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error {
return nil return nil
} }
// maybeAugmentTaprootResolvers will update the contract resolution information
// for taproot channels. This ensures that all the resolvers have the latest
// resolution, which may also include data such as the control block and tap
// tweaks.
func maybeAugmentTaprootResolvers(chanType channeldb.ChannelType,
resolver ContractResolver,
contractResolutions *ContractResolutions) {
if !chanType.IsTaproot() {
return
}
// The on disk resolutions contains all the ctrl block
// information, so we'll set that now for the relevant
// resolvers.
switch r := resolver.(type) {
case *commitSweepResolver:
if contractResolutions.CommitResolution != nil {
//nolint:lll
r.commitResolution = *contractResolutions.CommitResolution
}
case *htlcOutgoingContestResolver:
//nolint:lll
htlcResolutions := contractResolutions.HtlcResolutions.OutgoingHTLCs
for _, htlcRes := range htlcResolutions {
htlcRes := htlcRes
if r.htlcResolution.ClaimOutpoint ==
htlcRes.ClaimOutpoint {
r.htlcResolution = htlcRes
}
}
case *htlcTimeoutResolver:
//nolint:lll
htlcResolutions := contractResolutions.HtlcResolutions.OutgoingHTLCs
for _, htlcRes := range htlcResolutions {
htlcRes := htlcRes
if r.htlcResolution.ClaimOutpoint ==
htlcRes.ClaimOutpoint {
r.htlcResolution = htlcRes
}
}
case *htlcIncomingContestResolver:
//nolint:lll
htlcResolutions := contractResolutions.HtlcResolutions.IncomingHTLCs
for _, htlcRes := range htlcResolutions {
htlcRes := htlcRes
if r.htlcResolution.ClaimOutpoint ==
htlcRes.ClaimOutpoint {
r.htlcResolution = htlcRes
}
}
case *htlcSuccessResolver:
//nolint:lll
htlcResolutions := contractResolutions.HtlcResolutions.IncomingHTLCs
for _, htlcRes := range htlcResolutions {
htlcRes := htlcRes
if r.htlcResolution.ClaimOutpoint ==
htlcRes.ClaimOutpoint {
r.htlcResolution = htlcRes
}
}
}
}
// relauchResolvers relaunches the set of resolvers for unresolved contracts in // relauchResolvers relaunches the set of resolvers for unresolved contracts in
// order to provide them with information that's not immediately available upon // order to provide them with information that's not immediately available upon
// starting the ChannelArbitrator. This information should ideally be stored in // starting the ChannelArbitrator. This information should ideally be stored in
@ -651,11 +725,21 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet,
log.Infof("ChannelArbitrator(%v): relaunching %v contract "+ log.Infof("ChannelArbitrator(%v): relaunching %v contract "+
"resolvers", c.cfg.ChanPoint, len(unresolvedContracts)) "resolvers", c.cfg.ChanPoint, len(unresolvedContracts))
for _, resolver := range unresolvedContracts { for i := range unresolvedContracts {
resolver := unresolvedContracts[i]
if chanState != nil { if chanState != nil {
resolver.SupplementState(chanState) resolver.SupplementState(chanState)
} }
// For taproot channels, we'll need to also make sure the
// control block information was set properly.
maybeAugmentTaprootResolvers(
chanState.ChanType, resolver, contractResolutions,
)
unresolvedContracts[i] = resolver
htlcResolver, ok := resolver.(htlcContractResolver) htlcResolver, ok := resolver.(htlcContractResolver)
if !ok { if !ok {
continue continue