diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 3083e7252..feb5b2d5e 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -10,16 +10,16 @@ import ( "runtime" "testing" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/shachain" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" _ "github.com/btcsuite/btcwallet/walletdb/bdb" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" ) var ( diff --git a/channeldb/graph.go b/channeldb/graph.go index 49b73aad0..d716ead28 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2,6 +2,7 @@ package channeldb import ( "bytes" + "crypto/sha256" "encoding/binary" "fmt" "image/color" @@ -12,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/coreos/bbolt" @@ -2350,12 +2352,62 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) (*ChannelEdgeInfo, * return edgeInfo, policy1, policy2, nil } +// genMultiSigP2WSH generates the p2wsh'd multisig script for 2 of 2 pubkeys. +func genMultiSigP2WSH(aPub, bPub []byte) ([]byte, error) { + if len(aPub) != 33 || len(bPub) != 33 { + return nil, fmt.Errorf("Pubkey size error. Compressed " + + "pubkeys only") + } + + // Swap to sort pubkeys if needed. Keys are sorted in lexicographical + // order. The signatures within the scriptSig must also adhere to the + // order, ensuring that the signatures for each public key appears in + // the proper order on the stack. + if bytes.Compare(aPub, bPub) == 1 { + aPub, bPub = bPub, aPub + } + + // First, we'll generate the witness script for the multi-sig. + bldr := txscript.NewScriptBuilder() + bldr.AddOp(txscript.OP_2) + bldr.AddData(aPub) // Add both pubkeys (sorted). + bldr.AddData(bPub) + bldr.AddOp(txscript.OP_2) + bldr.AddOp(txscript.OP_CHECKMULTISIG) + witnessScript, err := bldr.Script() + if err != nil { + return nil, err + } + + // With the witness script generated, we'll now turn it into a p2sh + // script: + // * OP_0 + bldr = txscript.NewScriptBuilder() + bldr.AddOp(txscript.OP_0) + scriptHash := sha256.Sum256(witnessScript) + bldr.AddData(scriptHash[:]) + + return bldr.Script() +} + +// EdgePoint couples the outpoint of a channel with the funding script that it +// creates. The FilteredChainView will use this to watch for spends of this +// edge point on chain. We require both of these values as depending on the +// concrete implementation, either the pkScript, or the out point will be used. +type EdgePoint struct { + // FundingPkScript is the p2wsh multi-sig script of the target channel. + FundingPkScript []byte + + // OutPoint is the outpoint of the target channel. + OutPoint wire.OutPoint +} + // ChannelView returns the verifiable edge information for each active channel -// within the known channel graph. The set of UTXO's returned are the ones that -// need to be watched on chain to detect channel closes on the resident -// blockchain. -func (c *ChannelGraph) ChannelView() ([]wire.OutPoint, error) { - var chanPoints []wire.OutPoint +// within the known channel graph. The set of UTXO's (along with their scripts) +// returned are the ones that need to be watched on chain to detect channel +// closes on the resident blockchain. +func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { + var edgePoints []EdgePoint if err := c.db.View(func(tx *bolt.Tx) error { // We're going to iterate over the entire channel index, so // we'll need to fetch the edgeBucket to get to the index as @@ -2368,11 +2420,15 @@ func (c *ChannelGraph) ChannelView() ([]wire.OutPoint, error) { if chanIndex == nil { return ErrGraphNoEdgesFound } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } // Once we have the proper bucket, we'll range over each key // (which is the channel point for the channel) and decode it, // accumulating each entry. - return chanIndex.ForEach(func(chanPointBytes, _ []byte) error { + return chanIndex.ForEach(func(chanPointBytes, chanID []byte) error { chanPointReader := bytes.NewReader(chanPointBytes) var chanPoint wire.OutPoint @@ -2381,14 +2437,33 @@ func (c *ChannelGraph) ChannelView() ([]wire.OutPoint, error) { return err } - chanPoints = append(chanPoints, chanPoint) + edgeInfo, err := fetchChanEdgeInfo( + edgeIndex, chanID, + ) + if err != nil { + return err + } + + pkScript, err := genMultiSigP2WSH( + edgeInfo.BitcoinKey1Bytes[:], + edgeInfo.BitcoinKey2Bytes[:], + ) + if err != nil { + return err + } + + edgePoints = append(edgePoints, EdgePoint{ + FundingPkScript: pkScript, + OutPoint: chanPoint, + }) + return nil }) }); err != nil { return nil, err } - return chanPoints, nil + return edgePoints, nil } // NewChannelEdgePolicy returns a new blank ChannelEdgePolicy. diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index cc3055b3e..7c8287600 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -997,7 +997,7 @@ func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { } } -func assertChanViewEqual(t *testing.T, a []wire.OutPoint, b []*wire.OutPoint) { +func assertChanViewEqual(t *testing.T, a []EdgePoint, b []EdgePoint) { if len(a) != len(b) { _, _, line, _ := runtime.Caller(1) t.Fatalf("line %v: chan views don't match", line) @@ -1005,14 +1005,34 @@ func assertChanViewEqual(t *testing.T, a []wire.OutPoint, b []*wire.OutPoint) { chanViewSet := make(map[wire.OutPoint]struct{}) for _, op := range a { - chanViewSet[op] = struct{}{} + chanViewSet[op.OutPoint] = struct{}{} + } + + for _, op := range b { + if _, ok := chanViewSet[op.OutPoint]; !ok { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chanPoint(%v) not found in first "+ + "view", line, op) + } + } +} + +func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoint) { + if len(a) != len(b) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chan views don't match", line) + } + + chanViewSet := make(map[wire.OutPoint]struct{}) + for _, op := range a { + chanViewSet[op.OutPoint] = struct{}{} } for _, op := range b { if _, ok := chanViewSet[*op]; !ok { _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: chanPoint(%v) not found in first view", - line, op) + t.Fatalf("line %v: chanPoint(%v) not found in first "+ + "view", line, op) } } } @@ -1056,6 +1076,7 @@ func TestGraphPruning(t *testing.T) { // With the vertexes created, we'll next create a series of channels // between them. channelPoints := make([]*wire.OutPoint, 0, numNodes-1) + edgePoints := make([]EdgePoint, 0, numNodes-1) for i := 0; i < numNodes-1; i++ { txHash := sha256.Sum256([]byte{byte(i)}) chanID := uint64(i + 1) @@ -1086,6 +1107,17 @@ func TestGraphPruning(t *testing.T) { t.Fatalf("unable to add node: %v", err) } + pkScript, err := genMultiSigP2WSH( + edgeInfo.BitcoinKey1Bytes[:], edgeInfo.BitcoinKey2Bytes[:], + ) + if err != nil { + t.Fatalf("unable to gen multi-sig p2wsh: %v", err) + } + edgePoints = append(edgePoints, EdgePoint{ + FundingPkScript: pkScript, + OutPoint: op, + }) + // Create and add an edge with random data that points from // node_i -> node_i+1 edge := randEdgePolicy(chanID, op, db) @@ -1113,7 +1145,7 @@ func TestGraphPruning(t *testing.T) { if err != nil { t.Fatalf("unable to get graph channel view: %v", err) } - assertChanViewEqual(t, channelView, channelPoints) + assertChanViewEqual(t, channelView, edgePoints) // Now with our test graph created, we can test the pruning // capabilities of the channel graph. @@ -1145,7 +1177,7 @@ func TestGraphPruning(t *testing.T) { if err != nil { t.Fatalf("unable to get graph channel view: %v", err) } - assertChanViewEqual(t, channelView, channelPoints[2:]) + assertChanViewEqualChanPoints(t, channelView, channelPoints[2:]) // Next we'll create a block that doesn't close any channels within the // graph to test the negative error case.