diff --git a/watchtower/blob/justice_kit.go b/watchtower/blob/justice_kit.go index 0eec70926..952d812af 100644 --- a/watchtower/blob/justice_kit.go +++ b/watchtower/blob/justice_kit.go @@ -17,12 +17,6 @@ import ( ) const ( - // MinVersion is the minimum blob version supported by this package. - MinVersion = 0 - - // MaxVersion is the maximumm blob version supported by this package. - MaxVersion = 0 - // NonceSize is the length of a chacha20poly1305 nonce, 24 bytes. NonceSize = chacha20poly1305.NonceSizeX @@ -53,14 +47,14 @@ const ( // nonce: 24 bytes // enciphered plaintext: n bytes // MAC: 16 bytes -func Size(ver uint16) int { - return NonceSize + PlaintextSize(ver) + CiphertextExpansion +func Size(blobType Type) int { + return NonceSize + PlaintextSize(blobType) + CiphertextExpansion } // PlaintextSize returns the size of the encoded-but-unencrypted blob in bytes. -func PlaintextSize(ver uint16) int { - switch ver { - case 0: +func PlaintextSize(blobType Type) int { + switch { + case blobType.Has(FlagCommitOutputs): return V0PlaintextSize default: return 0 @@ -71,9 +65,9 @@ var ( // byteOrder specifies a big-endian encoding of all integer values. byteOrder = binary.BigEndian - // ErrUnknownBlobVersion signals that we don't understand the requested + // ErrUnknownBlobType signals that we don't understand the requested // blob encoding scheme. - ErrUnknownBlobVersion = errors.New("unknown blob version") + ErrUnknownBlobType = errors.New("unknown blob type") // ErrCiphertextTooSmall is a decryption error signaling that the // ciphertext is smaller than the ciphertext expansion factor. @@ -229,7 +223,7 @@ func (b *JusticeKit) CommitToRemoteWitnessStack() ([][]byte, error) { // // NOTE: It is the caller's responsibility to ensure that this method is only // called once for a given (nonce, key) pair. -func (b *JusticeKit) Encrypt(key []byte, version uint16) ([]byte, error) { +func (b *JusticeKit) Encrypt(key []byte, blobType Type) ([]byte, error) { // Fail if the nonce is not 32-bytes. if len(key) != KeySize { return nil, ErrKeySize @@ -238,7 +232,7 @@ func (b *JusticeKit) Encrypt(key []byte, version uint16) ([]byte, error) { // Encode the plaintext using the provided version, to obtain the // plaintext bytes. var ptxtBuf bytes.Buffer - err := b.encode(&ptxtBuf, version) + err := b.encode(&ptxtBuf, blobType) if err != nil { return nil, err } @@ -252,7 +246,7 @@ func (b *JusticeKit) Encrypt(key []byte, version uint16) ([]byte, error) { // Allocate the ciphertext, which will contain the nonce, encrypted // plaintext and MAC. plaintext := ptxtBuf.Bytes() - ciphertext := make([]byte, Size(version)) + ciphertext := make([]byte, Size(blobType)) // Generate a random 24-byte nonce in the ciphertext's prefix. nonce := ciphertext[:NonceSize] @@ -270,7 +264,7 @@ func (b *JusticeKit) Encrypt(key []byte, version uint16) ([]byte, error) { // Decrypt unenciphers a blob of justice by decrypting the ciphertext using // chacha20poly1305 with the chosen (nonce, key) pair. The internal plaintext is // then deserialized using the given encoding version. -func Decrypt(key, ciphertext []byte, version uint16) (*JusticeKit, error) { +func Decrypt(key, ciphertext []byte, blobType Type) (*JusticeKit, error) { switch { // Fail if the blob's overall length is less than required for the nonce @@ -305,7 +299,7 @@ func Decrypt(key, ciphertext []byte, version uint16) (*JusticeKit, error) { // If decryption succeeded, we will then decode the plaintext bytes // using the specified blob version. boj := &JusticeKit{} - err = boj.decode(bytes.NewReader(plaintext), version) + err = boj.decode(bytes.NewReader(plaintext), blobType) if err != nil { return nil, err } @@ -315,23 +309,23 @@ func Decrypt(key, ciphertext []byte, version uint16) (*JusticeKit, error) { // encode serializes the JusticeKit according to the version, returning an // error if the version is unknown. -func (b *JusticeKit) encode(w io.Writer, ver uint16) error { - switch ver { - case 0: +func (b *JusticeKit) encode(w io.Writer, blobType Type) error { + switch { + case blobType.Has(FlagCommitOutputs): return b.encodeV0(w) default: - return ErrUnknownBlobVersion + return ErrUnknownBlobType } } // decode deserializes the JusticeKit according to the version, returning an // error if the version is unknown. -func (b *JusticeKit) decode(r io.Reader, ver uint16) error { - switch ver { - case 0: +func (b *JusticeKit) decode(r io.Reader, blobType Type) error { + switch { + case blobType.Has(FlagCommitOutputs): return b.decodeV0(r) default: - return ErrUnknownBlobVersion + return ErrUnknownBlobType } } diff --git a/watchtower/blob/justice_kit_test.go b/watchtower/blob/justice_kit_test.go index ba36f134a..922290b52 100644 --- a/watchtower/blob/justice_kit_test.go +++ b/watchtower/blob/justice_kit_test.go @@ -1,12 +1,16 @@ package blob_test import ( + "bytes" "crypto/rand" "encoding/binary" "io" "reflect" "testing" + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/txscript" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/blob" ) @@ -38,8 +42,8 @@ func makeAddr(size int) []byte { type descriptorTest struct { name string - encVersion uint16 - decVersion uint16 + encVersion blob.Type + decVersion blob.Type sweepAddr []byte revPubKey blob.PubKey delayPubKey blob.PubKey @@ -52,11 +56,15 @@ type descriptorTest struct { decErr error } +var rewardAndCommitType = blob.TypeFromFlags( + blob.FlagReward, blob.FlagCommitOutputs, +) + var descriptorTests = []descriptorTest{ { name: "to-local only", - encVersion: 0, - decVersion: 0, + encVersion: blob.TypeDefault, + decVersion: blob.TypeDefault, sweepAddr: makeAddr(22), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -65,8 +73,8 @@ var descriptorTests = []descriptorTest{ }, { name: "to-local and p2wkh", - encVersion: 0, - decVersion: 0, + encVersion: rewardAndCommitType, + decVersion: rewardAndCommitType, sweepAddr: makeAddr(22), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -78,30 +86,30 @@ var descriptorTests = []descriptorTest{ }, { name: "unknown encrypt version", - encVersion: 1, - decVersion: 0, + encVersion: 0, + decVersion: blob.TypeDefault, sweepAddr: makeAddr(34), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), csvDelay: 144, commitToLocalSig: makeSig(1), - encErr: blob.ErrUnknownBlobVersion, + encErr: blob.ErrUnknownBlobType, }, { name: "unknown decrypt version", - encVersion: 0, - decVersion: 1, + encVersion: blob.TypeDefault, + decVersion: 0, sweepAddr: makeAddr(34), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), csvDelay: 144, commitToLocalSig: makeSig(1), - decErr: blob.ErrUnknownBlobVersion, + decErr: blob.ErrUnknownBlobType, }, { name: "sweep addr length zero", - encVersion: 0, - decVersion: 0, + encVersion: blob.TypeDefault, + decVersion: blob.TypeDefault, sweepAddr: makeAddr(0), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -110,8 +118,8 @@ var descriptorTests = []descriptorTest{ }, { name: "sweep addr max size", - encVersion: 0, - decVersion: 0, + encVersion: blob.TypeDefault, + decVersion: blob.TypeDefault, sweepAddr: makeAddr(blob.MaxSweepAddrSize), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -120,8 +128,8 @@ var descriptorTests = []descriptorTest{ }, { name: "sweep addr too long", - encVersion: 0, - decVersion: 0, + encVersion: blob.TypeDefault, + decVersion: blob.TypeDefault, sweepAddr: makeAddr(blob.MaxSweepAddrSize + 1), revPubKey: makePubKey(0), delayPubKey: makePubKey(1), @@ -208,3 +216,195 @@ func testBlobJusticeKitEncryptDecrypt(t *testing.T, test descriptorTest) { "want: %v, got %v", boj, boj2) } } + +// TestJusticeKitRemoteWitnessConstruction tests that a JusticeKit returns the +// proper to-remote witnes script and to-remote witness stack. This should be +// equivalent to p2wkh spend. +func TestJusticeKitRemoteWitnessConstruction(t *testing.T) { + // Generate the to-remote pubkey. + toRemotePrivKey, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate to-remote priv key: %v", err) + } + + // Copy the to-remote pubkey into the format expected by our justice + // kit. + var toRemotePubKey blob.PubKey + copy(toRemotePubKey[:], toRemotePrivKey.PubKey().SerializeCompressed()) + + // Sign a message using the to-remote private key. The exact message + // doesn't matter as we won't be validating the signature's validity. + digest := bytes.Repeat([]byte("a"), 32) + rawToRemoteSig, err := toRemotePrivKey.Sign(digest) + if err != nil { + t.Fatalf("unable to generate to-remote signature: %v", err) + } + + // Convert the DER-encoded signature into a fixed-size sig. + commitToRemoteSig, err := lnwire.NewSigFromSignature(rawToRemoteSig) + if err != nil { + t.Fatalf("unable to convert raw to-remote signature to "+ + "Sig: %v", err) + } + + // Populate the justice kit fields relevant to the to-remote output. + justiceKit := &blob.JusticeKit{ + CommitToRemotePubKey: toRemotePubKey, + CommitToRemoteSig: commitToRemoteSig, + } + + // Now, compute the to-remote witness script returned by the justice + // kit. + toRemoteScript, err := justiceKit.CommitToRemoteWitnessScript() + if err != nil { + t.Fatalf("unable to compute to-remote witness script: %v", err) + } + + // Assert this is exactly the to-remote, compressed pubkey. + if !bytes.Equal(toRemoteScript, toRemotePubKey[:]) { + t.Fatalf("to-remote witness script should be equal to "+ + "to-remote pubkey, want: %x, got %x", + toRemotePubKey[:], toRemoteScript) + } + + // Next, compute the to-remote witness stack, which should be a p2wkh + // witness stack consisting solely of a signature. + toRemoteWitnessStack, err := justiceKit.CommitToRemoteWitnessStack() + if err != nil { + t.Fatalf("unable to compute to-remote witness stack: %v", err) + } + + // Assert that the witness stack only has one element. + if len(toRemoteWitnessStack) != 1 { + t.Fatalf("to-remote witness stack should be of length 1, is %d", + len(toRemoteWitnessStack)) + } + + // Compute the expected first element, by appending a sighash all byte + // to our raw DER-encoded signature. + rawToRemoteSigWithSigHash := append( + rawToRemoteSig.Serialize(), byte(txscript.SigHashAll), + ) + + // Assert that the expected signature matches the first element in the + // witness stack. + if !bytes.Equal(rawToRemoteSigWithSigHash, toRemoteWitnessStack[0]) { + t.Fatalf("mismatched sig in to-remote witness stack, want: %v, "+ + "got: %v", rawToRemoteSigWithSigHash, + toRemoteWitnessStack[0]) + } + + // Finally, set the CommitToRemotePubKey to be a blank value. + justiceKit.CommitToRemotePubKey = blob.PubKey{} + + // When trying to compute the witness script, this should now return + // ErrNoCommitToRemoteOutput since a valid pubkey could not be parsed + // from CommitToRemotePubKey. + _, err = justiceKit.CommitToRemoteWitnessScript() + if err != blob.ErrNoCommitToRemoteOutput { + t.Fatalf("expected ErrNoCommitToRemoteOutput, got: %v", err) + } +} + +// TestJusticeKitToLocalWitnessConstruction tests that a JusticeKit returns the +// proper to-local witness script and to-local witness stack for spending the +// revocation path. +func TestJusticeKitToLocalWitnessConstruction(t *testing.T) { + csvDelay := uint32(144) + + // Generate the revocation and delay private keys. + revPrivKey, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate revocation priv key: %v", err) + } + + delayPrivKey, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate delay priv key: %v", err) + } + + // Copy the revocation and delay pubkeys into the format expected by our + // justice kit. + var revPubKey blob.PubKey + copy(revPubKey[:], revPrivKey.PubKey().SerializeCompressed()) + + var delayPubKey blob.PubKey + copy(delayPubKey[:], delayPrivKey.PubKey().SerializeCompressed()) + + // Sign a message using the revocation private key. The exact message + // doesn't matter as we won't be validating the signature's validity. + digest := bytes.Repeat([]byte("a"), 32) + rawRevSig, err := revPrivKey.Sign(digest) + if err != nil { + t.Fatalf("unable to generate revocation signature: %v", err) + } + + // Convert the DER-encoded signature into a fixed-size sig. + commitToLocalSig, err := lnwire.NewSigFromSignature(rawRevSig) + if err != nil { + t.Fatalf("unable to convert raw revocation signature to "+ + "Sig: %v", err) + } + + // Populate the justice kit with fields relevant to the to-local output. + justiceKit := &blob.JusticeKit{ + CSVDelay: csvDelay, + RevocationPubKey: revPubKey, + LocalDelayPubKey: delayPubKey, + CommitToLocalSig: commitToLocalSig, + } + + // Compute the expected to-local script, which is a function of the CSV + // delay, revocation pubkey and delay pubkey. + expToLocalScript, err := lnwallet.CommitScriptToSelf( + csvDelay, delayPrivKey.PubKey(), revPrivKey.PubKey(), + ) + if err != nil { + t.Fatalf("unable to generate expected to-local script: %v", err) + } + + // Compute the to-local script that is returned by the justice kit. + toLocalScript, err := justiceKit.CommitToLocalWitnessScript() + if err != nil { + t.Fatalf("unable to compute to-local witness script: %v", err) + } + + // Assert that the expected to-local script matches the actual script. + if !bytes.Equal(expToLocalScript, toLocalScript) { + t.Fatalf("mismatched to-local witness script, want: %v, got %v", + expToLocalScript, toLocalScript) + } + + // Next, compute the to-local witness stack returned by the justice kit. + toLocalWitnessStack, err := justiceKit.CommitToLocalRevokeWitnessStack() + if err != nil { + t.Fatalf("unable to compute to-local witness stack: %v", err) + } + + // A valid witness that spends the revocation path should have exactly + // two elements on the stack. + if len(toLocalWitnessStack) != 2 { + t.Fatalf("to-local witness stack should be of length 2, is %d", + len(toLocalWitnessStack)) + } + + // First, we'll verify that the top element is 0x01, which triggers the + // revocation path within the to-local witness script. + if !bytes.Equal(toLocalWitnessStack[1], []byte{0x01}) { + t.Fatalf("top item on witness stack should be 0x01, found: %v", + toLocalWitnessStack[1]) + } + + // Next, compute the expected signature in the bottom element of the + // stack, by appending a sighash all flag to the raw DER signature. + rawRevSigWithSigHash := append( + rawRevSig.Serialize(), byte(txscript.SigHashAll), + ) + + // Assert that the second element on the stack matches our expected + // signature under the revocation pubkey. + if !bytes.Equal(rawRevSigWithSigHash, toLocalWitnessStack[0]) { + t.Fatalf("mismatched sig in to-local witness stack, want: %v, "+ + "got: %v", rawRevSigWithSigHash, toLocalWitnessStack[0]) + } +} diff --git a/watchtower/blob/type.go b/watchtower/blob/type.go new file mode 100644 index 000000000..3a6e03de8 --- /dev/null +++ b/watchtower/blob/type.go @@ -0,0 +1,134 @@ +package blob + +import ( + "fmt" + "strings" +) + +// Flag represents a specify option that can be present in a Type. +type Flag uint16 + +const ( + // FlagReward signals that the justice transaction should contain an + // additional output for itself. Signatures sent by the client should + // include the reward script negotiated during session creation. Without + // the flag, there is only one output sweeping clients funds back to + // them solely. + FlagReward Flag = 1 << iota + + // FlagCommitOutputs signals that the blob contains the information + // required to sweep commitment outputs. + FlagCommitOutputs +) + +// Type returns a Type consisting solely of this flag enabled. +func (f Flag) Type() Type { + return Type(f) +} + +// String returns the name of the flag. +func (f Flag) String() string { + switch f { + case FlagReward: + return "FlagReward" + case FlagCommitOutputs: + return "FlagCommitOutputs" + default: + return "FlagUnknown" + } +} + +// Type is a bit vector composed of Flags that govern various aspects of +// reconstructing the justice transaction from an encrypted blob. The flags can +// be used to signal behaviors such as which inputs are being swept, which +// outputs should be added to the justice transaction, or modify serialization +// of the blob itself. +type Type uint16 + +// TypeDefault sweeps only commitment outputs to a sweep address controlled by +// the user, and does not give the tower a reward. +const TypeDefault = Type(FlagCommitOutputs) + +// Has returns true if the Type has the passed flag enabled. +func (t Type) Has(flag Flag) bool { + return Flag(t)&flag == flag +} + +// TypeFromFlags creates a single Type from an arbitrary list of flags. +func TypeFromFlags(flags ...Flag) Type { + var typ Type + for _, flag := range flags { + typ |= Type(flag) + } + + return typ +} + +// knownFlags maps the supported flags to their name. +var knownFlags = map[Flag]struct{}{ + FlagReward: {}, + FlagCommitOutputs: {}, +} + +// String returns a human readable description of a Type. +func (t Type) String() string { + var ( + hrPieces []string + hasUnknownFlags bool + ) + + // Iterate through the possible flags from highest to lowest. This will + // ensure that the human readable names will be in the same order as the + // bits (left to right) if the type were to be printed in big-endian + // byte order. + for f := Flag(1 << 15); f != 0; f >>= 1 { + // If this flag is known, we'll add a human-readable name or its + // inverse depending on whether the type has this flag set. + if _, ok := knownFlags[f]; ok { + if t.Has(f) { + hrPieces = append(hrPieces, f.String()) + } else { + hrPieces = append(hrPieces, "No-"+f.String()) + } + } else { + // Make note of any unknown flags that this type has + // set. If any are present, we'll prepend the bit-wise + // representation of the type in the final string. + if t.Has(f) { + hasUnknownFlags = true + } + } + } + + // If there were no unknown flags, we'll simply return the list of human + // readable pieces. + if !hasUnknownFlags { + return fmt.Sprintf("[%s]", strings.Join(hrPieces, "|")) + } + + // Otherwise, we'll prepend the bit-wise representation to the human + // readable names. + return fmt.Sprintf("%016b[%s]", t, strings.Join(hrPieces, "|")) +} + +// supportedTypes is the set of all configurations known to be supported by the +// package. +var supportedTypes = map[Type]struct{}{ + FlagCommitOutputs.Type(): {}, + (FlagCommitOutputs | FlagReward).Type(): {}, +} + +// IsSupportedType returns true if the given type is supported by the package. +func IsSupportedType(blobType Type) bool { + _, ok := supportedTypes[blobType] + return ok +} + +// SupportedTypes returns a list of all supported blob types. +func SupportedTypes() []Type { + supported := make([]Type, 0, len(supportedTypes)) + for t := range supportedTypes { + supported = append(supported, t) + } + return supported +} diff --git a/watchtower/blob/type_test.go b/watchtower/blob/type_test.go new file mode 100644 index 000000000..f5e0d85fe --- /dev/null +++ b/watchtower/blob/type_test.go @@ -0,0 +1,135 @@ +package blob_test + +import ( + "testing" + + "github.com/lightningnetwork/lnd/watchtower/blob" +) + +var unknownFlag = blob.Flag(16) + +type typeStringTest struct { + name string + typ blob.Type + expStr string +} + +var typeStringTests = []typeStringTest{ + { + name: "commit no-reward", + typ: blob.TypeDefault, + expStr: "[FlagCommitOutputs|No-FlagReward]", + }, + { + name: "commit reward", + typ: (blob.FlagCommitOutputs | blob.FlagReward).Type(), + expStr: "[FlagCommitOutputs|FlagReward]", + }, + { + name: "unknown flag", + typ: unknownFlag.Type(), + expStr: "0000000000010000[No-FlagCommitOutputs|No-FlagReward]", + }, +} + +// TestTypeStrings asserts that the proper human-readable string is returned for +// various blob.Types +func TestTypeStrings(t *testing.T) { + for _, test := range typeStringTests { + t.Run(test.name, func(t *testing.T) { + typeStr := test.typ.String() + if typeStr != test.expStr { + t.Fatalf("mismatched type string, want: %v, "+ + "got %v", test.expStr, typeStr) + } + }) + } +} + +// TestUnknownFlagString asserts that the proper string is returned from +// unallocated flags. +func TestUnknownFlagString(t *testing.T) { + if unknownFlag.String() != "FlagUnknown" { + t.Fatalf("unknown flags should return FlagUnknown, instead "+ + "got: %v", unknownFlag.String()) + } +} + +type typeFromFlagTest struct { + name string + flags []blob.Flag + expType blob.Type +} + +var typeFromFlagTests = []typeFromFlagTest{ + { + name: "no flags", + flags: nil, + expType: blob.Type(0), + }, + { + name: "single flag", + flags: []blob.Flag{blob.FlagReward}, + expType: blob.Type(blob.FlagReward), + }, + { + name: "multiple flags", + flags: []blob.Flag{blob.FlagReward, blob.FlagCommitOutputs}, + expType: blob.Type(blob.FlagReward | blob.FlagCommitOutputs), + }, + { + name: "duplicate flag", + flags: []blob.Flag{blob.FlagReward, blob.FlagReward}, + expType: blob.Type(blob.FlagReward), + }, +} + +// TestTypeFromFlags asserts that blob.Types constructed using +// blob.TypeFromFlags are correct, and properly deduplicate flags. We also +// assert that Has returns true for the generated blob.Type for all of the flags +// that were used to create it. +func TestTypeFromFlags(t *testing.T) { + for _, test := range typeFromFlagTests { + t.Run(test.name, func(t *testing.T) { + blobType := blob.TypeFromFlags(test.flags...) + + // Assert that the constructed type matches our + // expectation. + if blobType != test.expType { + t.Fatalf("mismatch, expected blob type %s, "+ + "got %s", test.expType, blobType) + } + + // Assert that Has returns true for all flags used to + // construct the type. + for _, flag := range test.flags { + if blobType.Has(flag) { + continue + } + + t.Fatalf("expected type to have flag %s, "+ + "but didn't", flag) + } + }) + } +} + +// TestSupportedTypes verifies that blob.IsSupported returns true for all +// blob.Types returned from blob.SupportedTypes. It also asserts that the +// blob.DefaultType returns true. +func TestSupportedTypes(t *testing.T) { + // Assert that the package's default type is supported. + if !blob.IsSupportedType(blob.TypeDefault) { + t.Fatalf("default type %s is not supported", blob.TypeDefault) + } + + // Assert that all claimed supported types are actually supported. + for _, supType := range blob.SupportedTypes() { + if blob.IsSupportedType(supType) { + continue + } + + t.Fatalf("supposedly supported type %s is not supported", + supType) + } +} diff --git a/watchtower/lookout/justice_descriptor_test.go b/watchtower/lookout/justice_descriptor_test.go index d0c49e592..1fe215cc1 100644 --- a/watchtower/lookout/justice_descriptor_test.go +++ b/watchtower/lookout/justice_descriptor_test.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/lookout" "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) const csvDelay uint32 = 144 @@ -170,8 +171,10 @@ func TestJusticeDescriptor(t *testing.T) { // parameters that should be used in constructing the justice // transaction. sessionInfo := &wtdb.SessionInfo{ - SweepFeeRate: 2000, - RewardRate: 900000, + Policy: wtpolicy.Policy{ + SweepFeeRate: 2000, + RewardRate: 900000, + }, RewardAddress: makeAddrSlice(22), } diff --git a/watchtower/lookout/lookout.go b/watchtower/lookout/lookout.go index 546e03f2f..556db2e81 100644 --- a/watchtower/lookout/lookout.go +++ b/watchtower/lookout/lookout.go @@ -210,7 +210,7 @@ func (l *Lookout) processEpoch(epoch *chainntnfs.BlockEpoch, // sweep the breached commitment outputs. justiceKit, err := blob.Decrypt( commitTxID[:], match.EncryptedBlob, - match.SessionInfo.Version, + match.SessionInfo.Policy.BlobType, ) if err != nil { // If the decryption fails, this implies either that the diff --git a/watchtower/lookout/lookout_test.go b/watchtower/lookout/lookout_test.go index 606813258..4232791d5 100644 --- a/watchtower/lookout/lookout_test.go +++ b/watchtower/lookout/lookout_test.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/lookout" "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) type mockPunisher struct { @@ -86,15 +87,25 @@ func TestLookoutBreachMatching(t *testing.T) { t.Fatalf("unable to start watcher: %v", err) } + rewardAndCommitType := blob.TypeFromFlags( + blob.FlagReward, blob.FlagCommitOutputs, + ) + // Create two sessions, representing two distinct clients. sessionInfo1 := &wtdb.SessionInfo{ - ID: makeArray33(1), - MaxUpdates: 10, + ID: makeArray33(1), + Policy: wtpolicy.Policy{ + BlobType: rewardAndCommitType, + MaxUpdates: 10, + }, RewardAddress: makeAddrSlice(22), } sessionInfo2 := &wtdb.SessionInfo{ - ID: makeArray33(2), - MaxUpdates: 10, + ID: makeArray33(2), + Policy: wtpolicy.Policy{ + BlobType: rewardAndCommitType, + MaxUpdates: 10, + }, RewardAddress: makeAddrSlice(22), } @@ -137,13 +148,13 @@ func TestLookoutBreachMatching(t *testing.T) { } // Encrypt the first justice kit under the txid of the first txn. - encBlob1, err := blob1.Encrypt(hash1[:], 0) + encBlob1, err := blob1.Encrypt(hash1[:], blob.FlagCommitOutputs.Type()) if err != nil { t.Fatalf("unable to encrypt sweep detail 1: %v", err) } // Encrypt the second justice kit under the txid of the second txn. - encBlob2, err := blob2.Encrypt(hash2[:], 0) + encBlob2, err := blob2.Encrypt(hash2[:], blob.FlagCommitOutputs.Type()) if err != nil { t.Fatalf("unable to encrypt sweep detail 2: %v", err) } diff --git a/watchtower/wtdb/session_info.go b/watchtower/wtdb/session_info.go index 23775fe61..5ff781729 100644 --- a/watchtower/wtdb/session_info.go +++ b/watchtower/wtdb/session_info.go @@ -4,7 +4,7 @@ import ( "errors" "github.com/btcsuite/btcutil" - "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) var ( @@ -49,12 +49,8 @@ type SessionInfo struct { // ID is the remote public key of the watchtower client. ID SessionID - // Version specifies the plaintext blob encoding of all state updates. - Version uint16 - - // MaxUpdates is the total number of updates the client can send for - // this session. - MaxUpdates uint16 + // Policy holds the negotiated session parameters. + Policy wtpolicy.Policy // LastApplied the sequence number of the last successful state update. LastApplied uint16 @@ -62,14 +58,6 @@ type SessionInfo struct { // ClientLastApplied the last last-applied the client has echoed back. ClientLastApplied uint16 - // RewardRate the fraction of the swept amount that goes to the tower, - // expressed in millionths of the swept balance. - RewardRate uint32 - - // SweepFeeRate is the agreed upon fee rate used to sign any sweep - // transactions. - SweepFeeRate lnwallet.SatPerKWeight - // RewardAddress the address that the tower's reward will be deposited // to if a sweep transaction confirms. RewardAddress []byte @@ -96,7 +84,7 @@ func (s *SessionInfo) AcceptUpdateSequence(seqNum, lastApplied uint16) error { return ErrLastAppliedReversion // Client update exceeds capacity of session. - case seqNum > s.MaxUpdates: + case seqNum > s.Policy.MaxUpdates: return ErrSessionConsumed // Client update does not match our expected next seqnum. @@ -117,7 +105,7 @@ func (s *SessionInfo) AcceptUpdateSequence(seqNum, lastApplied uint16) error { func (s *SessionInfo) ComputeSweepOutputs(totalAmt btcutil.Amount, txVSize int64) (btcutil.Amount, btcutil.Amount, error) { - txFee := s.SweepFeeRate.FeeForWeight(txVSize) + txFee := s.Policy.SweepFeeRate.FeeForWeight(txVSize) if txFee > totalAmt { return 0, 0, ErrFeeExceedsInputs } @@ -126,7 +114,8 @@ func (s *SessionInfo) ComputeSweepOutputs(totalAmt btcutil.Amount, // Apply the reward rate to the remaining total, specified in millionths // of the available balance. - rewardAmt := (totalAmt*btcutil.Amount(s.RewardRate) + 999999) / 1000000 + rewardRate := btcutil.Amount(s.Policy.RewardRate) + rewardAmt := (totalAmt*rewardRate + 999999) / 1000000 sweepAmt := totalAmt - rewardAmt // TODO(conner): check dustiness diff --git a/watchtower/wtpolicy/policy.go b/watchtower/wtpolicy/policy.go new file mode 100644 index 000000000..65e4b4a9b --- /dev/null +++ b/watchtower/wtpolicy/policy.go @@ -0,0 +1,68 @@ +package wtpolicy + +import ( + "fmt" + + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/watchtower/blob" +) + +const ( + // DefaultMaxUpdates specifies the number of encrypted blobs a client + // can send to the tower in a single session. + DefaultMaxUpdates = 1024 + + // DefaultRewardRate specifies the fraction of the channel that the + // tower takes if it successfully sweeps a breach. The value is + // expressed in millionths of the channel capacity. + DefaultRewardRate = 10000 + + // DefaultSweepFeeRate specifies the fee rate used to construct justice + // transactions. The value is expressed in satoshis per kilo-weight. + DefaultSweepFeeRate = 3000 +) + +// DefaultPolicy returns a Policy containing the default parameters that can be +// used by clients or servers. +func DefaultPolicy() Policy { + return Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: DefaultMaxUpdates, + RewardRate: DefaultRewardRate, + SweepFeeRate: lnwallet.SatPerKWeight( + DefaultSweepFeeRate, + ), + } +} + +// Policy defines the negotiated parameters for a session between a client and +// server. The parameters specify the format of encrypted blobs sent to the +// tower, the reward schedule for the tower, and the number of encrypted blobs a +// client can send in one session. +type Policy struct { + // BlobType specifies the blob format that must be used by all updates sent + // under the session key used to negotiate this session. + BlobType blob.Type + + // MaxUpdates is the maximum number of updates the watchtower will honor + // for this session. + MaxUpdates uint16 + + // RewardRate is the fraction of the total balance of the revoked + // commitment that the watchtower is entitled to. This value is + // expressed in millionths of the total balance. + RewardRate uint32 + + // SweepFeeRate expresses the intended fee rate to be used when + // constructing the justice transaction. All sweep transactions created + // for this session must use this value during construction, and the + // signatures must implicitly commit to the resulting output values. + SweepFeeRate lnwallet.SatPerKWeight +} + +// String returns a human-readable description of the current policy. +func (p Policy) String() string { + return fmt.Sprintf("(blob-type=%b max-updates=%d reward-rate=%d "+ + "sweep-fee-rate=%d)", p.BlobType, p.MaxUpdates, p.RewardRate, + p.SweepFeeRate) +} diff --git a/watchtower/wtserver/server.go b/watchtower/wtserver/server.go index 24483d3ac..1218b1f5a 100644 --- a/watchtower/wtserver/server.go +++ b/watchtower/wtserver/server.go @@ -13,7 +13,9 @@ import ( "github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtwire" ) @@ -246,14 +248,14 @@ func (s *Server) handleClient(peer Peer) { log.Infof("Received CreateSession from %s, "+ "version=%d nupdates=%d rewardrate=%d "+ - "sweepfeerate=%d", id, msg.BlobVersion, + "sweepfeerate=%d", id, msg.BlobType, msg.MaxUpdates, msg.RewardRate, msg.SweepFeeRate) // Attempt to open a new session for this client. err := s.handleCreateSession(peer, &id, msg) if err != nil { - log.Errorf("unable to handle CreateSession "+ + log.Errorf("Unable to handle CreateSession "+ "from %s: %v", id, err) } @@ -327,7 +329,7 @@ func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error { // session info is known about the session id. If an existing session is found, // the reward address is returned in case the client lost our reply. func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, - init *wtwire.CreateSession) error { + req *wtwire.CreateSession) error { // TODO(conner): validate accept against policy @@ -369,17 +371,28 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, rewardAddrBytes := rewardAddress.ScriptAddress() + // Ensure that the requested blob type is supported by our tower. + if !blob.IsSupportedType(req.BlobType) { + log.Debugf("Rejecting CreateSession from %s, unsupported blob "+ + "type %s", id, req.BlobType) + return s.replyCreateSession( + peer, id, wtwire.CreateSessionCodeRejectBlobType, nil, + ) + } + // TODO(conner): create invoice for upfront payment // Assemble the session info using the agreed upon parameters, reward // address, and session id. info := wtdb.SessionInfo{ ID: *id, - Version: init.BlobVersion, - MaxUpdates: init.MaxUpdates, - RewardRate: init.RewardRate, - SweepFeeRate: init.SweepFeeRate, RewardAddress: rewardAddrBytes, + Policy: wtpolicy.Policy{ + BlobType: req.BlobType, + MaxUpdates: req.MaxUpdates, + RewardRate: req.RewardRate, + SweepFeeRate: req.SweepFeeRate, + }, } // Insert the session info into the watchtower's database. If diff --git a/watchtower/wtserver/server_test.go b/watchtower/wtserver/server_test.go index 3ca056af2..bbca6c139 100644 --- a/watchtower/wtserver/server_test.go +++ b/watchtower/wtserver/server_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtserver" "github.com/lightningnetwork/lnd/watchtower/wtwire" @@ -155,7 +156,7 @@ var createSessionTests = []createSessionTestCase{ lnwire.NewRawFeatureVector(), ), createMsg: &wtwire.CreateSession{ - BlobVersion: 0, + BlobType: blob.TypeDefault, MaxUpdates: 1000, RewardRate: 0, SweepFeeRate: 1, @@ -169,6 +170,23 @@ var createSessionTests = []createSessionTestCase{ Data: []byte(addr.ScriptAddress()), }, }, + { + name: "reject unsupported blob type", + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + lnwire.NewRawFeatureVector(), + ), + createMsg: &wtwire.CreateSession{ + BlobType: 0, + MaxUpdates: 1000, + RewardRate: 0, + SweepFeeRate: 1, + }, + expReply: &wtwire.CreateSessionReply{ + Code: wtwire.CreateSessionCodeRejectBlobType, + Data: []byte{}, + }, + }, // TODO(conner): add policy rejection tests } @@ -258,7 +276,7 @@ var stateUpdateTests = []stateUpdateTestCase{ GlobalFeatures: lnwire.NewRawFeatureVector(), }}, createMsg: &wtwire.CreateSession{ - BlobVersion: 0, + BlobType: blob.TypeDefault, MaxUpdates: 3, RewardRate: 0, SweepFeeRate: 1, @@ -287,7 +305,7 @@ var stateUpdateTests = []stateUpdateTestCase{ GlobalFeatures: lnwire.NewRawFeatureVector(), }}, createMsg: &wtwire.CreateSession{ - BlobVersion: 0, + BlobType: blob.TypeDefault, MaxUpdates: 4, RewardRate: 0, SweepFeeRate: 1, @@ -310,7 +328,7 @@ var stateUpdateTests = []stateUpdateTestCase{ GlobalFeatures: lnwire.NewRawFeatureVector(), }}, createMsg: &wtwire.CreateSession{ - BlobVersion: 0, + BlobType: blob.TypeDefault, MaxUpdates: 4, RewardRate: 0, SweepFeeRate: 1, @@ -337,7 +355,7 @@ var stateUpdateTests = []stateUpdateTestCase{ GlobalFeatures: lnwire.NewRawFeatureVector(), }}, createMsg: &wtwire.CreateSession{ - BlobVersion: 0, + BlobType: blob.TypeDefault, MaxUpdates: 4, RewardRate: 0, SweepFeeRate: 1, @@ -364,7 +382,7 @@ var stateUpdateTests = []stateUpdateTestCase{ GlobalFeatures: lnwire.NewRawFeatureVector(), }}, createMsg: &wtwire.CreateSession{ - BlobVersion: 0, + BlobType: blob.TypeDefault, MaxUpdates: 4, RewardRate: 0, SweepFeeRate: 1, @@ -393,7 +411,7 @@ var stateUpdateTests = []stateUpdateTestCase{ GlobalFeatures: lnwire.NewRawFeatureVector(), }}, createMsg: &wtwire.CreateSession{ - BlobVersion: 0, + BlobType: blob.TypeDefault, MaxUpdates: 4, RewardRate: 0, SweepFeeRate: 1, @@ -421,7 +439,7 @@ var stateUpdateTests = []stateUpdateTestCase{ GlobalFeatures: lnwire.NewRawFeatureVector(), }}, createMsg: &wtwire.CreateSession{ - BlobVersion: 0, + BlobType: blob.TypeDefault, MaxUpdates: 3, RewardRate: 0, SweepFeeRate: 1, @@ -450,7 +468,7 @@ var stateUpdateTests = []stateUpdateTestCase{ GlobalFeatures: lnwire.NewRawFeatureVector(), }}, createMsg: &wtwire.CreateSession{ - BlobVersion: 0, + BlobType: blob.TypeDefault, MaxUpdates: 3, RewardRate: 0, SweepFeeRate: 1, diff --git a/watchtower/wtwire/create_session.go b/watchtower/wtwire/create_session.go index 8ee2c9069..6067d25ca 100644 --- a/watchtower/wtwire/create_session.go +++ b/watchtower/wtwire/create_session.go @@ -4,6 +4,7 @@ import ( "io" "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/watchtower/blob" ) // CreateSession is sent from a client to tower when to negotiate a session, which @@ -11,9 +12,9 @@ import ( // An update is consumed by uploading an encrypted blob that contains // information required to sweep a revoked commitment transaction. type CreateSession struct { - // BlobVersion specifies the blob format that must be used by all - // updates sent under the session key used to negotiate this session. - BlobVersion uint16 + // BlobType specifies the blob format that must be used by all updates sent + // under the session key used to negotiate this session. + BlobType blob.Type // MaxUpdates is the maximum number of updates the watchtower will honor // for this session. @@ -41,7 +42,7 @@ var _ Message = (*CreateSession)(nil) // This is part of the wtwire.Message interface. func (m *CreateSession) Decode(r io.Reader, pver uint32) error { return ReadElements(r, - &m.BlobVersion, + &m.BlobType, &m.MaxUpdates, &m.RewardRate, &m.SweepFeeRate, @@ -54,7 +55,7 @@ func (m *CreateSession) Decode(r io.Reader, pver uint32) error { // This is part of the wtwire.Message interface. func (m *CreateSession) Encode(w io.Writer, pver uint32) error { return WriteElements(w, - m.BlobVersion, + m.BlobType, m.MaxUpdates, m.RewardRate, m.SweepFeeRate, diff --git a/watchtower/wtwire/create_session_reply.go b/watchtower/wtwire/create_session_reply.go index b1224cb84..da4867f29 100644 --- a/watchtower/wtwire/create_session_reply.go +++ b/watchtower/wtwire/create_session_reply.go @@ -25,6 +25,10 @@ const ( // CreateSessionCodeRejectSweepFeeRate the tower rejected the sweep fee // rate proposed by the client. CreateSessionCodeRejectSweepFeeRate CreateSessionCode = 63 + + // CreateSessionCodeRejectBlobType is returned when the tower does not + // support the proposed blob type. + CreateSessionCodeRejectBlobType CreateSessionCode = 64 ) // MaxCreateSessionReplyDataLength is the maximum size of the Data payload diff --git a/watchtower/wtwire/wtwire.go b/watchtower/wtwire/wtwire.go index ce3a9a639..2582c7042 100644 --- a/watchtower/wtwire/wtwire.go +++ b/watchtower/wtwire/wtwire.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/watchtower/blob" ) // WriteElement is a one-stop shop to write the big endian representation of @@ -30,6 +31,13 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case blob.Type: + var b [2]byte + binary.BigEndian.PutUint16(b[:], uint16(e)) + if _, err := w.Write(b[:]); err != nil { + return err + } + case uint32: var b [4]byte binary.BigEndian.PutUint32(b[:], e) @@ -127,6 +135,13 @@ func ReadElement(r io.Reader, element interface{}) error { } *e = binary.BigEndian.Uint16(b[:]) + case *blob.Type: + var b [2]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = blob.Type(binary.BigEndian.Uint16(b[:])) + case *uint32: var b [4]byte if _, err := io.ReadFull(r, b[:]); err != nil {