From 47d4eb341dcf75fca9202d40f7566b24be169da6 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 1 Mar 2023 22:13:27 -0800 Subject: [PATCH] contractcourt: store new taproot resolution info in new key We pull the information from the sign descriptors and store them in the resolutions. However, the resolvers created end up duplicating the resolution data, so we update the sign descs as needed during start up. --- contractcourt/briefcase.go | 218 ++++++++++++++++++++++++++-- contractcourt/briefcase_test.go | 5 +- contractcourt/channel_arbitrator.go | 88 ++++++++++- 3 files changed, 294 insertions(+), 17 deletions(-) diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go index 2c8f6e34e..f6aa9c7f0 100644 --- a/contractcourt/briefcase.go +++ b/contractcourt/briefcase.go @@ -6,7 +6,6 @@ import ( "fmt" "io" - "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" @@ -363,6 +362,10 @@ var ( // store the confirmed active HTLC sets once we learn that a channel // has closed out on chain. commitSetKey = []byte("commit-set") + + // taprootDataKey is the key we'll use to store taproot specific data + // for the set of channels we'll need to sweep/claim. + taprootDataKey = []byte("taproot-data") ) var ( @@ -820,7 +823,26 @@ func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error } } - return nil + // If this isn't a taproot channel, then we can exit early here + // as there's no extra data to write. + switch { + case c.AnchorResolution == nil: + return nil + case !txscript.IsPayToTaproot( + c.AnchorResolution.AnchorSignDescriptor.Output.PkScript, + ): + return nil + } + + // With everything else encoded, we'll now populate the taproot + // specific items we need to store for the musig2 channels. + var tb bytes.Buffer + err = encodeTaprootAuxData(&tb, c) + if err != nil { + return err + } + + return scopeBucket.Put(taprootDataKey, tb.Bytes()) }) } @@ -861,7 +883,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er resReader, c.CommitResolution, ) if err != nil { - return err + return fmt.Errorf("unable to decode "+ + "commit res: %w", err) } } @@ -882,7 +905,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er resReader, &c.HtlcResolutions.IncomingHTLCs[i], ) if err != nil { - return err + return fmt.Errorf("unable to decode "+ + "incoming res: %w", err) } } @@ -896,7 +920,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er resReader, &c.HtlcResolutions.OutgoingHTLCs[i], ) if err != nil { - return err + return fmt.Errorf("unable to decode "+ + "outgoing res: %w", err) } } @@ -914,7 +939,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er htlc := &c.HtlcResolutions.IncomingHTLCs[i] htlc.SignDetails, err = decodeSignDetails(r) if err != nil { - return err + return fmt.Errorf("unable to decode "+ + "incoming sign desc: %w", err) } } @@ -922,7 +948,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er htlc := &c.HtlcResolutions.OutgoingHTLCs[i] htlc.SignDetails, err = decodeSignDetails(r) if err != nil { - return err + return fmt.Errorf("unable to decode "+ + "outgoing sign desc: %w", err) } } } @@ -935,7 +962,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er resReader, c.AnchorResolution, ) if err != nil { - return err + return fmt.Errorf("unable to read anchor "+ + "data: %w", err) } } @@ -947,10 +975,21 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er resReader, c.BreachResolution, ) if err != nil { - return err + return fmt.Errorf("unable to read breach "+ + "data: %w", err) } } + tapCaseBytes := scopeBucket.Get(taprootDataKey) + if tapCaseBytes != nil { + err = decodeTapRootAuxData( + bytes.NewReader(tapCaseBytes), c, + ) + if err != nil { + return fmt.Errorf("unable to read taproot "+ + "data: %w", err) + } + } return nil }, func() { c = &ContractResolutions{} @@ -1209,11 +1248,11 @@ func decodeSignDetails(r io.Reader) (*input.SignDetails, error) { if err != nil { return nil, err } - sig, err := ecdsa.ParseDERSignature(rawSig) + + s.PeerSig, err = input.ParseSignature(rawSig) if err != nil { return nil, err } - s.PeerSig = sig return &s, nil } @@ -1506,3 +1545,160 @@ func decodeCommitSet(r io.Reader) (*CommitSet, error) { return c, nil } + +func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error { + tapCase := newTaprootBriefcase() + + if c.CommitResolution != nil { + commitResolution := c.CommitResolution + commitSignDesc := commitResolution.SelfOutputSignDesc + tapCase.CtrlBlocks.CommitSweepCtrlBlock = commitSignDesc.ControlBlock + } + + for _, htlc := range c.HtlcResolutions.IncomingHTLCs { + htlc := htlc + + htlcSignDesc := htlc.SweepSignDesc + ctrlBlock := htlcSignDesc.ControlBlock + + if ctrlBlock == nil { + continue + } + + if htlc.SignedSuccessTx != nil { + resID := newResolverID( + htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint, + ) + + //nolint:lll + tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] = ctrlBlock + + // For HTLCs we need to go to the second level for, we + // also need to store the control block needed to + // publish the second level transaction. + if htlc.SignDetails != nil { + //nolint:lll + bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock + //nolint:lll + tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] = bridgeCtrlBlock + } + } else { + resID := newResolverID(htlc.ClaimOutpoint) + tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] = ctrlBlock + } + } + for _, htlc := range c.HtlcResolutions.OutgoingHTLCs { + htlc := htlc + + htlcSignDesc := htlc.SweepSignDesc + ctrlBlock := htlcSignDesc.ControlBlock + + if ctrlBlock == nil { + continue + } + + if htlc.SignedTimeoutTx != nil { + resID := newResolverID( + htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint, + ) + + //nolint:lll + tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] = ctrlBlock + + // For HTLCs we need to go to the second level for, we + // also need to store the control block needed to + // publish the second level transaction. + // + //nolint:lll + if htlc.SignDetails != nil { + //nolint:lll + bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock + //nolint:lll + tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] = bridgeCtrlBlock + } + } else { + resID := newResolverID(htlc.ClaimOutpoint) + tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] = ctrlBlock + } + + } + + if c.AnchorResolution != nil { + anchorSignDesc := c.AnchorResolution.AnchorSignDescriptor + tapCase.TapTweaks.AnchorTweak = anchorSignDesc.TapTweak + } + + return tapCase.Encode(w) +} + +func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error { + tapCase := newTaprootBriefcase() + if err := tapCase.Decode(r); err != nil { + return err + } + + if c.CommitResolution != nil { + c.CommitResolution.SelfOutputSignDesc.ControlBlock = + tapCase.CtrlBlocks.CommitSweepCtrlBlock + } + + for i := range c.HtlcResolutions.IncomingHTLCs { + htlc := c.HtlcResolutions.IncomingHTLCs[i] + + var resID resolverID + if htlc.SignedSuccessTx != nil { + resID = newResolverID( + htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint, + ) + + ctrlBlock := tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] + htlc.SweepSignDesc.ControlBlock = ctrlBlock + + //nolint:lll + if htlc.SignDetails != nil { + bridgeCtrlBlock := tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] + htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock + } + } else { + resID = newResolverID(htlc.ClaimOutpoint) + + ctrlBlock := tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] + htlc.SweepSignDesc.ControlBlock = ctrlBlock + } + + c.HtlcResolutions.IncomingHTLCs[i] = htlc + } + for i := range c.HtlcResolutions.OutgoingHTLCs { + htlc := c.HtlcResolutions.OutgoingHTLCs[i] + + var resID resolverID + if htlc.SignedTimeoutTx != nil { + resID = newResolverID( + htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint, + ) + + ctrlBlock := tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] + htlc.SweepSignDesc.ControlBlock = ctrlBlock + + //nolint:lll + if htlc.SignDetails != nil { + bridgeCtrlBlock := tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] + htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock + } + } else { + resID = newResolverID(htlc.ClaimOutpoint) + + ctrlBlock := tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] + htlc.SweepSignDesc.ControlBlock = ctrlBlock + } + + c.HtlcResolutions.OutgoingHTLCs[i] = htlc + } + + if c.AnchorResolution != nil { + c.AnchorResolution.AnchorSignDescriptor.TapTweak = + tapCase.TapTweaks.AnchorTweak + } + + return nil +} diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go index 140a85c22..a112b7f8c 100644 --- a/contractcourt/briefcase_test.go +++ b/contractcourt/briefcase_test.go @@ -625,10 +625,7 @@ func TestContractResolutionsStorage(t *testing.T) { diskRes, err := testLog.FetchContractResolutions() require.NoError(t, err, "unable to read resolution from db") - if !reflect.DeepEqual(&res, diskRes) { - t.Fatalf("resolution mismatch: expected %v\n, got %v", - spew.Sdump(&res), spew.Sdump(diskRes)) - } + require.Equal(t, res, *diskRes) // We'll now delete the state, then attempt to retrieve the set of // resolvers, no resolutions should be found. diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 37395f93a..3eb3ffb87 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -567,6 +567,80 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { return nil } +// maybeAugmentTaprootResolvers will update the contract resolution information +// for taproot channels. This ensures that all the resolvers have the latest +// resolution, which may also include data such as the control block and tap +// tweaks. +func maybeAugmentTaprootResolvers(chanType channeldb.ChannelType, + resolver ContractResolver, + contractResolutions *ContractResolutions) { + + if !chanType.IsTaproot() { + return + } + + // The on disk resolutions contains all the ctrl block + // information, so we'll set that now for the relevant + // resolvers. + switch r := resolver.(type) { + case *commitSweepResolver: + if contractResolutions.CommitResolution != nil { + //nolint:lll + r.commitResolution = *contractResolutions.CommitResolution + } + case *htlcOutgoingContestResolver: + //nolint:lll + htlcResolutions := contractResolutions.HtlcResolutions.OutgoingHTLCs + for _, htlcRes := range htlcResolutions { + htlcRes := htlcRes + + if r.htlcResolution.ClaimOutpoint == + htlcRes.ClaimOutpoint { + + r.htlcResolution = htlcRes + } + } + + case *htlcTimeoutResolver: + //nolint:lll + htlcResolutions := contractResolutions.HtlcResolutions.OutgoingHTLCs + for _, htlcRes := range htlcResolutions { + htlcRes := htlcRes + + if r.htlcResolution.ClaimOutpoint == + htlcRes.ClaimOutpoint { + + r.htlcResolution = htlcRes + } + } + + case *htlcIncomingContestResolver: + //nolint:lll + htlcResolutions := contractResolutions.HtlcResolutions.IncomingHTLCs + for _, htlcRes := range htlcResolutions { + htlcRes := htlcRes + + if r.htlcResolution.ClaimOutpoint == + htlcRes.ClaimOutpoint { + + r.htlcResolution = htlcRes + } + } + case *htlcSuccessResolver: + //nolint:lll + htlcResolutions := contractResolutions.HtlcResolutions.IncomingHTLCs + for _, htlcRes := range htlcResolutions { + htlcRes := htlcRes + + if r.htlcResolution.ClaimOutpoint == + htlcRes.ClaimOutpoint { + + r.htlcResolution = htlcRes + } + } + } +} + // relauchResolvers relaunches the set of resolvers for unresolved contracts in // order to provide them with information that's not immediately available upon // starting the ChannelArbitrator. This information should ideally be stored in @@ -651,11 +725,21 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, log.Infof("ChannelArbitrator(%v): relaunching %v contract "+ "resolvers", c.cfg.ChanPoint, len(unresolvedContracts)) - for _, resolver := range unresolvedContracts { - if chanState != nil { + for i := range unresolvedContracts { + resolver := unresolvedContracts[i] + + if chanState != nil { resolver.SupplementState(chanState) } + // For taproot channels, we'll need to also make sure the + // control block information was set properly. + maybeAugmentTaprootResolvers( + chanState.ChanType, resolver, contractResolutions, + ) + + unresolvedContracts[i] = resolver + htlcResolver, ok := resolver.(htlcContractResolver) if !ok { continue