package channeldb import ( "bytes" "math/rand" "net" "reflect" "runtime" "sync/atomic" "testing" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" _ "github.com/btcsuite/btcwallet/walletdb/bdb" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" "github.com/stretchr/testify/require" ) var ( key = [chainhash.HashSize]byte{ 0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, 0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, 0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d, 0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9, } rev = [chainhash.HashSize]byte{ 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, 0x2d, 0xe7, 0x93, 0xe4, } privKey, pubKey = btcec.PrivKeyFromBytes(key[:]) wireSig, _ = lnwire.NewSigFromSignature(testSig) testClock = clock.NewTestClock(testNow) // defaultPendingHeight is the default height at which we set // channels to pending. defaultPendingHeight = 100 // defaultAddr is the default address that we mark test channels pending // with. defaultAddr = &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 18555, } // keyLocIndex is the KeyLocator Index we use for // TestKeyLocatorEncoding. keyLocIndex = uint32(2049) // dummyLocalOutputIndex specifics a default value for our output index // in this test. dummyLocalOutputIndex = uint32(0) // dummyRemoteOutIndex specifics a default value for their output index // in this test. dummyRemoteOutIndex = uint32(1) // uniqueOutputIndex is used to create a unique funding outpoint. // // NOTE: must be incremented when used. uniqueOutputIndex = atomic.Uint32{} ) // testChannelParams is a struct which details the specifics of how a channel // should be created. type testChannelParams struct { // channel is the channel that will be written to disk. channel *OpenChannel // addr is the address that the channel will be synced pending with. addr *net.TCPAddr // pendingHeight is the height that the channel should be recorded as // pending. pendingHeight uint32 // openChannel is set to true if the channel should be fully marked as // open if this is false, the channel will be left in pending state. openChannel bool } // testChannelOption is a functional option which can be used to alter the // default channel that is creates for testing. type testChannelOption func(params *testChannelParams) // pendingHeightOption is an option which can be used to set the height the // channel is marked as pending at. func pendingHeightOption(height uint32) testChannelOption { return func(params *testChannelParams) { params.pendingHeight = height } } // openChannelOption is an option which can be used to create a test channel // that is open. func openChannelOption() testChannelOption { return func(params *testChannelParams) { params.openChannel = true } } // localHtlcsOption is an option which allows setting of htlcs on the local // commitment. func localHtlcsOption(htlcs []HTLC) testChannelOption { return func(params *testChannelParams) { params.channel.LocalCommitment.Htlcs = htlcs } } // remoteHtlcsOption is an option which allows setting of htlcs on the remote // commitment. func remoteHtlcsOption(htlcs []HTLC) testChannelOption { return func(params *testChannelParams) { params.channel.RemoteCommitment.Htlcs = htlcs } } // loadFwdPkgs is a helper method that reads all forwarding packages for a // particular packager. func loadFwdPkgs(t *testing.T, db kvdb.Backend, packager FwdPackager) []*FwdPkg { var ( fwdPkgs []*FwdPkg err error ) err = kvdb.View(db, func(tx kvdb.RTx) error { fwdPkgs, err = packager.LoadFwdPkgs(tx) return err }, func() {}) require.NoError(t, err, "unable to load fwd pkgs") return fwdPkgs } // localShutdownOption is an option which sets the local upfront shutdown // script for the channel. func localShutdownOption(addr lnwire.DeliveryAddress) testChannelOption { return func(params *testChannelParams) { params.channel.LocalShutdownScript = addr } } // remoteShutdownOption is an option which sets the remote upfront shutdown // script for the channel. func remoteShutdownOption(addr lnwire.DeliveryAddress) testChannelOption { return func(params *testChannelParams) { params.channel.RemoteShutdownScript = addr } } // fundingPointOption is an option which sets the funding outpoint of the // channel. func fundingPointOption(chanPoint wire.OutPoint) testChannelOption { return func(params *testChannelParams) { params.channel.FundingOutpoint = chanPoint } } // channelIDOption is an option which sets the short channel ID of the channel. var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { return func(params *testChannelParams) { params.channel.ShortChannelID = chanID } } // createTestChannel writes a test channel to the database. It takes a set of // functional options which can be used to overwrite the default of creating // a pending channel that was broadcast at height 100. func createTestChannel(t *testing.T, cdb *ChannelStateDB, opts ...testChannelOption) *OpenChannel { // Create a default set of parameters. params := &testChannelParams{ channel: createTestChannelState(t, cdb), addr: defaultAddr, openChannel: false, pendingHeight: uint32(defaultPendingHeight), } // Apply all functional options to the test channel params. for _, o := range opts { o(params) } // Mark the channel as pending. err := params.channel.SyncPending(params.addr, params.pendingHeight) if err != nil { t.Fatalf("unable to save and serialize channel "+ "state: %v", err) } // If the parameters do not specify that we should open the channel // fully, we return the pending channel. if !params.openChannel { return params.channel } // Mark the channel as open with the short channel id provided. err = params.channel.MarkAsOpen(params.channel.ShortChannelID) require.NoError(t, err, "unable to mark channel open") return params.channel } func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { // Simulate 1000 channel updates. producer, err := shachain.NewRevocationProducerFromBytes(key[:]) require.NoError(t, err, "could not get producer") store := shachain.NewRevocationStore() for i := 0; i < 1; i++ { preImage, err := producer.AtIndex(uint64(i)) if err != nil { t.Fatalf("could not get "+ "preimage: %v", err) } if err := store.AddNextEntry(preImage); err != nil { t.Fatalf("could not add entry: %v", err) } } localCfg := ChannelConfig{ ChannelConstraints: ChannelConstraints{ DustLimit: btcutil.Amount(rand.Int63()), MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()), ChanReserve: btcutil.Amount(rand.Int63()), MinHTLC: lnwire.MilliSatoshi(rand.Int63()), MaxAcceptedHtlcs: uint16(rand.Int31()), CsvDelay: uint16(rand.Int31()), }, MultiSigKey: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), }, RevocationBasePoint: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), }, PaymentBasePoint: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), }, DelayBasePoint: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), }, HtlcBasePoint: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), }, } remoteCfg := ChannelConfig{ ChannelConstraints: ChannelConstraints{ DustLimit: btcutil.Amount(rand.Int63()), MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()), ChanReserve: btcutil.Amount(rand.Int63()), MinHTLC: lnwire.MilliSatoshi(rand.Int63()), MaxAcceptedHtlcs: uint16(rand.Int31()), CsvDelay: uint16(rand.Int31()), }, MultiSigKey: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), KeyLocator: keychain.KeyLocator{ Family: keychain.KeyFamilyMultiSig, Index: 9, }, }, RevocationBasePoint: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), KeyLocator: keychain.KeyLocator{ Family: keychain.KeyFamilyRevocationBase, Index: 8, }, }, PaymentBasePoint: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), KeyLocator: keychain.KeyLocator{ Family: keychain.KeyFamilyPaymentBase, Index: 7, }, }, DelayBasePoint: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), KeyLocator: keychain.KeyLocator{ Family: keychain.KeyFamilyDelayBase, Index: 6, }, }, HtlcBasePoint: keychain.KeyDescriptor{ PubKey: privKey.PubKey(), KeyLocator: keychain.KeyLocator{ Family: keychain.KeyFamilyHtlcBase, Index: 5, }, }, } chanID := lnwire.NewShortChanIDFromInt(uint64(rand.Int63())) // Increment the uniqueOutputIndex so we always get a unique value for // the funding outpoint. uniqueOutputIndex.Add(1) op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()} return &OpenChannel{ ChanType: SingleFunderBit | FrozenBit, ChainHash: key, FundingOutpoint: op, ShortChannelID: chanID, IsInitiator: true, IsPending: true, IdentityPub: pubKey, Capacity: btcutil.Amount(10000), LocalChanCfg: localCfg, RemoteChanCfg: remoteCfg, TotalMSatSent: 8, TotalMSatReceived: 2, LocalCommitment: ChannelCommitment{ CommitHeight: 0, LocalBalance: lnwire.MilliSatoshi(9000), RemoteBalance: lnwire.MilliSatoshi(3000), CommitFee: btcutil.Amount(rand.Int63()), FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), }, RemoteCommitment: ChannelCommitment{ CommitHeight: 0, LocalBalance: lnwire.MilliSatoshi(3000), RemoteBalance: lnwire.MilliSatoshi(9000), CommitFee: btcutil.Amount(rand.Int63()), FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), }, NumConfsRequired: 4, RemoteCurrentRevocation: privKey.PubKey(), RemoteNextRevocation: privKey.PubKey(), RevocationProducer: producer, RevocationStore: store, Db: cdb, Packager: NewChannelPackager(chanID), FundingTxn: channels.TestFundingTx, ThawHeight: uint32(defaultPendingHeight), InitialLocalBalance: lnwire.MilliSatoshi(9000), InitialRemoteBalance: lnwire.MilliSatoshi(3000), } } func TestOpenChannelPutGetDelete(t *testing.T) { t.Parallel() fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") cdb := fullDB.ChannelStateDB() // Create the test channel state, with additional htlcs on the local // and remote commitment. localHtlcs := []HTLC{ { Signature: testSig.Serialize(), Incoming: true, Amt: 10, RHash: key, RefundTimeout: 1, OnionBlob: lnmock.MockOnion(), }, } remoteHtlcs := []HTLC{ { Signature: testSig.Serialize(), Incoming: false, Amt: 10, RHash: key, RefundTimeout: 1, OnionBlob: lnmock.MockOnion(), }, } state := createTestChannel( t, cdb, remoteHtlcsOption(remoteHtlcs), localHtlcsOption(localHtlcs), ) openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) require.NoError(t, err, "unable to fetch open channel") newState := openChannels[0] // The decoded channel state should be identical to what we stored // above. if !reflect.DeepEqual(state, newState) { t.Fatalf("channel state doesn't match:: %v vs %v", spew.Sdump(state), spew.Sdump(newState)) } // We'll also test that the channel is properly able to hot swap the // next revocation for the state machine. This tests the initial // post-funding revocation exchange. nextRevKey, err := btcec.NewPrivateKey() require.NoError(t, err, "unable to create new private key") if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil { t.Fatalf("unable to update revocation: %v", err) } openChannels, err = cdb.FetchOpenChannels(state.IdentityPub) require.NoError(t, err, "unable to fetch open channel") updatedChan := openChannels[0] // Ensure that the revocation was set properly. if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) { t.Fatalf("next revocation wasn't updated") } // Finally to wrap up the test, delete the state of the channel within // the database. This involves "closing" the channel which removes all // written state, and creates a small "summary" elsewhere within the // database. closeSummary := &ChannelCloseSummary{ ChanPoint: state.FundingOutpoint, RemotePub: state.IdentityPub, SettledBalance: btcutil.Amount(500), TimeLockedBalance: btcutil.Amount(10000), IsPending: false, CloseType: CooperativeClose, } if err := state.CloseChannel(closeSummary); err != nil { t.Fatalf("unable to close channel: %v", err) } // As the channel is now closed, attempting to fetch all open channels // for our fake node ID should return an empty slice. openChans, err := cdb.FetchOpenChannels(state.IdentityPub) require.NoError(t, err, "unable to fetch open channels") if len(openChans) != 0 { t.Fatalf("all channels not deleted, found %v", len(openChans)) } // Additionally, attempting to fetch all the open channels globally // should yield no results. openChans, err = cdb.FetchAllChannels() if err != nil { t.Fatal("unable to fetch all open chans") } if len(openChans) != 0 { t.Fatalf("all channels not deleted, found %v", len(openChans)) } } // TestOptionalShutdown tests the reading and writing of channels with and // without optional shutdown script fields. func TestOptionalShutdown(t *testing.T) { local := lnwire.DeliveryAddress([]byte("local shutdown script")) remote := lnwire.DeliveryAddress([]byte("remote shutdown script")) if _, err := rand.Read(remote); err != nil { t.Fatalf("Could not create random script: %v", err) } tests := []struct { name string localShutdown lnwire.DeliveryAddress remoteShutdown lnwire.DeliveryAddress }{ { name: "no shutdown scripts", localShutdown: nil, remoteShutdown: nil, }, { name: "local shutdown script", localShutdown: local, remoteShutdown: nil, }, { name: "remote shutdown script", localShutdown: nil, remoteShutdown: remote, }, { name: "both scripts set", localShutdown: local, remoteShutdown: remote, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { fullDB, err := MakeTestDB(t) if err != nil { t.Fatalf("unable to make test database: %v", err) } cdb := fullDB.ChannelStateDB() // Create a channel with upfront scripts set as // specified in the test. state := createTestChannel( t, cdb, localShutdownOption(test.localShutdown), remoteShutdownOption(test.remoteShutdown), ) openChannels, err := cdb.FetchOpenChannels( state.IdentityPub, ) if err != nil { t.Fatalf("unable to fetch open"+ " channel: %v", err) } if len(openChannels) != 1 { t.Fatalf("Expected one channel open,"+ " got: %v", len(openChannels)) } if !bytes.Equal(openChannels[0].LocalShutdownScript, test.localShutdown) { t.Fatalf("Expected local: %x, got: %x", test.localShutdown, openChannels[0].LocalShutdownScript) } if !bytes.Equal(openChannels[0].RemoteShutdownScript, test.remoteShutdown) { t.Fatalf("Expected remote: %x, got: %x", test.remoteShutdown, openChannels[0].RemoteShutdownScript) } }) } } func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { if !reflect.DeepEqual(a, b) { _, _, line, _ := runtime.Caller(1) t.Fatalf("line %v: commitments don't match: %v vs %v", line, spew.Sdump(a), spew.Sdump(b)) } } // assertRevocationLogEntryEqual asserts that, for all the fields of a given // revocation log entry, their values match those on a given ChannelCommitment. func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment, r *RevocationLog) { // Check the common fields. require.EqualValues( t, r.CommitTxHash, c.CommitTx.TxHash(), "CommitTx mismatch", ) // Now check the common fields from the HTLCs. require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch") for i, rHtlc := range r.HTLCEntries { cHtlc := c.Htlcs[i] require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch") require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(), "Amt mismatch") require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout, "RefundTimeout mismatch") require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex, "OutputIndex mismatch") require.Equal(t, rHtlc.Incoming, cHtlc.Incoming, "Incoming mismatch") } } func TestChannelStateTransition(t *testing.T) { t.Parallel() fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") cdb := fullDB.ChannelStateDB() // First create a minimal channel, then perform a full sync in order to // persist the data. channel := createTestChannel(t, cdb) // Add some HTLCs which were added during this new state transition. // Half of the HTLCs are incoming, while the other half are outgoing. var ( htlcs []HTLC htlcAmt lnwire.MilliSatoshi ) for i := uint32(0); i < 10; i++ { var incoming bool if i > 5 { incoming = true } htlc := HTLC{ Signature: testSig.Serialize(), Incoming: incoming, Amt: 10, RHash: key, RefundTimeout: i, OutputIndex: int32(i * 3), LogIndex: uint64(i * 2), HtlcIndex: uint64(i), } copy( htlc.OnionBlob[:], bytes.Repeat([]byte{2}, lnwire.OnionPacketSize), ) htlcs = append(htlcs, htlc) htlcAmt += htlc.Amt } // Create a new channel delta which includes the above HTLCs, some // balance updates, and an increment of the current commitment height. // Additionally, modify the signature and commitment transaction. newSequence := uint32(129498) newSig := bytes.Repeat([]byte{3}, 71) newTx := channel.LocalCommitment.CommitTx.Copy() newTx.TxIn[0].Sequence = newSequence commitment := ChannelCommitment{ CommitHeight: 1, LocalLogIndex: 2, LocalHtlcIndex: 1, RemoteLogIndex: 2, RemoteHtlcIndex: 1, LocalBalance: lnwire.MilliSatoshi(1e8), RemoteBalance: lnwire.MilliSatoshi(1e8), CommitFee: 55, FeePerKw: 99, CommitTx: newTx, CommitSig: newSig, Htlcs: htlcs, } // First update the local node's broadcastable state and also add a // CommitDiff remote node's as well in order to simulate a proper state // transition. unsignedAckedUpdates := []LogUpdate{ { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ ChanID: lnwire.ChannelID{1, 2, 3}, ExtraData: make([]byte, 0), }, }, } _, err = channel.UpdateCommitment(&commitment, unsignedAckedUpdates) require.NoError(t, err, "unable to update commitment") // Assert that update is correctly written to the database. dbUnsignedAckedUpdates, err := channel.UnsignedAckedUpdates() require.NoError(t, err, "unable to fetch dangling remote updates") if len(dbUnsignedAckedUpdates) != 1 { t.Fatalf("unexpected number of dangling remote updates") } if !reflect.DeepEqual( dbUnsignedAckedUpdates[0], unsignedAckedUpdates[0], ) { t.Fatalf("unexpected update: expected %v, got %v", spew.Sdump(unsignedAckedUpdates[0]), spew.Sdump(dbUnsignedAckedUpdates)) } // The balances, new update, the HTLCs and the changes to the fake // commitment transaction along with the modified signature should all // have been updated. updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) require.NoError(t, err, "unable to fetch updated channel") assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) numDiskUpdates, err := updatedChannel[0].CommitmentHeight() require.NoError(t, err, "unable to read commitment height from disk") if numDiskUpdates != uint64(commitment.CommitHeight) { t.Fatalf("num disk updates doesn't match: %v vs %v", numDiskUpdates, commitment.CommitHeight) } // Attempting to query for a commitment diff should return // ErrNoPendingCommit as we haven't yet created a new state for them. _, err = channel.RemoteCommitChainTip() if err != ErrNoPendingCommit { t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) } // To simulate us extending a new state to the remote party, we'll also // create a new commit diff for them. remoteCommit := commitment remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8) remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8) remoteCommit.CommitHeight = 1 commitDiff := &CommitDiff{ Commitment: remoteCommit, CommitSig: &lnwire.CommitSig{ ChanID: lnwire.ChannelID(key), CommitSig: wireSig, HtlcSigs: []lnwire.Sig{ wireSig, wireSig, }, ExtraData: make([]byte, 0), }, LogUpdates: []LogUpdate{ { LogIndex: 1, UpdateMsg: &lnwire.UpdateAddHTLC{ ID: 1, Amount: lnwire.NewMSatFromSatoshis(100), Expiry: 25, ExtraData: make([]byte, 0), }, }, { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ ID: 2, Amount: lnwire.NewMSatFromSatoshis(200), Expiry: 50, ExtraData: make([]byte, 0), }, }, }, OpenedCircuitKeys: []models.CircuitKey{}, ClosedCircuitKeys: []models.CircuitKey{}, } copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], bytes.Repeat([]byte{1}, 32)) copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], bytes.Repeat([]byte{2}, 32)) if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { t.Fatalf("unable to add to commit chain: %v", err) } // The commitment tip should now match the commitment that we just // inserted. diskCommitDiff, err := channel.RemoteCommitChainTip() require.NoError(t, err, "unable to fetch commit diff") if !reflect.DeepEqual(commitDiff, diskCommitDiff) { t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit), spew.Sdump(diskCommitDiff)) } // We'll save the old remote commitment as this will be added to the // revocation log shortly. oldRemoteCommit := channel.RemoteCommitment // Next, write to the log which tracks the necessary revocation state // needed to rectify any fishy behavior by the remote party. Modify the // current uncollapsed revocation state to simulate a state transition // by the remote party. channel.RemoteCurrentRevocation = channel.RemoteNextRevocation newPriv, err := btcec.NewPrivateKey() require.NoError(t, err, "unable to generate key") channel.RemoteNextRevocation = newPriv.PubKey() fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, diskCommitDiff.LogUpdates, nil) err = channel.AdvanceCommitChainTail( fwdPkg, nil, dummyLocalOutputIndex, dummyRemoteOutIndex, ) require.NoError(t, err, "unable to append to revocation log") // At this point, the remote commit chain should be nil, and the posted // remote commitment should match the one we added as a diff above. if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit { t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) } // We should be able to fetch the channel delta created above by its // update number with all the state properly reconstructed. diskPrevCommit, _, err := channel.FindPreviousState( oldRemoteCommit.CommitHeight, ) require.NoError(t, err, "unable to fetch past delta") // Check the output indexes are saved as expected. require.EqualValues( t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, ) require.EqualValues( t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, ) // The two deltas (the original vs the on-disk version) should // identical, and all HTLC data should properly be retained. assertRevocationLogEntryEqual(t, &oldRemoteCommit, diskPrevCommit) // The state number recovered from the tail of the revocation log // should be identical to this current state. logTailHeight, err := channel.revocationLogTailCommitHeight() require.NoError(t, err, "unable to retrieve log") if logTailHeight != oldRemoteCommit.CommitHeight { t.Fatal("update number doesn't match") } oldRemoteCommit = channel.RemoteCommitment // Next modify the posted diff commitment slightly, then create a new // commitment diff and advance the tail. commitDiff.Commitment.CommitHeight = 2 commitDiff.Commitment.LocalBalance -= htlcAmt commitDiff.Commitment.RemoteBalance += htlcAmt commitDiff.LogUpdates = []LogUpdate{} if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { t.Fatalf("unable to add to commit chain: %v", err) } fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil) err = channel.AdvanceCommitChainTail( fwdPkg, nil, dummyLocalOutputIndex, dummyRemoteOutIndex, ) require.NoError(t, err, "unable to append to revocation log") // Once again, fetch the state and ensure it has been properly updated. prevCommit, _, err := channel.FindPreviousState( oldRemoteCommit.CommitHeight, ) require.NoError(t, err, "unable to fetch past delta") // Check the output indexes are saved as expected. require.EqualValues( t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, ) require.EqualValues( t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, ) assertRevocationLogEntryEqual(t, &oldRemoteCommit, prevCommit) // Once again, state number recovered from the tail of the revocation // log should be identical to this current state. logTailHeight, err = channel.revocationLogTailCommitHeight() require.NoError(t, err, "unable to retrieve log") if logTailHeight != oldRemoteCommit.CommitHeight { t.Fatal("update number doesn't match") } // The revocation state stored on-disk should now also be identical. updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub) require.NoError(t, err, "unable to fetch updated channel") if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) { t.Fatalf("revocation state was not synced") } if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) { t.Fatalf("revocation state was not synced") } // At this point, we should have 2 forwarding packages added. fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager) require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages") // Now attempt to delete the channel from the database. closeSummary := &ChannelCloseSummary{ ChanPoint: channel.FundingOutpoint, RemotePub: channel.IdentityPub, SettledBalance: btcutil.Amount(500), TimeLockedBalance: btcutil.Amount(10000), IsPending: false, CloseType: RemoteForceClose, } if err := updatedChannel[0].CloseChannel(closeSummary); err != nil { t.Fatalf("unable to delete updated channel: %v", err) } // If we attempt to fetch the target channel again, it shouldn't be // found. channels, err := cdb.FetchOpenChannels(channel.IdentityPub) require.NoError(t, err, "unable to fetch updated channels") if len(channels) != 0 { t.Fatalf("%v channels, found, but none should be", len(channels)) } // Attempting to find previous states on the channel should fail as the // revocation log has been deleted. _, _, err = updatedChannel[0].FindPreviousState( oldRemoteCommit.CommitHeight, ) if err == nil { t.Fatal("revocation log search should have failed") } // All forwarding packages of this channel has been deleted too. fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager) require.Empty(t, fwdPkgs, "no forwarding packages should exist") } func TestFetchPendingChannels(t *testing.T) { t.Parallel() fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") cdb := fullDB.ChannelStateDB() // Create a pending channel that was broadcast at height 99. const broadcastHeight = 99 createTestChannel(t, cdb, pendingHeightOption(broadcastHeight)) pendingChannels, err := cdb.FetchPendingChannels() require.NoError(t, err, "unable to list pending channels") if len(pendingChannels) != 1 { t.Fatalf("incorrect number of pending channels: expecting %v,"+ "got %v", 1, len(pendingChannels)) } // The broadcast height of the pending channel should have been set // properly. if pendingChannels[0].FundingBroadcastHeight != broadcastHeight { t.Fatalf("broadcast height mismatch: expected %v, got %v", pendingChannels[0].FundingBroadcastHeight, broadcastHeight) } chanOpenLoc := lnwire.ShortChannelID{ BlockHeight: 5, TxIndex: 10, TxPosition: 15, } err = pendingChannels[0].MarkAsOpen(chanOpenLoc) require.NoError(t, err, "unable to mark channel as open") if pendingChannels[0].IsPending { t.Fatalf("channel marked open should no longer be pending") } if pendingChannels[0].ShortChanID() != chanOpenLoc { t.Fatalf("channel opening height not updated: expected %v, "+ "got %v", spew.Sdump(pendingChannels[0].ShortChanID()), chanOpenLoc) } // Next, we'll re-fetch the channel to ensure that the open height was // properly set. openChans, err := cdb.FetchAllChannels() require.NoError(t, err, "unable to fetch channels") if openChans[0].ShortChanID() != chanOpenLoc { t.Fatalf("channel opening heights don't match: expected %v, "+ "got %v", spew.Sdump(openChans[0].ShortChanID()), chanOpenLoc) } if openChans[0].FundingBroadcastHeight != broadcastHeight { t.Fatalf("broadcast height mismatch: expected %v, got %v", openChans[0].FundingBroadcastHeight, broadcastHeight) } pendingChannels, err = cdb.FetchPendingChannels() require.NoError(t, err, "unable to list pending channels") if len(pendingChannels) != 0 { t.Fatalf("incorrect number of pending channels: expecting %v,"+ "got %v", 0, len(pendingChannels)) } } func TestFetchClosedChannels(t *testing.T) { t.Parallel() fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") cdb := fullDB.ChannelStateDB() // Create an open channel in the database. state := createTestChannel(t, cdb, openChannelOption()) // Next, close the channel by including a close channel summary in the // database. summary := &ChannelCloseSummary{ ChanPoint: state.FundingOutpoint, ClosingTXID: rev, RemotePub: state.IdentityPub, Capacity: state.Capacity, SettledBalance: state.LocalCommitment.LocalBalance.ToSatoshis(), TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000, CloseType: RemoteForceClose, IsPending: true, LocalChanConfig: state.LocalChanCfg, } if err := state.CloseChannel(summary); err != nil { t.Fatalf("unable to close channel: %v", err) } // Query the database to ensure that the channel has now been properly // closed. We should get the same result whether querying for pending // channels only, or not. pendingClosed, err := cdb.FetchClosedChannels(true) require.NoError(t, err, "failed fetching closed channels") if len(pendingClosed) != 1 { t.Fatalf("incorrect number of pending closed channels: expecting %v,"+ "got %v", 1, len(pendingClosed)) } if !reflect.DeepEqual(summary, pendingClosed[0]) { t.Fatalf("database summaries don't match: expected %v got %v", spew.Sdump(summary), spew.Sdump(pendingClosed[0])) } closed, err := cdb.FetchClosedChannels(false) require.NoError(t, err, "failed fetching all closed channels") if len(closed) != 1 { t.Fatalf("incorrect number of closed channels: expecting %v, "+ "got %v", 1, len(closed)) } if !reflect.DeepEqual(summary, closed[0]) { t.Fatalf("database summaries don't match: expected %v got %v", spew.Sdump(summary), spew.Sdump(closed[0])) } // Mark the channel as fully closed. err = cdb.MarkChanFullyClosed(&state.FundingOutpoint) require.NoError(t, err, "failed fully closing channel") // The channel should no longer be considered pending, but should still // be retrieved when fetching all the closed channels. closed, err = cdb.FetchClosedChannels(false) require.NoError(t, err, "failed fetching closed channels") if len(closed) != 1 { t.Fatalf("incorrect number of closed channels: expecting %v, "+ "got %v", 1, len(closed)) } pendingClose, err := cdb.FetchClosedChannels(true) require.NoError(t, err, "failed fetching channels pending close") if len(pendingClose) != 0 { t.Fatalf("incorrect number of closed channels: expecting %v, "+ "got %v", 0, len(closed)) } } // TestFetchWaitingCloseChannels ensures that the correct channels that are // waiting to be closed are returned. func TestFetchWaitingCloseChannels(t *testing.T) { t.Parallel() const numChannels = 2 const broadcastHeight = 99 // We'll start by creating two channels within our test database. One of // them will have their funding transaction confirmed on-chain, while // the other one will remain unconfirmed. fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") cdb := fullDB.ChannelStateDB() channels := make([]*OpenChannel, numChannels) for i := 0; i < numChannels; i++ { // Create a pending channel in the database at the broadcast // height. channels[i] = createTestChannel( t, cdb, pendingHeightOption(broadcastHeight), ) } // We'll only confirm the first one. channelConf := lnwire.ShortChannelID{ BlockHeight: broadcastHeight + 1, TxIndex: 10, TxPosition: 15, } if err := channels[0].MarkAsOpen(channelConf); err != nil { t.Fatalf("unable to mark channel as open: %v", err) } // Then, we'll mark the channels as if their commitments were broadcast. // This would happen in the event of a force close and should make the // channels enter a state of waiting close. for _, channel := range channels { closeTx := wire.NewMsgTx(2) closeTx.AddTxIn( &wire.TxIn{ PreviousOutPoint: channel.FundingOutpoint, }, ) if err := channel.MarkCommitmentBroadcasted(closeTx, true); err != nil { t.Fatalf("unable to mark commitment broadcast: %v", err) } // Now try to marking a coop close with a nil tx. This should // succeed, but it shouldn't exit when queried. if err = channel.MarkCoopBroadcasted(nil, true); err != nil { t.Fatalf("unable to mark nil coop broadcast: %v", err) } _, err := channel.BroadcastedCooperative() if err != ErrNoCloseTx { t.Fatalf("expected no closing tx error, got: %v", err) } // Finally, modify the close tx deterministically and also mark // it as coop closed. Later we will test that distinct // transactions are returned for both coop and force closes. closeTx.TxIn[0].PreviousOutPoint.Index ^= 1 if err := channel.MarkCoopBroadcasted(closeTx, true); err != nil { t.Fatalf("unable to mark coop broadcast: %v", err) } } // Now, we'll fetch all the channels waiting to be closed from the // database. We should expect to see both channels above, even if any of // them haven't had their funding transaction confirm on-chain. waitingCloseChannels, err := cdb.FetchWaitingCloseChannels() require.NoError(t, err, "unable to fetch all waiting close channels") if len(waitingCloseChannels) != numChannels { t.Fatalf("expected %d channels waiting to be closed, got %d", 2, len(waitingCloseChannels)) } expectedChannels := make(map[wire.OutPoint]struct{}) for _, channel := range channels { expectedChannels[channel.FundingOutpoint] = struct{}{} } for _, channel := range waitingCloseChannels { if _, ok := expectedChannels[channel.FundingOutpoint]; !ok { t.Fatalf("expected channel %v to be waiting close", channel.FundingOutpoint) } chanPoint := channel.FundingOutpoint // Assert that the force close transaction is retrievable. forceCloseTx, err := channel.BroadcastedCommitment() if err != nil { t.Fatalf("Unable to retrieve commitment: %v", err) } if forceCloseTx.TxIn[0].PreviousOutPoint != chanPoint { t.Fatalf("expected outpoint %v, got %v", chanPoint, forceCloseTx.TxIn[0].PreviousOutPoint) } // Assert that the coop close transaction is retrievable. coopCloseTx, err := channel.BroadcastedCooperative() if err != nil { t.Fatalf("unable to retrieve coop close: %v", err) } chanPoint.Index ^= 1 if coopCloseTx.TxIn[0].PreviousOutPoint != chanPoint { t.Fatalf("expected outpoint %v, got %v", chanPoint, coopCloseTx.TxIn[0].PreviousOutPoint) } } } // TestShutdownInfo tests that a channel's shutdown info can correctly be // persisted and retrieved. func TestShutdownInfo(t *testing.T) { t.Parallel() tests := []struct { name string localInit bool }{ { name: "local node initiated", localInit: true, }, { name: "remote node initiated", localInit: false, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() testShutdownInfo(t, test.localInit) }) } } func testShutdownInfo(t *testing.T, locallyInitiated bool) { fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") cdb := fullDB.ChannelStateDB() // First a test channel. channel := createTestChannel(t, cdb) // We haven't persisted any shutdown info for this channel yet. _, err = channel.ShutdownInfo() require.Error(t, err, ErrNoShutdownInfo) // Construct a new delivery script and create a new ShutdownInfo object. script := []byte{1, 3, 4, 5} // Create a ShutdownInfo struct. shutdownInfo := NewShutdownInfo(script, locallyInitiated) // Persist the shutdown info. require.NoError(t, channel.MarkShutdownSent(shutdownInfo)) // We should now be able to retrieve the shutdown info. info, err := channel.ShutdownInfo() require.NoError(t, err) require.True(t, info.IsSome()) // Assert that the decoded values of the shutdown info are correct. info.WhenSome(func(info ShutdownInfo) { require.EqualValues(t, script, info.DeliveryScript.Val) require.Equal(t, locallyInitiated, info.LocalInitiator.Val) }) } // TestRefresh asserts that Refresh updates the in-memory state of another // OpenChannel to reflect a preceding call to MarkOpen on a different // OpenChannel. func TestRefresh(t *testing.T) { t.Parallel() fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") cdb := fullDB.ChannelStateDB() // First create a test channel. state := createTestChannel(t, cdb) // Next, locate the pending channel with the database. pendingChannels, err := cdb.FetchPendingChannels() if err != nil { t.Fatalf("unable to load pending channels; %v", err) } var pendingChannel *OpenChannel for _, channel := range pendingChannels { if channel.FundingOutpoint == state.FundingOutpoint { pendingChannel = channel break } } if pendingChannel == nil { t.Fatalf("unable to find pending channel with funding "+ "outpoint=%v: %v", state.FundingOutpoint, err) } // Next, simulate the confirmation of the channel by marking it as // pending within the database. chanOpenLoc := lnwire.ShortChannelID{ BlockHeight: 105, TxIndex: 10, TxPosition: 15, } err = state.MarkAsOpen(chanOpenLoc) require.NoError(t, err, "unable to mark channel open") // The short_chan_id of the receiver to MarkAsOpen should reflect the // open location, but the other pending channel should remain unchanged. if state.ShortChanID() == pendingChannel.ShortChanID() { t.Fatalf("pending channel short_chan_ID should not have been " + "updated before refreshing short_chan_id") } // Now that the receiver's short channel id has been updated, check to // ensure that the channel packager's source has been updated as well. // This ensures that the packager will read and write to buckets // corresponding to the new short chan id, instead of the prior. if state.Packager.(*ChannelPackager).source != chanOpenLoc { t.Fatalf("channel packager source was not updated: want %v, "+ "got %v", chanOpenLoc, state.Packager.(*ChannelPackager).source) } // Now, refresh the state of the pending channel. err = pendingChannel.Refresh() require.NoError(t, err, "unable to refresh short_chan_id") // This should result in both OpenChannel's now having the same // ShortChanID. if state.ShortChanID() != pendingChannel.ShortChanID() { t.Fatalf("expected pending channel short_chan_id to be "+ "refreshed: want %v, got %v", state.ShortChanID(), pendingChannel.ShortChanID()) } // Check to ensure that the _other_ OpenChannel channel packager's // source has also been updated after the refresh. This ensures that the // other packagers will read and write to buckets corresponding to the // updated short chan id. if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { t.Fatalf("channel packager source was not updated: want %v, "+ "got %v", chanOpenLoc, pendingChannel.Packager.(*ChannelPackager).source) } // Check to ensure that this channel is no longer pending and this field // is up to date. if pendingChannel.IsPending { t.Fatalf("channel pending state wasn't updated: want false got true") } } // TestCloseInitiator tests the setting of close initiator statuses for // cooperative closes and local force closes. func TestCloseInitiator(t *testing.T) { tests := []struct { name string // updateChannel is called to update the channel as broadcast, // cooperatively or not, based on the test's requirements. updateChannel func(c *OpenChannel) error expectedStatuses []ChannelStatus }{ { name: "local coop close", // Mark the channel as cooperatively closed, initiated // by the local party. updateChannel: func(c *OpenChannel) error { return c.MarkCoopBroadcasted( &wire.MsgTx{}, true, ) }, expectedStatuses: []ChannelStatus{ ChanStatusLocalCloseInitiator, ChanStatusCoopBroadcasted, }, }, { name: "remote coop close", // Mark the channel as cooperatively closed, initiated // by the remote party. updateChannel: func(c *OpenChannel) error { return c.MarkCoopBroadcasted( &wire.MsgTx{}, false, ) }, expectedStatuses: []ChannelStatus{ ChanStatusRemoteCloseInitiator, ChanStatusCoopBroadcasted, }, }, { name: "local force close", // Mark the channel's commitment as broadcast with // local initiator. updateChannel: func(c *OpenChannel) error { return c.MarkCommitmentBroadcasted( &wire.MsgTx{}, true, ) }, expectedStatuses: []ChannelStatus{ ChanStatusLocalCloseInitiator, ChanStatusCommitBroadcasted, }, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() fullDB, err := MakeTestDB(t) if err != nil { t.Fatalf("unable to make test database: %v", err) } cdb := fullDB.ChannelStateDB() // Create an open channel. channel := createTestChannel( t, cdb, openChannelOption(), ) err = test.updateChannel(channel) if err != nil { t.Fatalf("unexpected error: %v", err) } // Lookup open channels in the database. dbChans, err := fetchChannels( cdb, pendingChannelFilter(false), ) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(dbChans) != 1 { t.Fatalf("expected 1 channel, got: %v", len(dbChans)) } // Check that the statuses that we expect were written // to disk. for _, status := range test.expectedStatuses { if !dbChans[0].HasChanStatus(status) { t.Fatalf("expected channel to have "+ "status: %v, has status: %v", status, dbChans[0].chanStatus) } } }) } } // TestCloseChannelStatus tests setting of a channel status on the historical // channel on channel close. func TestCloseChannelStatus(t *testing.T) { fullDB, err := MakeTestDB(t) if err != nil { t.Fatalf("unable to make test database: %v", err) } cdb := fullDB.ChannelStateDB() // Create an open channel. channel := createTestChannel( t, cdb, openChannelOption(), ) if err := channel.CloseChannel( &ChannelCloseSummary{ ChanPoint: channel.FundingOutpoint, RemotePub: channel.IdentityPub, }, ChanStatusRemoteCloseInitiator, ); err != nil { t.Fatalf("unexpected error: %v", err) } histChan, err := channel.Db.FetchHistoricalChannel( &channel.FundingOutpoint, ) require.NoError(t, err, "unexpected error") if !histChan.HasChanStatus(ChanStatusRemoteCloseInitiator) { t.Fatalf("channel should have status") } } // TestHasChanStatus asserts the behavior of HasChanStatus by checking the // behavior of various status flags in addition to the special case of // ChanStatusDefault which is treated like a flag in the code base even though // it isn't. func TestHasChanStatus(t *testing.T) { tests := []struct { name string status ChannelStatus expHas map[ChannelStatus]bool }{ { name: "default", status: ChanStatusDefault, expHas: map[ChannelStatus]bool{ ChanStatusDefault: true, ChanStatusBorked: false, }, }, { name: "single flag", status: ChanStatusBorked, expHas: map[ChannelStatus]bool{ ChanStatusDefault: false, ChanStatusBorked: true, }, }, { name: "multiple flags", status: ChanStatusBorked | ChanStatusLocalDataLoss, expHas: map[ChannelStatus]bool{ ChanStatusDefault: false, ChanStatusBorked: true, ChanStatusLocalDataLoss: true, }, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { c := &OpenChannel{ chanStatus: test.status, } for status, expHas := range test.expHas { has := c.HasChanStatus(status) if has == expHas { continue } t.Fatalf("expected chan status to "+ "have %s? %t, got: %t", status, expHas, has) } }) } } // TestKeyLocatorEncoding tests that we are able to serialize a given // keychain.KeyLocator. After successfully encoding, we check that the decode // output arrives at the same initial KeyLocator. func TestKeyLocatorEncoding(t *testing.T) { keyLoc := keychain.KeyLocator{ Family: keychain.KeyFamilyRevocationRoot, Index: keyLocIndex, } // First, we'll encode the KeyLocator into a buffer. var ( b bytes.Buffer buf [8]byte ) err := EKeyLocator(&b, &keyLoc, &buf) require.NoError(t, err, "unable to encode key locator") // Next, we'll attempt to decode the bytes into a new KeyLocator. r := bytes.NewReader(b.Bytes()) var decodedKeyLoc keychain.KeyLocator err = DKeyLocator(r, &decodedKeyLoc, &buf, 8) require.NoError(t, err, "unable to decode key locator") // Finally, we'll compare that the original KeyLocator and the decoded // version are equal. require.Equal(t, keyLoc, decodedKeyLoc) } // TestFinalHtlcs tests final htlc storage and retrieval. func TestFinalHtlcs(t *testing.T) { t.Parallel() fullDB, err := MakeTestDB(t, OptionStoreFinalHtlcResolutions(true)) require.NoError(t, err, "unable to make test database") cdb := fullDB.ChannelStateDB() chanID := lnwire.ShortChannelID{ BlockHeight: 1, TxIndex: 2, TxPosition: 3, } // Test unknown htlc lookup. const unknownHtlcID = 999 _, err = cdb.LookupFinalHtlc(chanID, unknownHtlcID) require.ErrorIs(t, err, ErrHtlcUnknown) // Test offchain final htlcs. const offchainHtlcID = 1 err = kvdb.Update(cdb.backend, func(tx kvdb.RwTx) error { bucket, err := fetchFinalHtlcsBucketRw( tx, chanID, ) require.NoError(t, err) return putFinalHtlc(bucket, offchainHtlcID, FinalHtlcInfo{ Settled: true, Offchain: true, }) }, func() {}) require.NoError(t, err) info, err := cdb.LookupFinalHtlc(chanID, offchainHtlcID) require.NoError(t, err) require.True(t, info.Settled) require.True(t, info.Offchain) // Test onchain final htlcs. const onchainHtlcID = 2 err = cdb.PutOnchainFinalHtlcOutcome(chanID, onchainHtlcID, true) require.NoError(t, err) info, err = cdb.LookupFinalHtlc(chanID, onchainHtlcID) require.NoError(t, err) require.True(t, info.Settled) require.False(t, info.Offchain) // Test unknown htlc lookup for existing channel. _, err = cdb.LookupFinalHtlc(chanID, unknownHtlcID) require.ErrorIs(t, err, ErrHtlcUnknown) } // TestHTLCsExtraData tests serialization and deserialization of HTLCs // combined with extra data. func TestHTLCsExtraData(t *testing.T) { t.Parallel() mockHtlc := HTLC{ Signature: testSig.Serialize(), Incoming: false, Amt: 10, RHash: key, RefundTimeout: 1, OnionBlob: lnmock.MockOnion(), } testCases := []struct { name string htlcs []HTLC }{ { // Serialize multiple HLTCs with no extra data to // assert that there is no regression for HTLCs with // no extra data. name: "no extra data", htlcs: []HTLC{ mockHtlc, mockHtlc, }, }, { name: "mixed extra data", htlcs: []HTLC{ mockHtlc, { Signature: testSig.Serialize(), Incoming: false, Amt: 10, RHash: key, RefundTimeout: 1, OnionBlob: lnmock.MockOnion(), ExtraData: []byte{1, 2, 3}, }, mockHtlc, { Signature: testSig.Serialize(), Incoming: false, Amt: 10, RHash: key, RefundTimeout: 1, OnionBlob: lnmock.MockOnion(), ExtraData: bytes.Repeat( []byte{9}, 999, ), }, }, }, } for _, testCase := range testCases { testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() var b bytes.Buffer err := SerializeHtlcs(&b, testCase.htlcs...) require.NoError(t, err) r := bytes.NewReader(b.Bytes()) htlcs, err := DeserializeHtlcs(r) require.NoError(t, err) require.Equal(t, testCase.htlcs, htlcs) }) } } // TestOnionBlobIncorrectLength tests HTLC deserialization in the case where // the OnionBlob saved on disk is of an unexpected length. This error case is // only expected in the case of database corruption (or some severe protocol // breakdown/bug). A HTLC is manually serialized because we cannot force a // case where we write an onion blob of incorrect length. func TestOnionBlobIncorrectLength(t *testing.T) { t.Parallel() var b bytes.Buffer var numHtlcs uint16 = 1 require.NoError(t, WriteElement(&b, numHtlcs)) require.NoError(t, WriteElements( &b, // Number of HTLCs. numHtlcs, // Signature, incoming, amount, Rhash, Timeout. testSig.Serialize(), false, lnwire.MilliSatoshi(10), key, uint32(1), // Write an onion blob that is half of our expected size. bytes.Repeat([]byte{1}, lnwire.OnionPacketSize/2), )) _, err := DeserializeHtlcs(&b) require.ErrorIs(t, err, ErrOnionBlobLength) }