Merge pull request #9183 from lightningnetwork/0-18-4-branch-rc1

release: create branch for v0.18.4-beta.rc1
This commit is contained in:
Oliver Gugger 2024-11-21 20:19:40 +01:00 committed by GitHub
commit c1129bb086
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
309 changed files with 24299 additions and 9708 deletions

View file

@ -21,7 +21,7 @@ defaults:
shell: bash shell: bash
env: env:
BITCOIN_VERSION: "27" BITCOIN_VERSION: "28"
TRANCHES: 8 TRANCHES: 8
@ -31,7 +31,7 @@ env:
# /dev.Dockerfile # /dev.Dockerfile
# /make/builder.Dockerfile # /make/builder.Dockerfile
# /.github/workflows/release.yml # /.github/workflows/release.yml
GO_VERSION: 1.22.5 GO_VERSION: 1.22.6
jobs: jobs:
######################## ########################

View file

@ -11,12 +11,11 @@ defaults:
env: env:
# If you change this value, please change it in the following files as well: # If you change this value, please change it in the following files as well:
# /.travis.yml
# /Dockerfile # /Dockerfile
# /dev.Dockerfile # /dev.Dockerfile
# /make/builder.Dockerfile # /make/builder.Dockerfile
# /.github/workflows/main.yml # /.github/workflows/main.yml
GO_VERSION: 1.22.5 GO_VERSION: 1.22.6
jobs: jobs:
main: main:

1
.gitignore vendored
View file

@ -66,6 +66,7 @@ profile.tmp
.DS_Store .DS_Store
.vscode .vscode
*.code-workspace
# Coverage test # Coverage test
coverage.txt coverage.txt

View file

@ -1,18 +1,8 @@
run: run:
# timeout for analysis go: "1.22.6"
deadline: 10m
# Skip autogenerated files for mobile and gRPC as well as copied code for # Abort after 10 minutes.
# internal use. timeout: 10m
skip-files:
- "mobile\\/.*generated\\.go"
- "\\.pb\\.go$"
- "\\.pb\\.gw\\.go$"
- "internal\\/musig2v040"
skip-dirs:
- channeldb/migration_01_to_11
- channeldb/migration/lnwire21
build-tags: build-tags:
- autopilotrpc - autopilotrpc
@ -57,7 +47,6 @@ linters-settings:
- G306 # Poor file permissions used when writing to a new file. - G306 # Poor file permissions used when writing to a new file.
staticcheck: staticcheck:
go: "1.22.5"
checks: ["-SA1019"] checks: ["-SA1019"]
lll: lll:
@ -133,25 +122,15 @@ linters:
- gochecknoinits - gochecknoinits
# Deprecated linters. See https://golangci-lint.run/usage/linters/. # Deprecated linters. See https://golangci-lint.run/usage/linters/.
- interfacer
- golint
- maligned
- scopelint
- exhaustivestruct
- bodyclose - bodyclose
- contextcheck - contextcheck
- nilerr - nilerr
- noctx - noctx
- rowserrcheck - rowserrcheck
- sqlclosecheck - sqlclosecheck
- structcheck
- tparallel - tparallel
- unparam - unparam
- wastedassign - wastedassign
- ifshort
- varcheck
- deadcode
- nosnakecase
# Disable gofumpt as it has weird behavior regarding formatting multiple # Disable gofumpt as it has weird behavior regarding formatting multiple
@ -191,7 +170,7 @@ linters:
- wrapcheck - wrapcheck
# Allow dynamic errors. # Allow dynamic errors.
- goerr113 - err113
# We use ErrXXX instead. # We use ErrXXX instead.
- errname - errname
@ -207,15 +186,41 @@ linters:
# The linter is too aggressive and doesn't add much value since reviewers # The linter is too aggressive and doesn't add much value since reviewers
# will also catch magic numbers that make sense to extract. # will also catch magic numbers that make sense to extract.
- gomnd - gomnd
- mnd
# Some of the tests cannot be parallelized. On the other hand, we don't # Some of the tests cannot be parallelized. On the other hand, we don't
# gain much performance with this check so we disable it for now until # gain much performance with this check so we disable it for now until
# unit tests become our CI bottleneck. # unit tests become our CI bottleneck.
- paralleltest - paralleltest
# New linters that we haven't had time to address yet.
- testifylint
- perfsprint
- inamedparam
- copyloopvar
- tagalign
- protogetter
- revive
- depguard
- gosmopolitan
- intrange
issues: issues:
# Only show newly introduced problems. # Only show newly introduced problems.
new-from-rev: 8c66353e4c02329abdacb5a8df29998035ec2e24 new-from-rev: 77c7f776d5cbf9e147edc81d65ae5ba177a684e5
# Skip autogenerated files for mobile and gRPC as well as copied code for
# internal use.
skip-files:
- "mobile\\/.*generated\\.go"
- "\\.pb\\.go$"
- "\\.pb\\.gw\\.go$"
- "internal\\/musig2v040"
skip-dirs:
- channeldb/migration_01_to_11
- channeldb/migration/lnwire21
exclude-rules: exclude-rules:
# Exclude gosec from running for tests so that tests with weak randomness # Exclude gosec from running for tests so that tests with weak randomness
@ -256,8 +261,8 @@ issues:
- forbidigo - forbidigo
- godot - godot
# Allow fmt.Printf() in lncli. # Allow fmt.Printf() in commands.
- path: cmd/lncli/* - path: cmd/commands/*
linters: linters:
- forbidigo - forbidigo

View file

@ -3,7 +3,7 @@
# /make/builder.Dockerfile # /make/builder.Dockerfile
# /.github/workflows/main.yml # /.github/workflows/main.yml
# /.github/workflows/release.yml # /.github/workflows/release.yml
FROM golang:1.22.5-alpine as builder FROM golang:1.22.6-alpine as builder
# Force Go to use the cgo based DNS resolver. This is required to ensure DNS # Force Go to use the cgo based DNS resolver. This is required to ensure DNS
# queries required to connect to linked containers succeed. # queries required to connect to linked containers succeed.

View file

@ -35,7 +35,7 @@ endif
# GO_VERSION is the Go version used for the release build, docker files, and # GO_VERSION is the Go version used for the release build, docker files, and
# GitHub Actions. This is the reference version for the project. All other Go # GitHub Actions. This is the reference version for the project. All other Go
# versions are checked against this version. # versions are checked against this version.
GO_VERSION = 1.22.5 GO_VERSION = 1.22.6
GOBUILD := $(LOOPVARFIX) go build -v GOBUILD := $(LOOPVARFIX) go build -v
GOINSTALL := $(LOOPVARFIX) go install -v GOINSTALL := $(LOOPVARFIX) go install -v

View file

@ -5,11 +5,22 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"golang.org/x/exp/maps"
) )
// UpdateLinkAliases is a function type for a function that locates the active
// link that matches the given shortID and triggers an update based on the
// latest values of the alias manager.
type UpdateLinkAliases func(shortID lnwire.ShortChannelID) error
// ScidAliasMap is a map from a base short channel ID to a set of alias short
// channel IDs.
type ScidAliasMap map[lnwire.ShortChannelID][]lnwire.ShortChannelID
var ( var (
// aliasBucket stores aliases as keys and their base SCIDs as values. // aliasBucket stores aliases as keys and their base SCIDs as values.
// This is used to populate the maps that the Manager uses. The keys // This is used to populate the maps that the Manager uses. The keys
@ -47,17 +58,18 @@ var (
// operations. // operations.
byteOrder = binary.BigEndian byteOrder = binary.BigEndian
// startBlockHeight is the starting block height of the alias range. // AliasStartBlockHeight is the starting block height of the alias
startingBlockHeight = 16_000_000 // range.
AliasStartBlockHeight uint32 = 16_000_000
// endBlockHeight is the ending block height of the alias range. // AliasEndBlockHeight is the ending block height of the alias range.
endBlockHeight = 16_250_000 AliasEndBlockHeight uint32 = 16_250_000
// StartingAlias is the first alias ShortChannelID that will get // StartingAlias is the first alias ShortChannelID that will get
// assigned by RequestAlias. The starting BlockHeight is chosen so that // assigned by RequestAlias. The starting BlockHeight is chosen so that
// legitimate SCIDs in integration tests aren't mistaken for an alias. // legitimate SCIDs in integration tests aren't mistaken for an alias.
StartingAlias = lnwire.ShortChannelID{ StartingAlias = lnwire.ShortChannelID{
BlockHeight: uint32(startingBlockHeight), BlockHeight: AliasStartBlockHeight,
TxIndex: 0, TxIndex: 0,
TxPosition: 0, TxPosition: 0,
} }
@ -68,6 +80,10 @@ var (
// errNoPeerAlias is returned when the peer's alias for a given // errNoPeerAlias is returned when the peer's alias for a given
// channel is not found. // channel is not found.
errNoPeerAlias = fmt.Errorf("no peer alias found") errNoPeerAlias = fmt.Errorf("no peer alias found")
// ErrAliasNotFound is returned when the alias is not found and can't
// be mapped to a base SCID.
ErrAliasNotFound = fmt.Errorf("alias not found")
) )
// Manager is a struct that handles aliases for LND. It has an underlying // Manager is a struct that handles aliases for LND. It has an underlying
@ -77,10 +93,14 @@ var (
type Manager struct { type Manager struct {
backend kvdb.Backend backend kvdb.Backend
// linkAliasUpdater is a function used by the alias manager to
// facilitate live update of aliases in other subsystems.
linkAliasUpdater UpdateLinkAliases
// baseToSet is a mapping from the "base" SCID to the set of aliases // baseToSet is a mapping from the "base" SCID to the set of aliases
// for this channel. This mapping includes all channels that // for this channel. This mapping includes all channels that
// negotiated the option-scid-alias feature bit. // negotiated the option-scid-alias feature bit.
baseToSet map[lnwire.ShortChannelID][]lnwire.ShortChannelID baseToSet ScidAliasMap
// aliasToBase is a mapping that maps all aliases for a given channel // aliasToBase is a mapping that maps all aliases for a given channel
// to its base SCID. This is only used for channels that have // to its base SCID. This is only used for channels that have
@ -98,9 +118,15 @@ type Manager struct {
} }
// NewManager initializes an alias Manager from the passed database backend. // NewManager initializes an alias Manager from the passed database backend.
func NewManager(db kvdb.Backend) (*Manager, error) { func NewManager(db kvdb.Backend, linkAliasUpdater UpdateLinkAliases) (*Manager,
m := &Manager{backend: db} error) {
m.baseToSet = make(map[lnwire.ShortChannelID][]lnwire.ShortChannelID)
m := &Manager{
backend: db,
baseToSet: make(ScidAliasMap),
linkAliasUpdater: linkAliasUpdater,
}
m.aliasToBase = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) m.aliasToBase = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID)
m.peerAlias = make(map[lnwire.ChannelID]lnwire.ShortChannelID) m.peerAlias = make(map[lnwire.ChannelID]lnwire.ShortChannelID)
@ -215,12 +241,22 @@ func (m *Manager) populateMaps() error {
// AddLocalAlias adds a database mapping from the passed alias to the passed // AddLocalAlias adds a database mapping from the passed alias to the passed
// base SCID. The gossip boolean marks whether or not to create a mapping // base SCID. The gossip boolean marks whether or not to create a mapping
// that the gossiper will use. It is set to false for the upgrade path where // that the gossiper will use. It is set to false for the upgrade path where
// the feature-bit is toggled on and there are existing channels. // the feature-bit is toggled on and there are existing channels. The linkUpdate
// flag is used to signal whether this function should also trigger an update
// on the htlcswitch scid alias maps.
func (m *Manager) AddLocalAlias(alias, baseScid lnwire.ShortChannelID, func (m *Manager) AddLocalAlias(alias, baseScid lnwire.ShortChannelID,
gossip bool) error { gossip, linkUpdate bool) error {
// We need to lock the manager for the whole duration of this method,
// except for the very last part where we call the link updater. In
// order for us to safely use a defer _and_ still be able to manually
// unlock, we use a sync.Once.
m.Lock() m.Lock()
defer m.Unlock() unlockOnce := sync.Once{}
unlock := func() {
unlockOnce.Do(m.Unlock)
}
defer unlock()
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
// If the caller does not want to allow the alias to be used // If the caller does not want to allow the alias to be used
@ -270,6 +306,18 @@ func (m *Manager) AddLocalAlias(alias, baseScid lnwire.ShortChannelID,
m.aliasToBase[alias] = baseScid m.aliasToBase[alias] = baseScid
} }
// We definitely need to unlock the Manager before calling the link
// updater. If we don't, we'll deadlock. We use a sync.Once to ensure
// that we only unlock once.
unlock()
// Finally, we trigger a htlcswitch update if the flag is set, in order
// for any future htlc that references the added alias to be properly
// routed.
if linkUpdate {
return m.linkAliasUpdater(baseScid)
}
return nil return nil
} }
@ -340,6 +388,74 @@ func (m *Manager) DeleteSixConfs(baseScid lnwire.ShortChannelID) error {
return nil return nil
} }
// DeleteLocalAlias removes a mapping from the database and the Manager's maps.
func (m *Manager) DeleteLocalAlias(alias,
baseScid lnwire.ShortChannelID) error {
// We need to lock the manager for the whole duration of this method,
// except for the very last part where we call the link updater. In
// order for us to safely use a defer _and_ still be able to manually
// unlock, we use a sync.Once.
m.Lock()
unlockOnce := sync.Once{}
unlock := func() {
unlockOnce.Do(m.Unlock)
}
defer unlock()
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
aliasToBaseBucket, err := tx.CreateTopLevelBucket(aliasBucket)
if err != nil {
return err
}
var aliasBytes [8]byte
byteOrder.PutUint64(aliasBytes[:], alias.ToUint64())
// If the user attempts to delete an alias that doesn't exist,
// we'll want to inform them about it and not just do nothing.
if aliasToBaseBucket.Get(aliasBytes[:]) == nil {
return ErrAliasNotFound
}
return aliasToBaseBucket.Delete(aliasBytes[:])
}, func() {})
if err != nil {
return err
}
// Now that the database state has been updated, we'll delete the
// mapping from the Manager's maps.
aliasSet, ok := m.baseToSet[baseScid]
if !ok {
return ErrAliasNotFound
}
// We'll filter the alias set and remove the alias from it.
aliasSet = fn.Filter(func(a lnwire.ShortChannelID) bool {
return a.ToUint64() != alias.ToUint64()
}, aliasSet)
// If the alias set is empty, we'll delete the base SCID from the
// baseToSet map.
if len(aliasSet) == 0 {
delete(m.baseToSet, baseScid)
} else {
m.baseToSet[baseScid] = aliasSet
}
// Finally, we'll delete the aliasToBase mapping from the Manager's
// cache (but this is only set if we gossip the alias).
delete(m.aliasToBase, alias)
// We definitely need to unlock the Manager before calling the link
// updater. If we don't, we'll deadlock. We use a sync.Once to ensure
// that we only unlock once.
unlock()
return m.linkAliasUpdater(baseScid)
}
// PutPeerAlias stores the peer's alias SCID once we learn of it in the // PutPeerAlias stores the peer's alias SCID once we learn of it in the
// channel_ready message. // channel_ready message.
func (m *Manager) PutPeerAlias(chanID lnwire.ChannelID, func (m *Manager) PutPeerAlias(chanID lnwire.ChannelID,
@ -392,6 +508,19 @@ func (m *Manager) GetPeerAlias(chanID lnwire.ChannelID) (lnwire.ShortChannelID,
func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) { func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) {
var nextAlias lnwire.ShortChannelID var nextAlias lnwire.ShortChannelID
m.RLock()
defer m.RUnlock()
// haveAlias returns true if the passed alias is already assigned to a
// channel in the baseToSet map.
haveAlias := func(maybeNextAlias lnwire.ShortChannelID) bool {
return fn.Any(func(aliasList []lnwire.ShortChannelID) bool {
return fn.Any(func(alias lnwire.ShortChannelID) bool {
return alias == maybeNextAlias
}, aliasList)
}, maps.Values(m.baseToSet))
}
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(aliasAllocBucket) bucket, err := tx.CreateTopLevelBucket(aliasAllocBucket)
if err != nil { if err != nil {
@ -404,6 +533,29 @@ func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) {
// StartingAlias to it. // StartingAlias to it.
nextAlias = StartingAlias nextAlias = StartingAlias
// If the very first alias is already assigned, we'll
// keep incrementing until we find an unassigned alias.
// This is to avoid collision with custom added SCID
// aliases that fall into the same range as the ones we
// generate here monotonically. Those custom SCIDs are
// stored in a different bucket, but we can just check
// the in-memory map for simplicity.
for {
if !haveAlias(nextAlias) {
break
}
nextAlias = getNextScid(nextAlias)
// Abort if we've reached the end of the range.
if nextAlias.BlockHeight >=
AliasEndBlockHeight {
return fmt.Errorf("range for custom " +
"aliases exhausted")
}
}
var scratch [8]byte var scratch [8]byte
byteOrder.PutUint64(scratch[:], nextAlias.ToUint64()) byteOrder.PutUint64(scratch[:], nextAlias.ToUint64())
return bucket.Put(lastAliasKey, scratch[:]) return bucket.Put(lastAliasKey, scratch[:])
@ -418,6 +570,26 @@ func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) {
) )
nextAlias = getNextScid(lastScid) nextAlias = getNextScid(lastScid)
// If the next alias is already assigned, we'll keep
// incrementing until we find an unassigned alias. This is to
// avoid collision with custom added SCID aliases that fall into
// the same range as the ones we generate here monotonically.
// Those custom SCIDs are stored in a different bucket, but we
// can just check the in-memory map for simplicity.
for {
if !haveAlias(nextAlias) {
break
}
nextAlias = getNextScid(nextAlias)
// Abort if we've reached the end of the range.
if nextAlias.BlockHeight >= AliasEndBlockHeight {
return fmt.Errorf("range for custom " +
"aliases exhausted")
}
}
var scratch [8]byte var scratch [8]byte
byteOrder.PutUint64(scratch[:], nextAlias.ToUint64()) byteOrder.PutUint64(scratch[:], nextAlias.ToUint64())
return bucket.Put(lastAliasKey, scratch[:]) return bucket.Put(lastAliasKey, scratch[:])
@ -433,11 +605,11 @@ func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) {
// ListAliases returns a carbon copy of baseToSet. This is used by the rpc // ListAliases returns a carbon copy of baseToSet. This is used by the rpc
// layer. // layer.
func (m *Manager) ListAliases() map[lnwire.ShortChannelID][]lnwire.ShortChannelID { func (m *Manager) ListAliases() ScidAliasMap {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()
baseCopy := make(map[lnwire.ShortChannelID][]lnwire.ShortChannelID) baseCopy := make(ScidAliasMap)
for k, v := range m.baseToSet { for k, v := range m.baseToSet {
setCopy := make([]lnwire.ShortChannelID, len(v)) setCopy := make([]lnwire.ShortChannelID, len(v))
@ -496,10 +668,10 @@ func getNextScid(last lnwire.ShortChannelID) lnwire.ShortChannelID {
// IsAlias returns true if the passed SCID is an alias. The function determines // IsAlias returns true if the passed SCID is an alias. The function determines
// this by looking at the BlockHeight. If the BlockHeight is greater than // this by looking at the BlockHeight. If the BlockHeight is greater than
// startingBlockHeight and less than endBlockHeight, then it is an alias // AliasStartBlockHeight and less than AliasEndBlockHeight, then it is an alias
// assigned by RequestAlias. These bounds only apply to aliases we generate. // assigned by RequestAlias. These bounds only apply to aliases we generate.
// Our peers are free to use any range they choose. // Our peers are free to use any range they choose.
func IsAlias(scid lnwire.ShortChannelID) bool { func IsAlias(scid lnwire.ShortChannelID) bool {
return scid.BlockHeight >= uint32(startingBlockHeight) && return scid.BlockHeight >= AliasStartBlockHeight &&
scid.BlockHeight < uint32(endBlockHeight) scid.BlockHeight < AliasEndBlockHeight
} }

View file

@ -23,7 +23,11 @@ func TestAliasStorePeerAlias(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer db.Close()
aliasStore, err := NewManager(db) linkUpdater := func(shortID lnwire.ShortChannelID) error {
return nil
}
aliasStore, err := NewManager(db, linkUpdater)
require.NoError(t, err) require.NoError(t, err)
var chanID1 [32]byte var chanID1 [32]byte
@ -52,7 +56,11 @@ func TestAliasStoreRequest(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer db.Close()
aliasStore, err := NewManager(db) linkUpdater := func(shortID lnwire.ShortChannelID) error {
return nil
}
aliasStore, err := NewManager(db, linkUpdater)
require.NoError(t, err) require.NoError(t, err)
// We'll assert that the very first alias we receive is StartingAlias. // We'll assert that the very first alias we receive is StartingAlias.
@ -68,6 +76,118 @@ func TestAliasStoreRequest(t *testing.T) {
require.Equal(t, nextAlias, alias2) require.Equal(t, nextAlias, alias2)
} }
// TestAliasLifecycle tests that the aliases can be created and deleted.
func TestAliasLifecycle(t *testing.T) {
t.Parallel()
// Create the backend database and use this to create the aliasStore.
dbPath := filepath.Join(t.TempDir(), "testdb")
db, err := kvdb.Create(
kvdb.BoltBackendName, dbPath, true, kvdb.DefaultDBTimeout,
)
require.NoError(t, err)
defer db.Close()
updateChan := make(chan struct{}, 1)
linkUpdater := func(shortID lnwire.ShortChannelID) error {
updateChan <- struct{}{}
return nil
}
aliasStore, err := NewManager(db, linkUpdater)
require.NoError(t, err)
const (
base = uint64(123123123)
alias = uint64(456456456)
)
// Parse the aliases and base to short channel ID format.
baseScid := lnwire.NewShortChanIDFromInt(base)
aliasScid := lnwire.NewShortChanIDFromInt(alias)
aliasScid2 := lnwire.NewShortChanIDFromInt(alias + 1)
// Add the first alias.
err = aliasStore.AddLocalAlias(aliasScid, baseScid, false, true)
require.NoError(t, err)
// The link updater should be called.
<-updateChan
// Query the aliases and verify the results.
aliasList := aliasStore.GetAliases(baseScid)
require.Len(t, aliasList, 1)
require.Contains(t, aliasList, aliasScid)
// Add the second alias.
err = aliasStore.AddLocalAlias(aliasScid2, baseScid, false, true)
require.NoError(t, err)
// The link updater should be called.
<-updateChan
// Query the aliases and verify the results.
aliasList = aliasStore.GetAliases(baseScid)
require.Len(t, aliasList, 2)
require.Contains(t, aliasList, aliasScid)
require.Contains(t, aliasList, aliasScid2)
// Delete the first alias.
err = aliasStore.DeleteLocalAlias(aliasScid, baseScid)
require.NoError(t, err)
// The link updater should be called.
<-updateChan
// We expect to get an error if we attempt to delete the same alias
// again.
err = aliasStore.DeleteLocalAlias(aliasScid, baseScid)
require.ErrorIs(t, err, ErrAliasNotFound)
// The link updater should _not_ be called.
select {
case <-updateChan:
t.Fatal("link alias updater should not have been called")
default:
}
// Query the aliases and verify that first one doesn't exist anymore.
aliasList = aliasStore.GetAliases(baseScid)
require.Len(t, aliasList, 1)
require.Contains(t, aliasList, aliasScid2)
require.NotContains(t, aliasList, aliasScid)
// Delete the second alias.
err = aliasStore.DeleteLocalAlias(aliasScid2, baseScid)
require.NoError(t, err)
// The link updater should be called.
<-updateChan
// Query the aliases and verify that none exists.
aliasList = aliasStore.GetAliases(baseScid)
require.Len(t, aliasList, 0)
// We now request an alias generated by the aliasStore. This should give
// the first from the pre-defined list of allocated aliases.
firstRequested, err := aliasStore.RequestAlias()
require.NoError(t, err)
require.Equal(t, StartingAlias, firstRequested)
// We now manually add the next alias from the range as a custom alias.
secondAlias := getNextScid(firstRequested)
err = aliasStore.AddLocalAlias(secondAlias, baseScid, false, true)
require.NoError(t, err)
// When we now request another alias from the allocation list, we expect
// the third one (tx position 2) to be returned.
thirdRequested, err := aliasStore.RequestAlias()
require.NoError(t, err)
require.Equal(t, getNextScid(secondAlias), thirdRequested)
require.EqualValues(t, 2, thirdRequested.TxPosition)
}
// TestGetNextScid tests that given a current lnwire.ShortChannelID, // TestGetNextScid tests that given a current lnwire.ShortChannelID,
// getNextScid returns the expected alias to use next. // getNextScid returns the expected alias to use next.
func TestGetNextScid(t *testing.T) { func TestGetNextScid(t *testing.T) {
@ -80,7 +200,7 @@ func TestGetNextScid(t *testing.T) {
name: "starting alias", name: "starting alias",
current: StartingAlias, current: StartingAlias,
expected: lnwire.ShortChannelID{ expected: lnwire.ShortChannelID{
BlockHeight: uint32(startingBlockHeight), BlockHeight: AliasStartBlockHeight,
TxIndex: 0, TxIndex: 0,
TxPosition: 1, TxPosition: 1,
}, },

View file

@ -43,11 +43,11 @@ const (
AppMinor uint = 18 AppMinor uint = 18
// AppPatch defines the application patch for this binary. // AppPatch defines the application patch for this binary.
AppPatch uint = 3 AppPatch uint = 4
// AppPreRelease MUST only contain characters from semanticAlphabet per // AppPreRelease MUST only contain characters from semanticAlphabet per
// the semantic versioning spec. // the semantic versioning spec.
AppPreRelease = "beta" AppPreRelease = "beta.rc1"
) )
func init() { func init() {

View file

@ -24,6 +24,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify" "github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
@ -63,6 +64,14 @@ type Config struct {
// state. // state.
ChanStateDB *channeldb.ChannelStateDB ChanStateDB *channeldb.ChannelStateDB
// AuxLeafStore is an optional store that can be used to store auxiliary
// leaves for certain custom channel types.
AuxLeafStore fn.Option[lnwallet.AuxLeafStore]
// AuxSigner is an optional signer that can be used to sign auxiliary
// leaves for certain custom channel types.
AuxSigner fn.Option[lnwallet.AuxSigner]
// BlockCache is the main cache for storing block information. // BlockCache is the main cache for storing block information.
BlockCache *blockcache.BlockCache BlockCache *blockcache.BlockCache

View file

@ -356,6 +356,30 @@ func (r *RPCAcceptor) sendAcceptRequests(errChan chan error,
): ):
commitmentType = lnrpc.CommitmentType_SIMPLE_TAPROOT commitmentType = lnrpc.CommitmentType_SIMPLE_TAPROOT
case channelFeatures.OnlyContains(
lnwire.SimpleTaprootOverlayChansRequired,
lnwire.ZeroConfRequired,
lnwire.ScidAliasRequired,
):
commitmentType = lnrpc.CommitmentType_SIMPLE_TAPROOT_OVERLAY
case channelFeatures.OnlyContains(
lnwire.SimpleTaprootOverlayChansRequired,
lnwire.ZeroConfRequired,
):
commitmentType = lnrpc.CommitmentType_SIMPLE_TAPROOT_OVERLAY
case channelFeatures.OnlyContains(
lnwire.SimpleTaprootOverlayChansRequired,
lnwire.ScidAliasRequired,
):
commitmentType = lnrpc.CommitmentType_SIMPLE_TAPROOT_OVERLAY
case channelFeatures.OnlyContains(
lnwire.SimpleTaprootOverlayChansRequired,
):
commitmentType = lnrpc.CommitmentType_SIMPLE_TAPROOT_OVERLAY
case channelFeatures.OnlyContains( case channelFeatures.OnlyContains(
lnwire.StaticRemoteKeyRequired, lnwire.StaticRemoteKeyRequired,
): ):

View file

@ -226,28 +226,109 @@ const (
// A tlv type definition used to serialize an outpoint's indexStatus // A tlv type definition used to serialize an outpoint's indexStatus
// for use in the outpoint index. // for use in the outpoint index.
indexStatusType tlv.Type = 0 indexStatusType tlv.Type = 0
// A tlv type definition used to serialize and deserialize a KeyLocator
// from the database.
keyLocType tlv.Type = 1
// A tlv type used to serialize and deserialize the
// `InitialLocalBalance` field.
initialLocalBalanceType tlv.Type = 2
// A tlv type used to serialize and deserialize the
// `InitialRemoteBalance` field.
initialRemoteBalanceType tlv.Type = 3
// A tlv type definition used to serialize and deserialize the
// confirmed ShortChannelID for a zero-conf channel.
realScidType tlv.Type = 4
// A tlv type definition used to serialize and deserialize the
// Memo for the channel channel.
channelMemoType tlv.Type = 5
) )
// openChannelTlvData houses the new data fields that are stored for each
// channel in a TLV stream within the root bucket. This is stored as a TLV
// stream appended to the existing hard-coded fields in the channel's root
// bucket. New fields being added to the channel state should be added here.
//
// NOTE: This struct is used for serialization purposes only and its fields
// should be accessed via the OpenChannel struct while in memory.
type openChannelTlvData struct {
// revokeKeyLoc is the key locator for the revocation key.
revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord]
// initialLocalBalance is the initial local balance of the channel.
initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64]
// initialRemoteBalance is the initial remote balance of the channel.
initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64]
// realScid is the real short channel ID of the channel corresponding to
// the on-chain outpoint.
realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID]
// memo is an optional text field that gives context to the user about
// the channel.
memo tlv.OptionalRecordT[tlv.TlvType5, []byte]
// tapscriptRoot is the optional Tapscript root the channel funding
// output commits to.
tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte]
// customBlob is an optional TLV encoded blob of data representing
// custom channel funding information.
customBlob tlv.OptionalRecordT[tlv.TlvType7, tlv.Blob]
}
// encode serializes the openChannelTlvData to the given io.Writer.
func (c *openChannelTlvData) encode(w io.Writer) error {
tlvRecords := []tlv.Record{
c.revokeKeyLoc.Record(),
c.initialLocalBalance.Record(),
c.initialRemoteBalance.Record(),
c.realScid.Record(),
}
c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) {
tlvRecords = append(tlvRecords, memo.Record())
})
c.tapscriptRoot.WhenSome(
func(root tlv.RecordT[tlv.TlvType6, [32]byte]) {
tlvRecords = append(tlvRecords, root.Record())
},
)
c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType7, tlv.Blob]) {
tlvRecords = append(tlvRecords, blob.Record())
})
// Create the tlv stream.
tlvStream, err := tlv.NewStream(tlvRecords...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// decode deserializes the openChannelTlvData from the given io.Reader.
func (c *openChannelTlvData) decode(r io.Reader) error {
memo := c.memo.Zero()
tapscriptRoot := c.tapscriptRoot.Zero()
blob := c.customBlob.Zero()
// Create the tlv stream.
tlvStream, err := tlv.NewStream(
c.revokeKeyLoc.Record(),
c.initialLocalBalance.Record(),
c.initialRemoteBalance.Record(),
c.realScid.Record(),
memo.Record(),
tapscriptRoot.Record(),
blob.Record(),
)
if err != nil {
return err
}
tlvs, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil {
return err
}
if _, ok := tlvs[memo.TlvType()]; ok {
c.memo = tlv.SomeRecordT(memo)
}
if _, ok := tlvs[tapscriptRoot.TlvType()]; ok {
c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot)
}
if _, ok := tlvs[c.customBlob.TlvType()]; ok {
c.customBlob = tlv.SomeRecordT(blob)
}
return nil
}
// indexStatus is an enum-like type that describes what state the // indexStatus is an enum-like type that describes what state the
// outpoint is in. Currently only two possible values. // outpoint is in. Currently only two possible values.
type indexStatus uint8 type indexStatus uint8
@ -325,6 +406,11 @@ const (
// SimpleTaprootFeatureBit indicates that the simple-taproot-chans // SimpleTaprootFeatureBit indicates that the simple-taproot-chans
// feature bit was negotiated during the lifetime of the channel. // feature bit was negotiated during the lifetime of the channel.
SimpleTaprootFeatureBit ChannelType = 1 << 10 SimpleTaprootFeatureBit ChannelType = 1 << 10
// TapscriptRootBit indicates that this is a MuSig2 channel with a top
// level tapscript commitment. This MUST be set along with the
// SimpleTaprootFeatureBit.
TapscriptRootBit ChannelType = 1 << 11
) )
// IsSingleFunder returns true if the channel type if one of the known single // IsSingleFunder returns true if the channel type if one of the known single
@ -395,6 +481,12 @@ func (c ChannelType) IsTaproot() bool {
return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit
} }
// HasTapscriptRoot returns true if the channel is using a top level tapscript
// root commitment.
func (c ChannelType) HasTapscriptRoot() bool {
return c&TapscriptRootBit == TapscriptRootBit
}
// ChannelStateBounds are the parameters from OpenChannel and AcceptChannel // ChannelStateBounds are the parameters from OpenChannel and AcceptChannel
// that are responsible for providing bounds on the state space of the abstract // that are responsible for providing bounds on the state space of the abstract
// channel state. These values must be remembered for normal channel operation // channel state. These values must be remembered for normal channel operation
@ -496,6 +588,53 @@ type ChannelConfig struct {
HtlcBasePoint keychain.KeyDescriptor HtlcBasePoint keychain.KeyDescriptor
} }
// commitTlvData stores all the optional data that may be stored as a TLV stream
// at the _end_ of the normal serialized commit on disk.
type commitTlvData struct {
// customBlob is a custom blob that may store extra data for custom
// channels.
customBlob tlv.OptionalRecordT[tlv.TlvType1, tlv.Blob]
}
// encode encodes the aux data into the passed io.Writer.
func (c *commitTlvData) encode(w io.Writer) error {
var tlvRecords []tlv.Record
c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType1, tlv.Blob]) {
tlvRecords = append(tlvRecords, blob.Record())
})
// Create the tlv stream.
tlvStream, err := tlv.NewStream(tlvRecords...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// decode attempts to decode the aux data from the passed io.Reader.
func (c *commitTlvData) decode(r io.Reader) error {
blob := c.customBlob.Zero()
tlvStream, err := tlv.NewStream(
blob.Record(),
)
if err != nil {
return err
}
tlvs, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil {
return err
}
if _, ok := tlvs[c.customBlob.TlvType()]; ok {
c.customBlob = tlv.SomeRecordT(blob)
}
return nil
}
// ChannelCommitment is a snapshot of the commitment state at a particular // ChannelCommitment is a snapshot of the commitment state at a particular
// point in the commitment chain. With each state transition, a snapshot of the // point in the commitment chain. With each state transition, a snapshot of the
// current state along with all non-settled HTLCs are recorded. These snapshots // current state along with all non-settled HTLCs are recorded. These snapshots
@ -562,6 +701,11 @@ type ChannelCommitment struct {
// able by us. // able by us.
CommitTx *wire.MsgTx CommitTx *wire.MsgTx
// CustomBlob is an optional blob that can be used to store information
// specific to a custom channel type. This may track some custom
// specific state for this given commitment.
CustomBlob fn.Option[tlv.Blob]
// CommitSig is one half of the signature required to fully complete // CommitSig is one half of the signature required to fully complete
// the script for the commitment transaction above. This is the // the script for the commitment transaction above. This is the
// signature signed by the remote party for our version of the // signature signed by the remote party for our version of the
@ -571,9 +715,26 @@ type ChannelCommitment struct {
// Htlcs is the set of HTLC's that are pending at this particular // Htlcs is the set of HTLC's that are pending at this particular
// commitment height. // commitment height.
Htlcs []HTLC Htlcs []HTLC
}
// TODO(roasbeef): pending commit pointer? // amendTlvData updates the channel with the given auxiliary TLV data.
// * lets just walk through func (c *ChannelCommitment) amendTlvData(auxData commitTlvData) {
auxData.customBlob.WhenSomeV(func(blob tlv.Blob) {
c.CustomBlob = fn.Some(blob)
})
}
// extractTlvData creates a new commitTlvData from the given commitment.
func (c *ChannelCommitment) extractTlvData() commitTlvData {
var auxData commitTlvData
c.CustomBlob.WhenSome(func(blob tlv.Blob) {
auxData.customBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType1](blob),
)
})
return auxData
} }
// ChannelStatus is a bit vector used to indicate whether an OpenChannel is in // ChannelStatus is a bit vector used to indicate whether an OpenChannel is in
@ -867,6 +1028,16 @@ type OpenChannel struct {
// channel that will be useful to our future selves. // channel that will be useful to our future selves.
Memo []byte Memo []byte
// TapscriptRoot is an optional tapscript root used to derive the MuSig2
// funding output.
TapscriptRoot fn.Option[chainhash.Hash]
// CustomBlob is an optional blob that can be used to store information
// specific to a custom channel type. This information is only created
// at channel funding time, and after wards is to be considered
// immutable.
CustomBlob fn.Option[tlv.Blob]
// TODO(roasbeef): eww // TODO(roasbeef): eww
Db *ChannelStateDB Db *ChannelStateDB
@ -1025,6 +1196,64 @@ func (c *OpenChannel) SetBroadcastHeight(height uint32) {
c.FundingBroadcastHeight = height c.FundingBroadcastHeight = height
} }
// amendTlvData updates the channel with the given auxiliary TLV data.
func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) {
c.RevocationKeyLocator = auxData.revokeKeyLoc.Val.KeyLocator
c.InitialLocalBalance = lnwire.MilliSatoshi(
auxData.initialLocalBalance.Val,
)
c.InitialRemoteBalance = lnwire.MilliSatoshi(
auxData.initialRemoteBalance.Val,
)
c.confirmedScid = auxData.realScid.Val
auxData.memo.WhenSomeV(func(memo []byte) {
c.Memo = memo
})
auxData.tapscriptRoot.WhenSomeV(func(h [32]byte) {
c.TapscriptRoot = fn.Some[chainhash.Hash](h)
})
auxData.customBlob.WhenSomeV(func(blob tlv.Blob) {
c.CustomBlob = fn.Some(blob)
})
}
// extractTlvData creates a new openChannelTlvData from the given channel.
func (c *OpenChannel) extractTlvData() openChannelTlvData {
auxData := openChannelTlvData{
revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1](
keyLocRecord{c.RevocationKeyLocator},
),
initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2](
uint64(c.InitialLocalBalance),
),
initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3](
uint64(c.InitialRemoteBalance),
),
realScid: tlv.NewRecordT[tlv.TlvType4](
c.confirmedScid,
),
}
if len(c.Memo) != 0 {
auxData.memo = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo),
)
}
c.TapscriptRoot.WhenSome(func(h chainhash.Hash) {
auxData.tapscriptRoot = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h),
)
})
c.CustomBlob.WhenSome(func(blob tlv.Blob) {
auxData.customBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType7](blob),
)
})
return auxData
}
// Refresh updates the in-memory channel state using the latest state observed // Refresh updates the in-memory channel state using the latest state observed
// on disk. // on disk.
func (c *OpenChannel) Refresh() error { func (c *OpenChannel) Refresh() error {
@ -2351,6 +2580,12 @@ type HTLC struct {
// HTLC. It is stored in the ExtraData field, which is used to store // HTLC. It is stored in the ExtraData field, which is used to store
// a TLV stream of additional information associated with the HTLC. // a TLV stream of additional information associated with the HTLC.
BlindingPoint lnwire.BlindingPointRecord BlindingPoint lnwire.BlindingPointRecord
// CustomRecords is a set of custom TLV records that are associated with
// this HTLC. These records are used to store additional information
// about the HTLC that is not part of the standard HTLC fields. This
// field is encoded within the ExtraData field.
CustomRecords lnwire.CustomRecords
} }
// serializeExtraData encodes a TLV stream of extra data to be stored with a // serializeExtraData encodes a TLV stream of extra data to be stored with a
@ -2369,6 +2604,11 @@ func (h *HTLC) serializeExtraData() error {
records = append(records, &b) records = append(records, &b)
}) })
records, err := h.CustomRecords.ExtendRecordProducers(records)
if err != nil {
return err
}
return h.ExtraData.PackRecords(records...) return h.ExtraData.PackRecords(records...)
} }
@ -2390,8 +2630,19 @@ func (h *HTLC) deserializeExtraData() error {
if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil { if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil {
h.BlindingPoint = tlv.SomeRecordT(blindingPoint) h.BlindingPoint = tlv.SomeRecordT(blindingPoint)
// Remove the entry from the TLV map. Anything left in the map
// will be included in the custom records field.
delete(tlvMap, h.BlindingPoint.TlvType())
} }
// Set the custom records field to the remaining TLV records.
customRecords, err := lnwire.NewCustomRecords(tlvMap)
if err != nil {
return err
}
h.CustomRecords = customRecords
return nil return nil
} }
@ -2529,6 +2780,8 @@ func (h *HTLC) Copy() HTLC {
copy(clone.Signature[:], h.Signature) copy(clone.Signature[:], h.Signature)
copy(clone.RHash[:], h.RHash[:]) copy(clone.RHash[:], h.RHash[:])
copy(clone.ExtraData, h.ExtraData) copy(clone.ExtraData, h.ExtraData)
clone.BlindingPoint = h.BlindingPoint
clone.CustomRecords = h.CustomRecords.Copy()
return clone return clone
} }
@ -2690,6 +2943,14 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl
} }
} }
// We'll also encode the commit aux data stream here. We do this here
// rather than above (at the call to serializeChanCommit), to ensure
// backwards compat for reads to existing non-custom channels.
auxData := diff.Commitment.extractTlvData()
if err := auxData.encode(w); err != nil {
return fmt.Errorf("unable to write aux data: %w", err)
}
return nil return nil
} }
@ -2750,6 +3011,17 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
} }
} }
// As a final step, we'll read out any aux commit data that we have at
// the end of this byte stream. We do this here to ensure backward
// compatibility, as otherwise we risk erroneously reading into the
// wrong field.
var auxData commitTlvData
if err := auxData.decode(r); err != nil {
return nil, fmt.Errorf("unable to decode aux data: %w", err)
}
d.Commitment.amendTlvData(auxData)
return &d, nil return &d, nil
} }
@ -3728,6 +4000,13 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot {
}, },
} }
localCommit.CustomBlob.WhenSome(func(blob tlv.Blob) {
blobCopy := make([]byte, len(blob))
copy(blobCopy, blob)
snapshot.ChannelCommitment.CustomBlob = fn.Some(blobCopy)
})
// Copy over the current set of HTLCs to ensure the caller can't mutate // Copy over the current set of HTLCs to ensure the caller can't mutate
// our internal state. // our internal state.
snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs))
@ -4030,32 +4309,9 @@ func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
return err return err
} }
// Convert balance fields into uint64. auxData := channel.extractTlvData()
localBalance := uint64(channel.InitialLocalBalance) if err := auxData.encode(&w); err != nil {
remoteBalance := uint64(channel.InitialRemoteBalance) return fmt.Errorf("unable to encode aux data: %w", err)
// Create the tlv stream.
tlvStream, err := tlv.NewStream(
// Write the RevocationKeyLocator as the first entry in a tlv
// stream.
MakeKeyLocRecord(
keyLocType, &channel.RevocationKeyLocator,
),
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
MakeScidRecord(realScidType, &channel.confirmedScid),
tlv.MakePrimitiveRecord(channelMemoType, &channel.Memo),
)
if err != nil {
return err
}
if err := tlvStream.Encode(&w); err != nil {
return err
} }
if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil { if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil {
@ -4142,6 +4398,12 @@ func putChanCommitment(chanBucket kvdb.RwBucket, c *ChannelCommitment,
return err return err
} }
// Before we write to disk, we'll also write our aux data as well.
auxData := c.extractTlvData()
if err := auxData.encode(&b); err != nil {
return fmt.Errorf("unable to write aux data: %w", err)
}
return chanBucket.Put(commitKey, b.Bytes()) return chanBucket.Put(commitKey, b.Bytes())
} }
@ -4244,45 +4506,14 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error {
} }
} }
// Create balance fields in uint64, and Memo field as byte slice. var auxData openChannelTlvData
var ( if err := auxData.decode(r); err != nil {
localBalance uint64 return fmt.Errorf("unable to decode aux data: %w", err)
remoteBalance uint64
memo []byte
)
// Create the tlv stream.
tlvStream, err := tlv.NewStream(
// Write the RevocationKeyLocator as the first entry in a tlv
// stream.
MakeKeyLocRecord(
keyLocType, &channel.RevocationKeyLocator,
),
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
MakeScidRecord(realScidType, &channel.confirmedScid),
tlv.MakePrimitiveRecord(channelMemoType, &memo),
)
if err != nil {
return err
} }
if err := tlvStream.Decode(r); err != nil { // Assign all the relevant fields from the aux data into the actual
return err // open channel.
} channel.amendTlvData(auxData)
// Attach the balance fields.
channel.InitialLocalBalance = lnwire.MilliSatoshi(localBalance)
channel.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance)
// Attach the memo field if non-empty.
if len(memo) > 0 {
channel.Memo = memo
}
channel.Packager = NewChannelPackager(channel.ShortChannelID) channel.Packager = NewChannelPackager(channel.ShortChannelID)
@ -4318,7 +4549,9 @@ func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) {
return c, nil return c, nil
} }
func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment, error) { func fetchChanCommitment(chanBucket kvdb.RBucket,
local bool) (ChannelCommitment, error) {
var commitKey []byte var commitKey []byte
if local { if local {
commitKey = append(chanCommitmentKey, byte(0x00)) commitKey = append(chanCommitmentKey, byte(0x00))
@ -4332,7 +4565,23 @@ func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment
} }
r := bytes.NewReader(commitBytes) r := bytes.NewReader(commitBytes)
return deserializeChanCommit(r) chanCommit, err := deserializeChanCommit(r)
if err != nil {
return ChannelCommitment{}, fmt.Errorf("unable to decode "+
"chan commit: %w", err)
}
// We'll also check to see if we have any aux data stored as the end of
// the stream.
var auxData commitTlvData
if err := auxData.decode(r); err != nil {
return ChannelCommitment{}, fmt.Errorf("unable to decode "+
"chan aux data: %w", err)
}
chanCommit.amendTlvData(auxData)
return chanCommit, nil
} }
func fetchChanCommitments(chanBucket kvdb.RBucket, channel *OpenChannel) error { func fetchChanCommitments(chanBucket kvdb.RBucket, channel *OpenChannel) error {
@ -4440,6 +4689,25 @@ func deleteThawHeight(chanBucket kvdb.RwBucket) error {
return chanBucket.Delete(frozenChanKey) return chanBucket.Delete(frozenChanKey)
} }
// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the
// tlv.RecordProducer interface.
type keyLocRecord struct {
keychain.KeyLocator
}
// Record creates a Record out of a KeyLocator using the passed Type and the
// EKeyLocator and DKeyLocator functions. The size will always be 8 as
// KeyFamily is uint32 and the Index is uint32.
//
// NOTE: This is part of the tlv.RecordProducer interface.
func (k *keyLocRecord) Record() tlv.Record {
// Note that we set the type here as zero, as when used with a
// tlv.RecordT, the type param will be used as the type.
return tlv.MakeStaticRecord(
0, &k.KeyLocator, 8, EKeyLocator, DKeyLocator,
)
}
// EKeyLocator is an encoder for keychain.KeyLocator. // EKeyLocator is an encoder for keychain.KeyLocator.
func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error { func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*keychain.KeyLocator); ok { if v, ok := val.(*keychain.KeyLocator); ok {
@ -4468,22 +4736,6 @@ func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8) return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8)
} }
// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed
// Type and the EKeyLocator and DKeyLocator functions. The size will always be
// 8 as KeyFamily is uint32 and the Index is uint32.
func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record {
return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator)
}
// MakeScidRecord creates a Record out of a ShortChannelID using the passed
// Type and the EShortChannelID and DShortChannelID functions. The size will
// always be 8 for the ShortChannelID.
func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record {
return tlv.MakeStaticRecord(
typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID,
)
}
// ShutdownInfo contains various info about the shutdown initiation of a // ShutdownInfo contains various info about the shutdown initiation of a
// channel. // channel.
type ShutdownInfo struct { type ShutdownInfo struct {

View file

@ -17,6 +17,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lnmock"
@ -173,7 +174,7 @@ func fundingPointOption(chanPoint wire.OutPoint) testChannelOption {
} }
// channelIDOption is an option which sets the short channel ID of the channel. // channelIDOption is an option which sets the short channel ID of the channel.
var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { func channelIDOption(chanID lnwire.ShortChannelID) testChannelOption {
return func(params *testChannelParams) { return func(params *testChannelParams) {
params.channel.ShortChannelID = chanID params.channel.ShortChannelID = chanID
} }
@ -326,6 +327,9 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
uniqueOutputIndex.Add(1) uniqueOutputIndex.Add(1)
op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()} op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()}
var tapscriptRoot chainhash.Hash
copy(tapscriptRoot[:], bytes.Repeat([]byte{1}, 32))
return &OpenChannel{ return &OpenChannel{
ChanType: SingleFunderBit | FrozenBit, ChanType: SingleFunderBit | FrozenBit,
ChainHash: key, ChainHash: key,
@ -347,6 +351,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
FeePerKw: btcutil.Amount(5000), FeePerKw: btcutil.Amount(5000),
CommitTx: channels.TestFundingTx, CommitTx: channels.TestFundingTx,
CommitSig: bytes.Repeat([]byte{1}, 71), CommitSig: bytes.Repeat([]byte{1}, 71),
CustomBlob: fn.Some([]byte{1, 2, 3}),
}, },
RemoteCommitment: ChannelCommitment{ RemoteCommitment: ChannelCommitment{
CommitHeight: 0, CommitHeight: 0,
@ -356,6 +361,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
FeePerKw: btcutil.Amount(5000), FeePerKw: btcutil.Amount(5000),
CommitTx: channels.TestFundingTx, CommitTx: channels.TestFundingTx,
CommitSig: bytes.Repeat([]byte{1}, 71), CommitSig: bytes.Repeat([]byte{1}, 71),
CustomBlob: fn.Some([]byte{4, 5, 6}),
}, },
NumConfsRequired: 4, NumConfsRequired: 4,
RemoteCurrentRevocation: privKey.PubKey(), RemoteCurrentRevocation: privKey.PubKey(),
@ -368,6 +374,9 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
ThawHeight: uint32(defaultPendingHeight), ThawHeight: uint32(defaultPendingHeight),
InitialLocalBalance: lnwire.MilliSatoshi(9000), InitialLocalBalance: lnwire.MilliSatoshi(9000),
InitialRemoteBalance: lnwire.MilliSatoshi(3000), InitialRemoteBalance: lnwire.MilliSatoshi(3000),
Memo: []byte("test"),
TapscriptRoot: fn.Some(tapscriptRoot),
CustomBlob: fn.Some([]byte{1, 2, 3}),
} }
} }
@ -575,24 +584,32 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment, func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment,
r *RevocationLog) { r *RevocationLog) {
t.Helper()
// Check the common fields. // Check the common fields.
require.EqualValues( require.EqualValues(
t, r.CommitTxHash, c.CommitTx.TxHash(), "CommitTx mismatch", t, r.CommitTxHash.Val, c.CommitTx.TxHash(), "CommitTx mismatch",
) )
// Now check the common fields from the HTLCs. // Now check the common fields from the HTLCs.
require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch") require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch")
for i, rHtlc := range r.HTLCEntries { for i, rHtlc := range r.HTLCEntries {
cHtlc := c.Htlcs[i] cHtlc := c.Htlcs[i]
require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch") require.Equal(t, rHtlc.RHash.Val[:], cHtlc.RHash[:], "RHash")
require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(), require.Equal(
"Amt mismatch") t, rHtlc.Amt.Val.Int(), cHtlc.Amt.ToSatoshis(), "Amt",
require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout, )
"RefundTimeout mismatch") require.Equal(
require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex, t, rHtlc.RefundTimeout.Val, cHtlc.RefundTimeout,
"OutputIndex mismatch") "RefundTimeout",
require.Equal(t, rHtlc.Incoming, cHtlc.Incoming, )
"Incoming mismatch") require.EqualValues(
t, rHtlc.OutputIndex.Val, cHtlc.OutputIndex,
"OutputIndex",
)
require.Equal(
t, rHtlc.Incoming.Val, cHtlc.Incoming, "Incoming",
)
} }
} }
@ -657,6 +674,7 @@ func TestChannelStateTransition(t *testing.T) {
CommitTx: newTx, CommitTx: newTx,
CommitSig: newSig, CommitSig: newSig,
Htlcs: htlcs, Htlcs: htlcs,
CustomBlob: fn.Some([]byte{4, 5, 6}),
} }
// First update the local node's broadcastable state and also add a // First update the local node's broadcastable state and also add a
@ -694,9 +712,14 @@ func TestChannelStateTransition(t *testing.T) {
// have been updated. // have been updated.
updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub)
require.NoError(t, err, "unable to fetch updated channel") require.NoError(t, err, "unable to fetch updated channel")
assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment)
assertCommitmentEqual(
t, &commitment, &updatedChannel[0].LocalCommitment,
)
numDiskUpdates, err := updatedChannel[0].CommitmentHeight() numDiskUpdates, err := updatedChannel[0].CommitmentHeight()
require.NoError(t, err, "unable to read commitment height from disk") require.NoError(t, err, "unable to read commitment height from disk")
if numDiskUpdates != uint64(commitment.CommitHeight) { if numDiskUpdates != uint64(commitment.CommitHeight) {
t.Fatalf("num disk updates doesn't match: %v vs %v", t.Fatalf("num disk updates doesn't match: %v vs %v",
numDiskUpdates, commitment.CommitHeight) numDiskUpdates, commitment.CommitHeight)
@ -799,10 +822,10 @@ func TestChannelStateTransition(t *testing.T) {
// Check the output indexes are saved as expected. // Check the output indexes are saved as expected.
require.EqualValues( require.EqualValues(
t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val,
) )
require.EqualValues( require.EqualValues(
t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val,
) )
// The two deltas (the original vs the on-disk version) should // The two deltas (the original vs the on-disk version) should
@ -844,10 +867,10 @@ func TestChannelStateTransition(t *testing.T) {
// Check the output indexes are saved as expected. // Check the output indexes are saved as expected.
require.EqualValues( require.EqualValues(
t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val,
) )
require.EqualValues( require.EqualValues(
t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val,
) )
assertRevocationLogEntryEqual(t, &oldRemoteCommit, prevCommit) assertRevocationLogEntryEqual(t, &oldRemoteCommit, prevCommit)
@ -1642,6 +1665,24 @@ func TestHTLCsExtraData(t *testing.T) {
), ),
} }
// Custom channel data htlc with a blinding point.
customDataHTLC := HTLC{
Signature: testSig.Serialize(),
Incoming: false,
Amt: 10,
RHash: key,
RefundTimeout: 1,
OnionBlob: lnmock.MockOnion(),
BlindingPoint: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
pubKey,
),
),
CustomRecords: map[uint64][]byte{
uint64(lnwire.MinCustomRecordsTlvType + 3): {1, 2, 3},
},
}
testCases := []struct { testCases := []struct {
name string name string
htlcs []HTLC htlcs []HTLC
@ -1663,6 +1704,7 @@ func TestHTLCsExtraData(t *testing.T) {
mockHtlc, mockHtlc,
blindingPointHTLC, blindingPointHTLC,
mockHtlc, mockHtlc,
customDataHTLC,
}, },
}, },
} }

View file

@ -286,6 +286,27 @@ func NewFwdPkg(source lnwire.ShortChannelID, height uint64,
} }
} }
// SourceRef is a convenience method that returns an AddRef to this forwarding
// package for the index in the argument. It is the caller's responsibility
// to ensure that the index is in bounds.
func (f *FwdPkg) SourceRef(i uint16) AddRef {
return AddRef{
Height: f.Height,
Index: i,
}
}
// DestRef is a convenience method that returns a SettleFailRef to this
// forwarding package for the index in the argument. It is the caller's
// responsibility to ensure that the index is in bounds.
func (f *FwdPkg) DestRef(i uint16) SettleFailRef {
return SettleFailRef{
Source: f.Source,
Height: f.Height,
Index: i,
}
}
// ID returns an unique identifier for this package, used to ensure that sphinx // ID returns an unique identifier for this package, used to ensure that sphinx
// replay processing of this batch is idempotent. // replay processing of this batch is idempotent.
func (f *FwdPkg) ID() []byte { func (f *FwdPkg) ID() []byte {

View file

@ -2382,7 +2382,7 @@ func TestStressTestChannelGraphAPI(t *testing.T) {
methodsMu.Unlock() methodsMu.Unlock()
err := fn() err := fn()
require.NoErrorf(t, err, fmt.Sprintf(name)) require.NoErrorf(t, err, name)
} }
}) })
} }

View file

@ -2,7 +2,6 @@ package migration_01_to_11
import ( import (
"bytes" "bytes"
"fmt"
"testing" "testing"
"time" "time"
@ -154,12 +153,7 @@ func signDigestCompact(hash []byte) ([]byte, error) {
privKey, _ := btcec.PrivKeyFromBytes(testPrivKeyBytes) privKey, _ := btcec.PrivKeyFromBytes(testPrivKeyBytes)
// ecdsa.SignCompact returns a pubkey-recoverable signature // ecdsa.SignCompact returns a pubkey-recoverable signature
sig, err := ecdsa.SignCompact(privKey, hash, isCompressedKey) return ecdsa.SignCompact(privKey, hash, isCompressedKey), nil
if err != nil {
return nil, fmt.Errorf("can't sign the hash: %w", err)
}
return sig, nil
} }
// getPayReq creates a payment request for the given net. // getPayReq creates a payment request for the given net.

View file

@ -8,6 +8,7 @@ import (
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn"
) )
// ChannelEdgeInfo represents a fully authenticated channel along with all its // ChannelEdgeInfo represents a fully authenticated channel along with all its
@ -62,6 +63,11 @@ type ChannelEdgeInfo struct {
// the value output in the outpoint that created this channel. // the value output in the outpoint that created this channel.
Capacity btcutil.Amount Capacity btcutil.Amount
// TapscriptRoot is the optional Merkle root of the tapscript tree if
// this channel is a taproot channel that also commits to a tapscript
// tree (custom channel).
TapscriptRoot fn.Option[chainhash.Hash]
// ExtraOpaqueData is the set of data that was appended to this // ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or // message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to // parse. By holding onto this data, we ensure that we're able to

View file

@ -195,6 +195,11 @@ type PaymentCreationInfo struct {
// PaymentRequest is the full payment request, if any. // PaymentRequest is the full payment request, if any.
PaymentRequest []byte PaymentRequest []byte
// FirstHopCustomRecords are the TLV records that are to be sent to the
// first hop of this payment. These records will be transmitted via the
// wire message only and therefore do not affect the onion payload size.
FirstHopCustomRecords lnwire.CustomRecords
} }
// htlcBucketKey creates a composite key from prefix and id where the result is // htlcBucketKey creates a composite key from prefix and id where the result is
@ -1010,10 +1015,21 @@ func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
return err return err
} }
// Any remaining bytes are TLV encoded records. Currently, these are
// only the custom records provided by the user to be sent to the first
// hop. But this can easily be extended with further records by merging
// the records into a single TLV stream.
err := c.FirstHopCustomRecords.SerializeTo(w)
if err != nil {
return err
}
return nil return nil
} }
func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) { func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo,
error) {
var scratch [8]byte var scratch [8]byte
c := &PaymentCreationInfo{} c := &PaymentCreationInfo{}
@ -1046,6 +1062,15 @@ func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) {
} }
c.PaymentRequest = payReq c.PaymentRequest = payReq
// Any remaining bytes are TLV encoded records. Currently, these are
// only the custom records provided by the user to be sent to the first
// hop. But this can easily be extended with further records by merging
// the records into a single TLV stream.
c.FirstHopCustomRecords, err = lnwire.ParseCustomRecordsFrom(r)
if err != nil {
return nil, err
}
return c, nil return c, nil
} }
@ -1071,6 +1096,25 @@ func serializeHTLCAttemptInfo(w io.Writer, a *HTLCAttemptInfo) error {
return err return err
} }
// Merge the fixed/known records together with the custom records to
// serialize them as a single blob. We can't do this in SerializeRoute
// because we're in the middle of the byte stream there. We can only do
// TLV serialization at the end of the stream, since EOF is allowed for
// a stream if no more data is expected.
producers := []tlv.RecordProducer{
&a.Route.FirstHopAmount,
}
tlvData, err := lnwire.MergeAndEncode(
producers, nil, a.Route.FirstHopWireCustomRecords,
)
if err != nil {
return err
}
if _, err := w.Write(tlvData); err != nil {
return err
}
return nil return nil
} }
@ -1108,6 +1152,22 @@ func deserializeHTLCAttemptInfo(r io.Reader) (*HTLCAttemptInfo, error) {
a.Hash = &hash a.Hash = &hash
// Read any remaining data (if any) and parse it into the known records
// and custom records.
extraData, err := io.ReadAll(r)
if err != nil {
return nil, err
}
customRecords, _, _, err := lnwire.ParseAndExtractCustomRecords(
extraData, &a.Route.FirstHopAmount,
)
if err != nil {
return nil, err
}
a.Route.FirstHopWireCustomRecords = customRecords
return a, nil return a, nil
} }
@ -1373,6 +1433,8 @@ func SerializeRoute(w io.Writer, r route.Route) error {
} }
} }
// Any new/extra TLV data is encoded in serializeHTLCAttemptInfo!
return nil return nil
} }
@ -1406,5 +1468,7 @@ func DeserializeRoute(r io.Reader) (route.Route, error) {
} }
rt.Hops = hops rt.Hops = hops
// Any new/extra TLV data is decoded in deserializeHTLCAttemptInfo!
return rt, nil return rt, nil
} }

View file

@ -13,8 +13,10 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -108,7 +110,7 @@ func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) {
// Use single second precision to avoid false positive test // Use single second precision to avoid false positive test
// failures due to the monotonic time component. // failures due to the monotonic time component.
CreationTime: time.Unix(time.Now().Unix(), 0), CreationTime: time.Unix(time.Now().Unix(), 0),
PaymentRequest: []byte(""), PaymentRequest: []byte("test"),
} }
a := NewHtlcAttempt( a := NewHtlcAttempt(
@ -124,51 +126,64 @@ func TestSentPaymentSerialization(t *testing.T) {
c, s := makeFakeInfo() c, s := makeFakeInfo()
var b bytes.Buffer var b bytes.Buffer
if err := serializePaymentCreationInfo(&b, c); err != nil { require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize")
t.Fatalf("unable to serialize creation info: %v", err)
} // Assert the length of the serialized creation info is as expected,
// without any custom records.
baseLength := 32 + 8 + 8 + 4 + len(c.PaymentRequest)
require.Len(t, b.Bytes(), baseLength)
newCreationInfo, err := deserializePaymentCreationInfo(&b) newCreationInfo, err := deserializePaymentCreationInfo(&b)
require.NoError(t, err, "unable to deserialize creation info") require.NoError(t, err, "deserialize")
require.Equal(t, c, newCreationInfo)
if !reflect.DeepEqual(c, newCreationInfo) {
t.Fatalf("Payments do not match after "+
"serialization/deserialization %v vs %v",
spew.Sdump(c), spew.Sdump(newCreationInfo),
)
}
b.Reset() b.Reset()
if err := serializeHTLCAttemptInfo(&b, s); err != nil {
t.Fatalf("unable to serialize info: %v", err) // Now we add some custom records to the creation info and serialize it
// again.
c.FirstHopCustomRecords = lnwire.CustomRecords{
lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3},
} }
require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize")
newCreationInfo, err = deserializePaymentCreationInfo(&b)
require.NoError(t, err, "deserialize")
require.Equal(t, c, newCreationInfo)
b.Reset()
require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize")
newWireInfo, err := deserializeHTLCAttemptInfo(&b) newWireInfo, err := deserializeHTLCAttemptInfo(&b)
require.NoError(t, err, "unable to deserialize info") require.NoError(t, err, "deserialize")
newWireInfo.AttemptID = s.AttemptID
// First we verify all the records match up porperly, as they aren't // First we verify all the records match up properly.
// able to be properly compared using reflect.DeepEqual. require.Equal(t, s.Route, newWireInfo.Route)
err = assertRouteEqual(&s.Route, &newWireInfo.Route)
if err != nil { // We now add the new fields and custom records to the route and
t.Fatalf("Routes do not match after "+ // serialize it again.
"serialization/deserialization: %v", err) b.Reset()
s.Route.FirstHopAmount = tlv.NewRecordT[tlv.TlvType0](
tlv.NewBigSizeT(lnwire.MilliSatoshi(1234)),
)
s.Route.FirstHopWireCustomRecords = lnwire.CustomRecords{
lnwire.MinCustomRecordsTlvType + 3: []byte{4, 5, 6},
} }
require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize")
newWireInfo, err = deserializeHTLCAttemptInfo(&b)
require.NoError(t, err, "deserialize")
require.Equal(t, s.Route, newWireInfo.Route)
// Clear routes to allow DeepEqual to compare the remaining fields. // Clear routes to allow DeepEqual to compare the remaining fields.
newWireInfo.Route = route.Route{} newWireInfo.Route = route.Route{}
s.Route = route.Route{} s.Route = route.Route{}
newWireInfo.AttemptID = s.AttemptID
// Call session key method to set our cached session key so we can use // Call session key method to set our cached session key so we can use
// DeepEqual, and assert that our key equals the original key. // DeepEqual, and assert that our key equals the original key.
require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey()) require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey())
if !reflect.DeepEqual(s, newWireInfo) { require.Equal(t, s, newWireInfo)
t.Fatalf("Payments do not match after "+
"serialization/deserialization %v vs %v",
spew.Sdump(s), spew.Sdump(newWireInfo),
)
}
} }
// assertRouteEquals compares to routes for equality and returns an error if // assertRouteEquals compares to routes for equality and returns an error if

View file

@ -7,6 +7,7 @@ import (
"math" "math"
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -16,16 +17,15 @@ import (
const ( const (
// OutputIndexEmpty is used when the output index doesn't exist. // OutputIndexEmpty is used when the output index doesn't exist.
OutputIndexEmpty = math.MaxUint16 OutputIndexEmpty = math.MaxUint16
)
// A set of tlv type definitions used to serialize the body of type (
// revocation logs to the database. // BigSizeAmount is a type alias for a TLV record of a btcutil.Amount.
// BigSizeAmount = tlv.BigSizeT[btcutil.Amount]
// NOTE: A migration should be added whenever this list changes.
revLogOurOutputIndexType tlv.Type = 0 // BigSizeMilliSatoshi is a type alias for a TLV record of a
revLogTheirOutputIndexType tlv.Type = 1 // lnwire.MilliSatoshi.
revLogCommitTxHashType tlv.Type = 2 BigSizeMilliSatoshi = tlv.BigSizeT[lnwire.MilliSatoshi]
revLogOurBalanceType tlv.Type = 3
revLogTheirBalanceType tlv.Type = 4
) )
var ( var (
@ -54,6 +54,74 @@ var (
ErrOutputIndexTooBig = errors.New("output index is over uint16") ErrOutputIndexTooBig = errors.New("output index is over uint16")
) )
// SparsePayHash is a type alias for a 32 byte array, which when serialized is
// able to save some space by not including an empty payment hash on disk.
type SparsePayHash [32]byte
// NewSparsePayHash creates a new SparsePayHash from a 32 byte array.
func NewSparsePayHash(rHash [32]byte) SparsePayHash {
return SparsePayHash(rHash)
}
// Record returns a tlv record for the SparsePayHash.
func (s *SparsePayHash) Record() tlv.Record {
// We use a zero for the type here, as this'll be used along with the
// RecordT type.
return tlv.MakeDynamicRecord(
0, s, s.hashLen,
sparseHashEncoder, sparseHashDecoder,
)
}
// hashLen is used by MakeDynamicRecord to return the size of the RHash.
//
// NOTE: for zero hash, we return a length 0.
func (s *SparsePayHash) hashLen() uint64 {
if bytes.Equal(s[:], lntypes.ZeroHash[:]) {
return 0
}
return 32
}
// sparseHashEncoder is the customized encoder which skips encoding the empty
// hash.
func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
v, ok := val.(*SparsePayHash)
if !ok {
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
}
// If the value is an empty hash, we will skip encoding it.
if bytes.Equal(v[:], lntypes.ZeroHash[:]) {
return nil
}
vArray := (*[32]byte)(v)
return tlv.EBytes32(w, vArray, buf)
}
// sparseHashDecoder is the customized decoder which skips decoding the empty
// hash.
func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
v, ok := val.(*SparsePayHash)
if !ok {
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
}
// If the length is zero, we will skip encoding the empty hash.
if l == 0 {
return nil
}
vArray := (*[32]byte)(v)
return tlv.DBytes32(r, vArray, buf, 32)
}
// HTLCEntry specifies the minimal info needed to be stored on disk for ALL the // HTLCEntry specifies the minimal info needed to be stored on disk for ALL the
// historical HTLCs, which is useful for constructing RevocationLog when a // historical HTLCs, which is useful for constructing RevocationLog when a
// breach is detected. // breach is detected.
@ -72,116 +140,90 @@ var (
// made into tlv records without further conversion. // made into tlv records without further conversion.
type HTLCEntry struct { type HTLCEntry struct {
// RHash is the payment hash of the HTLC. // RHash is the payment hash of the HTLC.
RHash [32]byte RHash tlv.RecordT[tlv.TlvType0, SparsePayHash]
// RefundTimeout is the absolute timeout on the HTLC that the sender // RefundTimeout is the absolute timeout on the HTLC that the sender
// must wait before reclaiming the funds in limbo. // must wait before reclaiming the funds in limbo.
RefundTimeout uint32 RefundTimeout tlv.RecordT[tlv.TlvType1, uint32]
// OutputIndex is the output index for this particular HTLC output // OutputIndex is the output index for this particular HTLC output
// within the commitment transaction. // within the commitment transaction.
// //
// NOTE: we use uint16 instead of int32 here to save us 2 bytes, which // NOTE: we use uint16 instead of int32 here to save us 2 bytes, which
// gives us a max number of HTLCs of 65K. // gives us a max number of HTLCs of 65K.
OutputIndex uint16 OutputIndex tlv.RecordT[tlv.TlvType2, uint16]
// Incoming denotes whether we're the receiver or the sender of this // Incoming denotes whether we're the receiver or the sender of this
// HTLC. // HTLC.
// Incoming tlv.RecordT[tlv.TlvType3, bool]
// NOTE: this field is the memory representation of the field
// incomingUint.
Incoming bool
// Amt is the amount of satoshis this HTLC escrows. // Amt is the amount of satoshis this HTLC escrows.
// Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]]
// NOTE: this field is the memory representation of the field amtUint.
Amt btcutil.Amount
// amtTlv is the uint64 format of Amt. This field is created so we can // CustomBlob is an optional blob that can be used to store information
// easily make it into a tlv record and save it to disk. // specific to revocation handling for a custom channel type.
// CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob]
// NOTE: we keep this field for accounting purpose only. If the disk
// space becomes an issue, we could delete this field to save us extra
// 8 bytes.
amtTlv uint64
// incomingTlv is the uint8 format of Incoming. This field is created // HtlcIndex is the index of the HTLC in the channel.
// so we can easily make it into a tlv record and save it to disk. HtlcIndex tlv.OptionalRecordT[tlv.TlvType6, uint16]
incomingTlv uint8
}
// RHashLen is used by MakeDynamicRecord to return the size of the RHash.
//
// NOTE: for zero hash, we return a length 0.
func (h *HTLCEntry) RHashLen() uint64 {
if h.RHash == lntypes.ZeroHash {
return 0
}
return 32
}
// RHashEncoder is the customized encoder which skips encoding the empty hash.
func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
v, ok := val.(*[32]byte)
if !ok {
return tlv.NewTypeForEncodingErr(val, "RHash")
}
// If the value is an empty hash, we will skip encoding it.
if *v == lntypes.ZeroHash {
return nil
}
return tlv.EBytes32(w, v, buf)
}
// RHashDecoder is the customized decoder which skips decoding the empty hash.
func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
v, ok := val.(*[32]byte)
if !ok {
return tlv.NewTypeForEncodingErr(val, "RHash")
}
// If the length is zero, we will skip encoding the empty hash.
if l == 0 {
return nil
}
return tlv.DBytes32(r, v, buf, 32)
} }
// toTlvStream converts an HTLCEntry record into a tlv representation. // toTlvStream converts an HTLCEntry record into a tlv representation.
func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) {
const ( records := []tlv.Record{
// A set of tlv type definitions used to serialize htlc entries h.RHash.Record(),
// to the database. We define it here instead of the head of h.RefundTimeout.Record(),
// the file to avoid naming conflicts. h.OutputIndex.Record(),
// h.Incoming.Record(),
// NOTE: A migration should be added whenever this list h.Amt.Record(),
// changes. }
rHashType tlv.Type = 0
refundTimeoutType tlv.Type = 1
outputIndexType tlv.Type = 2
incomingType tlv.Type = 3
amtType tlv.Type = 4
)
return tlv.NewStream( h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) {
tlv.MakeDynamicRecord( records = append(records, r.Record())
rHashType, &h.RHash, h.RHashLen, })
RHashEncoder, RHashDecoder,
h.HtlcIndex.WhenSome(func(r tlv.RecordT[tlv.TlvType6, uint16]) {
records = append(records, r.Record())
})
tlv.SortRecords(records)
return tlv.NewStream(records...)
}
// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC.
func NewHTLCEntryFromHTLC(htlc HTLC) (*HTLCEntry, error) {
h := &HTLCEntry{
RHash: tlv.NewRecordT[tlv.TlvType0](
NewSparsePayHash(htlc.RHash),
), ),
tlv.MakePrimitiveRecord( RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1](
refundTimeoutType, &h.RefundTimeout, htlc.RefundTimeout,
), ),
tlv.MakePrimitiveRecord( OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2](
outputIndexType, &h.OutputIndex, uint16(htlc.OutputIndex),
), ),
tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv), Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming),
// We will save 3 bytes if the amount is less or equal to Amt: tlv.NewRecordT[tlv.TlvType4](
// 4,294,967,295 msat, or roughly 0.043 bitcoin. tlv.NewBigSizeT(htlc.Amt.ToSatoshis()),
tlv.MakeBigSizeRecord(amtType, &h.amtTlv), ),
) HtlcIndex: tlv.SomeRecordT(tlv.NewPrimitiveRecord[tlv.TlvType6](
uint16(htlc.HtlcIndex),
)),
}
if len(htlc.CustomRecords) != 0 {
blob, err := htlc.CustomRecords.Serialize()
if err != nil {
return nil, err
}
h.CustomBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob),
)
}
return h, nil
} }
// RevocationLog stores the info needed to construct a breach retribution. Its // RevocationLog stores the info needed to construct a breach retribution. Its
@ -191,15 +233,15 @@ func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) {
type RevocationLog struct { type RevocationLog struct {
// OurOutputIndex specifies our output index in this commitment. In a // OurOutputIndex specifies our output index in this commitment. In a
// remote commitment transaction, this is the to remote output index. // remote commitment transaction, this is the to remote output index.
OurOutputIndex uint16 OurOutputIndex tlv.RecordT[tlv.TlvType0, uint16]
// TheirOutputIndex specifies their output index in this commitment. In // TheirOutputIndex specifies their output index in this commitment. In
// a remote commitment transaction, this is the to local output index. // a remote commitment transaction, this is the to local output index.
TheirOutputIndex uint16 TheirOutputIndex tlv.RecordT[tlv.TlvType1, uint16]
// CommitTxHash is the hash of the latest version of the commitment // CommitTxHash is the hash of the latest version of the commitment
// state, broadcast able by us. // state, broadcast able by us.
CommitTxHash [32]byte CommitTxHash tlv.RecordT[tlv.TlvType2, [32]byte]
// HTLCEntries is the set of HTLCEntry's that are pending at this // HTLCEntries is the set of HTLCEntry's that are pending at this
// particular commitment height. // particular commitment height.
@ -209,21 +251,65 @@ type RevocationLog struct {
// directly spendable by us. In other words, it is the value of the // directly spendable by us. In other words, it is the value of the
// to_remote output on the remote parties' commitment transaction. // to_remote output on the remote parties' commitment transaction.
// //
// NOTE: this is a pointer so that it is clear if the value is zero or // NOTE: this is an option so that it is clear if the value is zero or
// nil. Since migration 30 of the channeldb initially did not include // nil. Since migration 30 of the channeldb initially did not include
// this field, it could be the case that the field is not present for // this field, it could be the case that the field is not present for
// all revocation logs. // all revocation logs.
OurBalance *lnwire.MilliSatoshi OurBalance tlv.OptionalRecordT[tlv.TlvType3, BigSizeMilliSatoshi]
// TheirBalance is the current available balance within the channel // TheirBalance is the current available balance within the channel
// directly spendable by the remote node. In other words, it is the // directly spendable by the remote node. In other words, it is the
// value of the to_local output on the remote parties' commitment. // value of the to_local output on the remote parties' commitment.
// //
// NOTE: this is a pointer so that it is clear if the value is zero or // NOTE: this is an option so that it is clear if the value is zero or
// nil. Since migration 30 of the channeldb initially did not include // nil. Since migration 30 of the channeldb initially did not include
// this field, it could be the case that the field is not present for // this field, it could be the case that the field is not present for
// all revocation logs. // all revocation logs.
TheirBalance *lnwire.MilliSatoshi TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi]
// CustomBlob is an optional blob that can be used to store information
// specific to a custom channel type. This information is only created
// at channel funding time, and after wards is to be considered
// immutable.
CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob]
}
// NewRevocationLog creates a new RevocationLog from the given parameters.
func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16,
commitHash [32]byte, ourBalance,
theirBalance fn.Option[lnwire.MilliSatoshi], htlcs []*HTLCEntry,
customBlob fn.Option[tlv.Blob]) RevocationLog {
rl := RevocationLog{
OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0](
ourOutputIndex,
),
TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1](
theirOutputIndex,
),
CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2](commitHash),
HTLCEntries: htlcs,
}
ourBalance.WhenSome(func(balance lnwire.MilliSatoshi) {
rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3](
tlv.NewBigSizeT(balance),
))
})
theirBalance.WhenSome(func(balance lnwire.MilliSatoshi) {
rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(balance),
))
})
customBlob.WhenSome(func(blob tlv.Blob) {
rl.CustomBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob),
)
})
return rl
} }
// putRevocationLog uses the fields `CommitTx` and `Htlcs` from a // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a
@ -242,15 +328,32 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment,
} }
rl := &RevocationLog{ rl := &RevocationLog{
OurOutputIndex: uint16(ourOutputIndex), OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0](
TheirOutputIndex: uint16(theirOutputIndex), uint16(ourOutputIndex),
CommitTxHash: commit.CommitTx.TxHash(), ),
HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1](
uint16(theirOutputIndex),
),
CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2, [32]byte](
commit.CommitTx.TxHash(),
),
HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)),
} }
commit.CustomBlob.WhenSome(func(blob tlv.Blob) {
rl.CustomBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob),
)
})
if !noAmtData { if !noAmtData {
rl.OurBalance = &commit.LocalBalance rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3](
rl.TheirBalance = &commit.RemoteBalance tlv.NewBigSizeT(commit.LocalBalance),
))
rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(commit.RemoteBalance),
))
} }
for _, htlc := range commit.Htlcs { for _, htlc := range commit.Htlcs {
@ -265,12 +368,9 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment,
return ErrOutputIndexTooBig return ErrOutputIndexTooBig
} }
entry := &HTLCEntry{ entry, err := NewHTLCEntryFromHTLC(htlc)
RHash: htlc.RHash, if err != nil {
RefundTimeout: htlc.RefundTimeout, return err
Incoming: htlc.Incoming,
OutputIndex: uint16(htlc.OutputIndex),
Amt: htlc.Amt.ToSatoshis(),
} }
rl.HTLCEntries = append(rl.HTLCEntries, entry) rl.HTLCEntries = append(rl.HTLCEntries, entry)
} }
@ -306,31 +406,27 @@ func fetchRevocationLog(log kvdb.RBucket,
func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
// Add the tlv records for all non-optional fields. // Add the tlv records for all non-optional fields.
records := []tlv.Record{ records := []tlv.Record{
tlv.MakePrimitiveRecord( rl.OurOutputIndex.Record(),
revLogOurOutputIndexType, &rl.OurOutputIndex, rl.TheirOutputIndex.Record(),
), rl.CommitTxHash.Record(),
tlv.MakePrimitiveRecord(
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogCommitTxHashType, &rl.CommitTxHash,
),
} }
// Now we add any optional fields that are non-nil. // Now we add any optional fields that are non-nil.
if rl.OurBalance != nil { rl.OurBalance.WhenSome(
lb := uint64(*rl.OurBalance) func(r tlv.RecordT[tlv.TlvType3, BigSizeMilliSatoshi]) {
records = append(records, tlv.MakeBigSizeRecord( records = append(records, r.Record())
revLogOurBalanceType, &lb, },
)) )
}
if rl.TheirBalance != nil { rl.TheirBalance.WhenSome(
rb := uint64(*rl.TheirBalance) func(r tlv.RecordT[tlv.TlvType4, BigSizeMilliSatoshi]) {
records = append(records, tlv.MakeBigSizeRecord( records = append(records, r.Record())
revLogTheirBalanceType, &rb, },
)) )
}
rl.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) {
records = append(records, r.Record())
})
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := tlv.NewStream(records...) tlvStream, err := tlv.NewStream(records...)
@ -351,14 +447,6 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
// format. // format.
func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error {
for _, htlc := range htlcs { for _, htlc := range htlcs {
// Patch the incomingTlv field.
if htlc.Incoming {
htlc.incomingTlv = 1
}
// Patch the amtTlv field.
htlc.amtTlv = uint64(htlc.Amt)
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := htlc.toTlvStream() tlvStream, err := htlc.toTlvStream()
if err != nil { if err != nil {
@ -376,27 +464,20 @@ func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error {
// deserializeRevocationLog deserializes a RevocationLog based on tlv format. // deserializeRevocationLog deserializes a RevocationLog based on tlv format.
func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { func deserializeRevocationLog(r io.Reader) (RevocationLog, error) {
var ( var rl RevocationLog
rl RevocationLog
ourBalance uint64 ourBalance := rl.OurBalance.Zero()
theirBalance uint64 theirBalance := rl.TheirBalance.Zero()
) customBlob := rl.CustomBlob.Zero()
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := tlv.NewStream( tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord( rl.OurOutputIndex.Record(),
revLogOurOutputIndexType, &rl.OurOutputIndex, rl.TheirOutputIndex.Record(),
), rl.CommitTxHash.Record(),
tlv.MakePrimitiveRecord( ourBalance.Record(),
revLogTheirOutputIndexType, &rl.TheirOutputIndex, theirBalance.Record(),
), customBlob.Record(),
tlv.MakePrimitiveRecord(
revLogCommitTxHashType, &rl.CommitTxHash,
),
tlv.MakeBigSizeRecord(revLogOurBalanceType, &ourBalance),
tlv.MakeBigSizeRecord(
revLogTheirBalanceType, &theirBalance,
),
) )
if err != nil { if err != nil {
return rl, err return rl, err
@ -408,14 +489,16 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) {
return rl, err return rl, err
} }
if t, ok := parsedTypes[revLogOurBalanceType]; ok && t == nil { if t, ok := parsedTypes[ourBalance.TlvType()]; ok && t == nil {
lb := lnwire.MilliSatoshi(ourBalance) rl.OurBalance = tlv.SomeRecordT(ourBalance)
rl.OurBalance = &lb
} }
if t, ok := parsedTypes[revLogTheirBalanceType]; ok && t == nil { if t, ok := parsedTypes[theirBalance.TlvType()]; ok && t == nil {
rb := lnwire.MilliSatoshi(theirBalance) rl.TheirBalance = tlv.SomeRecordT(theirBalance)
rl.TheirBalance = &rb }
if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil {
rl.CustomBlob = tlv.SomeRecordT(customBlob)
} }
// Read the HTLC entries. // Read the HTLC entries.
@ -432,14 +515,28 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
for { for {
var htlc HTLCEntry var htlc HTLCEntry
customBlob := htlc.CustomBlob.Zero()
htlcIndex := htlc.HtlcIndex.Zero()
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := htlc.toTlvStream() records := []tlv.Record{
htlc.RHash.Record(),
htlc.RefundTimeout.Record(),
htlc.OutputIndex.Record(),
htlc.Incoming.Record(),
htlc.Amt.Record(),
customBlob.Record(),
htlcIndex.Record(),
}
tlvStream, err := tlv.NewStream(records...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Read the HTLC entry. // Read the HTLC entry.
if _, err := readTlvStream(r, tlvStream); err != nil { parsedTypes, err := readTlvStream(r, tlvStream)
if err != nil {
// We've reached the end when hitting an EOF. // We've reached the end when hitting an EOF.
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
break break
@ -447,13 +544,13 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
return nil, err return nil, err
} }
// Patch the Incoming field. if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil {
if htlc.incomingTlv == 1 { htlc.CustomBlob = tlv.SomeRecordT(customBlob)
htlc.Incoming = true
} }
// Patch the Amt field. if t, ok := parsedTypes[htlcIndex.TlvType()]; ok && t == nil {
htlc.Amt = btcutil.Amount(htlc.amtTlv) htlc.HtlcIndex = tlv.SomeRecordT(htlcIndex)
}
// Append the entry. // Append the entry.
htlcs = append(htlcs, &htlc) htlcs = append(htlcs, &htlc)
@ -469,6 +566,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error {
if err := s.Encode(&b); err != nil { if err := s.Encode(&b); err != nil {
return err return err
} }
// Write the stream's length as a varint. // Write the stream's length as a varint.
err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{}) err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{})
if err != nil { if err != nil {

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lntest/channels"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -33,17 +34,38 @@ var (
0xff, // value = 255 0xff, // value = 255
} }
customRecords = lnwire.CustomRecords{
lnwire.MinCustomRecordsTlvType + 1: []byte("custom data"),
}
blobBytes = []byte{
// Corresponds to the encoded version of the above custom
// records.
0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, 0x74,
0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61,
}
testHTLCEntry = HTLCEntry{ testHTLCEntry = HTLCEntry{
RefundTimeout: 740_000, RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32](
OutputIndex: 10, 740_000,
Incoming: true, ),
Amt: 1000_000, OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](
amtTlv: 1000_000, 10,
incomingTlv: 1, ),
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true),
Amt: tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(btcutil.Amount(1_000_000)),
),
CustomBlob: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5](blobBytes),
),
HtlcIndex: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType6, uint16](0x33),
),
} }
testHTLCEntryBytes = []byte{ testHTLCEntryBytes = []byte{
// Body length 23. // Body length 45.
0x16, 0x2d,
// Rhash tlv. // Rhash tlv.
0x0, 0x0, 0x0, 0x0,
// RefundTimeout tlv. // RefundTimeout tlv.
@ -54,6 +76,45 @@ var (
0x3, 0x1, 0x1, 0x3, 0x1, 0x1,
// Amt tlv. // Amt tlv.
0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40,
// Custom blob tlv.
0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73,
0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61,
// HLTC index tlv.
0x6, 0x2, 0x0, 0x33,
}
testHTLCEntryHash = HTLCEntry{
RHash: tlv.NewPrimitiveRecord[tlv.TlvType0](NewSparsePayHash(
[32]byte{0x33, 0x44, 0x55},
)),
RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32](
740_000,
),
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](
10,
),
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true),
Amt: tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(btcutil.Amount(1_000_000)),
),
}
testHTLCEntryHashBytes = []byte{
// Body length 54.
0x36,
// Rhash tlv.
0x0, 0x20,
0x33, 0x44, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
// RefundTimeout tlv.
0x1, 0x4, 0x0, 0xb, 0x4a, 0xa0,
// OutputIndex tlv.
0x2, 0x2, 0x0, 0xa,
// Incoming tlv.
0x3, 0x1, 0x1,
// Amt tlv.
0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40,
} }
localBalance = lnwire.MilliSatoshi(9000) localBalance = lnwire.MilliSatoshi(9000)
@ -68,24 +129,29 @@ var (
CommitTx: channels.TestFundingTx, CommitTx: channels.TestFundingTx,
CommitSig: bytes.Repeat([]byte{1}, 71), CommitSig: bytes.Repeat([]byte{1}, 71),
Htlcs: []HTLC{{ Htlcs: []HTLC{{
RefundTimeout: testHTLCEntry.RefundTimeout, RefundTimeout: testHTLCEntry.RefundTimeout.Val,
OutputIndex: int32(testHTLCEntry.OutputIndex), OutputIndex: int32(testHTLCEntry.OutputIndex.Val),
Incoming: testHTLCEntry.Incoming, HtlcIndex: uint64(
Amt: lnwire.NewMSatFromSatoshis( testHTLCEntry.HtlcIndex.ValOpt().
testHTLCEntry.Amt, UnsafeFromSome(),
), ),
Incoming: testHTLCEntry.Incoming.Val,
Amt: lnwire.NewMSatFromSatoshis(
testHTLCEntry.Amt.Val.Int(),
),
CustomRecords: customRecords,
}}, }},
CustomBlob: fn.Some(blobBytes),
} }
testRevocationLogNoAmts = RevocationLog{ testRevocationLogNoAmts = NewRevocationLog(
OurOutputIndex: 0, 0, 1, testChannelCommit.CommitTx.TxHash(),
TheirOutputIndex: 1, fn.None[lnwire.MilliSatoshi](), fn.None[lnwire.MilliSatoshi](),
CommitTxHash: testChannelCommit.CommitTx.TxHash(), []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes),
HTLCEntries: []*HTLCEntry{&testHTLCEntry}, )
}
testRevocationLogNoAmtsBytes = []byte{ testRevocationLogNoAmtsBytes = []byte{
// Body length 42. // Body length 61.
0x2a, 0x3d,
// OurOutputIndex tlv. // OurOutputIndex tlv.
0x0, 0x2, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0,
// TheirOutputIndex tlv. // TheirOutputIndex tlv.
@ -96,19 +162,19 @@ var (
0x6e, 0x60, 0x29, 0x23, 0x1d, 0x5e, 0xc5, 0xe6, 0x6e, 0x60, 0x29, 0x23, 0x1d, 0x5e, 0xc5, 0xe6,
0xbd, 0xf7, 0xd3, 0x9b, 0x16, 0x7d, 0x0, 0xff, 0xbd, 0xf7, 0xd3, 0x9b, 0x16, 0x7d, 0x0, 0xff,
0xc8, 0x22, 0x51, 0xb1, 0x5b, 0xa0, 0xbf, 0xd, 0xc8, 0x22, 0x51, 0xb1, 0x5b, 0xa0, 0xbf, 0xd,
// Custom blob tlv.
0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73,
0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61,
} }
testRevocationLogWithAmts = RevocationLog{ testRevocationLogWithAmts = NewRevocationLog(
OurOutputIndex: 0, 0, 1, testChannelCommit.CommitTx.TxHash(),
TheirOutputIndex: 1, fn.Some(localBalance), fn.Some(remoteBalance),
CommitTxHash: testChannelCommit.CommitTx.TxHash(), []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes),
HTLCEntries: []*HTLCEntry{&testHTLCEntry}, )
OurBalance: &localBalance,
TheirBalance: &remoteBalance,
}
testRevocationLogWithAmtsBytes = []byte{ testRevocationLogWithAmtsBytes = []byte{
// Body length 52. // Body length 71.
0x34, 0x47,
// OurOutputIndex tlv. // OurOutputIndex tlv.
0x0, 0x2, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0,
// TheirOutputIndex tlv. // TheirOutputIndex tlv.
@ -123,6 +189,9 @@ var (
0x3, 0x3, 0xfd, 0x23, 0x28, 0x3, 0x3, 0xfd, 0x23, 0x28,
// Remote Balance. // Remote Balance.
0x4, 0x3, 0xfd, 0x0b, 0xb8, 0x4, 0x3, 0xfd, 0x0b, 0xb8,
// Custom blob tlv.
0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73,
0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61,
} }
) )
@ -193,11 +262,6 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) {
// Copy the testHTLCEntry. // Copy the testHTLCEntry.
entry := testHTLCEntry entry := testHTLCEntry
// Set the internal fields to empty values so we can test the bytes are
// padded.
entry.incomingTlv = 0
entry.amtTlv = 0
// Write the tlv stream. // Write the tlv stream.
buf := bytes.NewBuffer([]byte{}) buf := bytes.NewBuffer([]byte{})
err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) err := serializeHTLCEntries(buf, []*HTLCEntry{&entry})
@ -207,6 +271,21 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) {
require.Equal(t, testHTLCEntryBytes, buf.Bytes()) require.Equal(t, testHTLCEntryBytes, buf.Bytes())
} }
func TestSerializeHTLCEntriesWithRHash(t *testing.T) {
t.Parallel()
// Copy the testHTLCEntry.
entry := testHTLCEntryHash
// Write the tlv stream.
buf := bytes.NewBuffer([]byte{})
err := serializeHTLCEntries(buf, []*HTLCEntry{&entry})
require.NoError(t, err)
// Check the bytes are read as expected.
require.Equal(t, testHTLCEntryHashBytes, buf.Bytes())
}
func TestSerializeHTLCEntries(t *testing.T) { func TestSerializeHTLCEntries(t *testing.T) {
t.Parallel() t.Parallel()
@ -215,7 +294,7 @@ func TestSerializeHTLCEntries(t *testing.T) {
// Create a fake rHash. // Create a fake rHash.
rHashBytes := bytes.Repeat([]byte{10}, 32) rHashBytes := bytes.Repeat([]byte{10}, 32)
copy(entry.RHash[:], rHashBytes) copy(entry.RHash.Val[:], rHashBytes)
// Construct the serialized bytes. // Construct the serialized bytes.
// //
@ -224,7 +303,7 @@ func TestSerializeHTLCEntries(t *testing.T) {
partialBytes := testHTLCEntryBytes[3:] partialBytes := testHTLCEntryBytes[3:]
// Write the total length and RHash tlv. // Write the total length and RHash tlv.
expectedBytes := []byte{0x36, 0x0, 0x20} expectedBytes := []byte{0x4d, 0x0, 0x20}
expectedBytes = append(expectedBytes, rHashBytes...) expectedBytes = append(expectedBytes, rHashBytes...)
// Append the rest. // Append the rest.
@ -269,7 +348,7 @@ func TestSerializeAndDeserializeRevLog(t *testing.T) {
t, &test.revLog, test.revLogBytes, t, &test.revLog, test.revLogBytes,
) )
testDerializeRevocationLog( testDeserializeRevocationLog(
t, &test.revLog, test.revLogBytes, t, &test.revLog, test.revLogBytes,
) )
}) })
@ -293,7 +372,7 @@ func testSerializeRevocationLog(t *testing.T, rl *RevocationLog,
require.Equal(t, revLogBytes, buf.Bytes()[:bodyIndex]) require.Equal(t, revLogBytes, buf.Bytes()[:bodyIndex])
} }
func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog, func testDeserializeRevocationLog(t *testing.T, revLog *RevocationLog,
revLogBytes []byte) { revLogBytes []byte) {
// Construct the full bytes. // Construct the full bytes.
@ -309,7 +388,7 @@ func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog,
require.Equal(t, *revLog, rl) require.Equal(t, *revLog, rl)
} }
func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) { func TestDeserializeHTLCEntriesEmptyRHash(t *testing.T) {
t.Parallel() t.Parallel()
// Read the tlv stream. // Read the tlv stream.
@ -322,7 +401,7 @@ func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) {
require.Equal(t, &testHTLCEntry, htlcs[0]) require.Equal(t, &testHTLCEntry, htlcs[0])
} }
func TestDerializeHTLCEntries(t *testing.T) { func TestDeserializeHTLCEntries(t *testing.T) {
t.Parallel() t.Parallel()
// Copy the testHTLCEntry. // Copy the testHTLCEntry.
@ -330,7 +409,7 @@ func TestDerializeHTLCEntries(t *testing.T) {
// Create a fake rHash. // Create a fake rHash.
rHashBytes := bytes.Repeat([]byte{10}, 32) rHashBytes := bytes.Repeat([]byte{10}, 32)
copy(entry.RHash[:], rHashBytes) copy(entry.RHash.Val[:], rHashBytes)
// Construct the serialized bytes. // Construct the serialized bytes.
// //
@ -339,7 +418,7 @@ func TestDerializeHTLCEntries(t *testing.T) {
partialBytes := testHTLCEntryBytes[3:] partialBytes := testHTLCEntryBytes[3:]
// Write the total length and RHash tlv. // Write the total length and RHash tlv.
testBytes := append([]byte{0x36, 0x0, 0x20}, rHashBytes...) testBytes := append([]byte{0x4d, 0x0, 0x20}, rHashBytes...)
// Append the rest. // Append the rest.
testBytes = append(testBytes, partialBytes...) testBytes = append(testBytes, partialBytes...)
@ -398,11 +477,11 @@ func TestDeleteLogBucket(t *testing.T) {
err = kvdb.Update(backend, func(tx kvdb.RwTx) error { err = kvdb.Update(backend, func(tx kvdb.RwTx) error {
// Create the buckets. // Create the buckets.
chanBucket, _, err := createTestRevocatoinLogBuckets(tx) chanBucket, _, err := createTestRevocationLogBuckets(tx)
require.NoError(t, err) require.NoError(t, err)
// Create the buckets again should give us an error. // Create the buckets again should give us an error.
_, _, err = createTestRevocatoinLogBuckets(tx) _, _, err = createTestRevocationLogBuckets(tx)
require.ErrorIs(t, err, kvdb.ErrBucketExists) require.ErrorIs(t, err, kvdb.ErrBucketExists)
// Delete both buckets. // Delete both buckets.
@ -410,7 +489,7 @@ func TestDeleteLogBucket(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Create the buckets again should give us NO error. // Create the buckets again should give us NO error.
_, _, err = createTestRevocatoinLogBuckets(tx) _, _, err = createTestRevocationLogBuckets(tx)
return err return err
}, func() {}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
@ -516,7 +595,7 @@ func TestPutRevocationLog(t *testing.T) {
// Construct the testing db transaction. // Construct the testing db transaction.
dbTx := func(tx kvdb.RwTx) (RevocationLog, error) { dbTx := func(tx kvdb.RwTx) (RevocationLog, error) {
// Create the buckets. // Create the buckets.
_, bucket, err := createTestRevocatoinLogBuckets(tx) _, bucket, err := createTestRevocationLogBuckets(tx)
require.NoError(t, err) require.NoError(t, err)
// Save the log. // Save the log.
@ -686,7 +765,7 @@ func TestFetchRevocationLogCompatible(t *testing.T) {
} }
} }
func createTestRevocatoinLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket, func createTestRevocationLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket,
kvdb.RwBucket, error) { kvdb.RwBucket, error) {
chanBucket, err := tx.CreateTopLevelBucket(openChannelBucket) chanBucket, err := tx.CreateTopLevelBucket(openChannelBucket)

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"regexp" "regexp"
@ -42,7 +42,7 @@ func parseTime(s string, base time.Time) (uint64, error) {
var lightningPrefix = "lightning:" var lightningPrefix = "lightning:"
// stripPrefix removes accidentally copied 'lightning:' prefix. // StripPrefix removes accidentally copied 'lightning:' prefix.
func stripPrefix(s string) string { func StripPrefix(s string) string {
return strings.TrimSpace(strings.TrimPrefix(s, lightningPrefix)) return strings.TrimSpace(strings.TrimPrefix(s, lightningPrefix))
} }

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"testing" "testing"
@ -111,7 +111,7 @@ func TestStripPrefix(t *testing.T) {
t.Parallel() t.Parallel()
for _, test := range stripPrefixTests { for _, test := range stripPrefixTests {
actual := stripPrefix(test.in) actual := StripPrefix(test.in)
require.Equal(t, test.expected, actual) require.Equal(t, test.expected, actual)
} }
} }

View file

@ -1,7 +1,7 @@
//go:build autopilotrpc //go:build autopilotrpc
// +build autopilotrpc // +build autopilotrpc
package main package commands
import ( import (
"github.com/lightningnetwork/lnd/lnrpc/autopilotrpc" "github.com/lightningnetwork/lnd/lnrpc/autopilotrpc"

View file

@ -1,7 +1,7 @@
//go:build !autopilotrpc //go:build !autopilotrpc
// +build !autopilotrpc // +build !autopilotrpc
package main package commands
import "github.com/urfave/cli" import "github.com/urfave/cli"

View file

@ -1,7 +1,7 @@
//go:build chainrpc //go:build chainrpc
// +build chainrpc // +build chainrpc
package main package commands
import ( import (
"bytes" "bytes"

View file

@ -1,7 +1,7 @@
//go:build !chainrpc //go:build !chainrpc
// +build !chainrpc // +build !chainrpc
package main package commands
import "github.com/urfave/cli" import "github.com/urfave/cli"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"encoding/hex" "encoding/hex"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"bytes" "bytes"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"context" "context"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"encoding/hex" "encoding/hex"
@ -9,7 +9,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
) )
var addInvoiceCommand = cli.Command{ var AddInvoiceCommand = cli.Command{
Name: "addinvoice", Name: "addinvoice",
Category: "Invoices", Category: "Invoices",
Usage: "Add a new invoice.", Usage: "Add a new invoice.",
@ -408,7 +408,7 @@ func decodePayReq(ctx *cli.Context) error {
} }
resp, err := client.DecodePayReq(ctxc, &lnrpc.PayReqString{ resp, err := client.DecodePayReq(ctxc, &lnrpc.PayReqString{
PayReq: stripPrefix(payreq), PayReq: StripPrefix(payreq),
}) })
if err != nil { if err != nil {
return err return err

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"bytes" "bytes"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"fmt" "fmt"
@ -265,6 +265,7 @@ func setCfg(ctx *cli.Context) error {
Config: mcCfg.Config, Config: mcCfg.Config,
}, },
) )
return err return err
} }
@ -366,5 +367,6 @@ func resetMissionControl(ctx *cli.Context) error {
req := &routerrpc.ResetMissionControlRequest{} req := &routerrpc.ResetMissionControlRequest{}
_, err := client.ResetMissionControl(ctxc, req) _, err := client.ResetMissionControl(ctxc, req)
return err return err
} }

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"bytes" "bytes"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"bytes" "bytes"
@ -25,6 +25,7 @@ import (
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/urfave/cli" "github.com/urfave/cli"
"google.golang.org/grpc"
) )
const ( const (
@ -152,8 +153,8 @@ var (
} }
) )
// paymentFlags returns common flags for sendpayment and payinvoice. // PaymentFlags returns common flags for sendpayment and payinvoice.
func paymentFlags() []cli.Flag { func PaymentFlags() []cli.Flag {
return []cli.Flag{ return []cli.Flag{
cli.StringFlag{ cli.StringFlag{
Name: "pay_req", Name: "pay_req",
@ -202,7 +203,7 @@ func paymentFlags() []cli.Flag {
} }
} }
var sendPaymentCommand = cli.Command{ var SendPaymentCommand = cli.Command{
Name: "sendpayment", Name: "sendpayment",
Category: "Payments", Category: "Payments",
Usage: "Send a payment over lightning.", Usage: "Send a payment over lightning.",
@ -226,7 +227,7 @@ var sendPaymentCommand = cli.Command{
`, `,
ArgsUsage: "dest amt payment_hash final_cltv_delta pay_addr | " + ArgsUsage: "dest amt payment_hash final_cltv_delta pay_addr | " +
"--pay_req=R [--pay_addr=H]", "--pay_req=R [--pay_addr=H]",
Flags: append(paymentFlags(), Flags: append(PaymentFlags(),
cli.StringFlag{ cli.StringFlag{
Name: "dest, d", Name: "dest, d",
Usage: "the compressed identity pubkey of the " + Usage: "the compressed identity pubkey of the " +
@ -253,7 +254,7 @@ var sendPaymentCommand = cli.Command{
Usage: "will generate a pre-image and encode it in the sphinx packet, a dest must be set [experimental]", Usage: "will generate a pre-image and encode it in the sphinx packet, a dest must be set [experimental]",
}, },
), ),
Action: sendPayment, Action: SendPayment,
} }
// retrieveFeeLimit retrieves the fee limit based on the different fee limit // retrieveFeeLimit retrieves the fee limit based on the different fee limit
@ -324,20 +325,23 @@ func parsePayAddr(ctx *cli.Context, args cli.Args) ([]byte, error) {
return payAddr, nil return payAddr, nil
} }
func sendPayment(ctx *cli.Context) error { func SendPayment(ctx *cli.Context) error {
// Show command help if no arguments provided // Show command help if no arguments provided
if ctx.NArg() == 0 && ctx.NumFlags() == 0 { if ctx.NArg() == 0 && ctx.NumFlags() == 0 {
_ = cli.ShowCommandHelp(ctx, "sendpayment") _ = cli.ShowCommandHelp(ctx, "sendpayment")
return nil return nil
} }
conn := getClientConn(ctx, false)
defer conn.Close()
args := ctx.Args() args := ctx.Args()
// If a payment request was provided, we can exit early since all of the // If a payment request was provided, we can exit early since all of the
// details of the payment are encoded within the request. // details of the payment are encoded within the request.
if ctx.IsSet("pay_req") { if ctx.IsSet("pay_req") {
req := &routerrpc.SendPaymentRequest{ req := &routerrpc.SendPaymentRequest{
PaymentRequest: stripPrefix(ctx.String("pay_req")), PaymentRequest: StripPrefix(ctx.String("pay_req")),
Amt: ctx.Int64("amt"), Amt: ctx.Int64("amt"),
DestCustomRecords: make(map[uint64][]byte), DestCustomRecords: make(map[uint64][]byte),
Amp: ctx.Bool(ampFlag.Name), Amp: ctx.Bool(ampFlag.Name),
@ -357,7 +361,9 @@ func sendPayment(ctx *cli.Context) error {
req.PaymentAddr = payAddr req.PaymentAddr = payAddr
return sendPaymentRequest(ctx, req) return SendPaymentRequest(
ctx, req, conn, conn, routerRPCSendPayment,
)
} }
var ( var (
@ -466,19 +472,29 @@ func sendPayment(ctx *cli.Context) error {
req.PaymentAddr = payAddr req.PaymentAddr = payAddr
return sendPaymentRequest(ctx, req) return SendPaymentRequest(ctx, req, conn, conn, routerRPCSendPayment)
} }
func sendPaymentRequest(ctx *cli.Context, // SendPaymentFn is a function type that abstracts the SendPaymentV2 call of the
req *routerrpc.SendPaymentRequest) error { // router client.
type SendPaymentFn func(ctx context.Context, payConn grpc.ClientConnInterface,
req *routerrpc.SendPaymentRequest) (PaymentResultStream, error)
// routerRPCSendPayment is the default implementation of the SendPaymentFn type
// that uses the lnd routerrpc.SendPaymentV2 call.
func routerRPCSendPayment(ctx context.Context, payConn grpc.ClientConnInterface,
req *routerrpc.SendPaymentRequest) (PaymentResultStream, error) {
return routerrpc.NewRouterClient(payConn).SendPaymentV2(ctx, req)
}
func SendPaymentRequest(ctx *cli.Context, req *routerrpc.SendPaymentRequest,
lnConn, paymentConn grpc.ClientConnInterface,
callSendPayment SendPaymentFn) error {
ctxc := getContext() ctxc := getContext()
conn := getClientConn(ctx, false) lnClient := lnrpc.NewLightningClient(lnConn)
defer conn.Close()
client := lnrpc.NewLightningClient(conn)
routerClient := routerrpc.NewRouterClient(conn)
outChan := ctx.Int64Slice("outgoing_chan_id") outChan := ctx.Int64Slice("outgoing_chan_id")
if len(outChan) != 0 { if len(outChan) != 0 {
@ -558,7 +574,7 @@ func sendPaymentRequest(ctx *cli.Context,
if req.PaymentRequest != "" { if req.PaymentRequest != "" {
// Decode payment request to find out the amount. // Decode payment request to find out the amount.
decodeReq := &lnrpc.PayReqString{PayReq: req.PaymentRequest} decodeReq := &lnrpc.PayReqString{PayReq: req.PaymentRequest}
decodeResp, err := client.DecodePayReq(ctxc, decodeReq) decodeResp, err := lnClient.DecodePayReq(ctxc, decodeReq)
if err != nil { if err != nil {
return err return err
} }
@ -602,14 +618,12 @@ func sendPaymentRequest(ctx *cli.Context,
printJSON := ctx.Bool(jsonFlag.Name) printJSON := ctx.Bool(jsonFlag.Name)
req.NoInflightUpdates = !ctx.Bool(inflightUpdatesFlag.Name) && printJSON req.NoInflightUpdates = !ctx.Bool(inflightUpdatesFlag.Name) && printJSON
stream, err := routerClient.SendPaymentV2(ctxc, req) stream, err := callSendPayment(ctxc, paymentConn, req)
if err != nil { if err != nil {
return err return err
} }
finalState, err := printLivePayment( finalState, err := PrintLivePayment(ctxc, stream, lnClient, printJSON)
ctxc, stream, client, printJSON,
)
if err != nil { if err != nil {
return err return err
} }
@ -667,24 +681,29 @@ func trackPayment(ctx *cli.Context) error {
} }
client := lnrpc.NewLightningClient(conn) client := lnrpc.NewLightningClient(conn)
_, err = printLivePayment(ctxc, stream, client, ctx.Bool(jsonFlag.Name)) _, err = PrintLivePayment(ctxc, stream, client, ctx.Bool(jsonFlag.Name))
return err return err
} }
// printLivePayment receives payment updates from the given stream and either // PaymentResultStream is an interface that abstracts the Recv method of the
// SendPaymentV2 or TrackPaymentV2 client stream.
type PaymentResultStream interface {
Recv() (*lnrpc.Payment, error)
}
// PrintLivePayment receives payment updates from the given stream and either
// outputs them as json or as a more user-friendly formatted table. The table // outputs them as json or as a more user-friendly formatted table. The table
// option uses terminal control codes to rewrite the output. This call // option uses terminal control codes to rewrite the output. This call
// terminates when the payment reaches a final state. // terminates when the payment reaches a final state.
func printLivePayment(ctxc context.Context, func PrintLivePayment(ctxc context.Context, stream PaymentResultStream,
stream routerrpc.Router_TrackPaymentV2Client, lnClient lnrpc.LightningClient, json bool) (*lnrpc.Payment, error) {
client lnrpc.LightningClient, json bool) (*lnrpc.Payment, error) {
// Terminal escape codes aren't supported on Windows, fall back to json. // Terminal escape codes aren't supported on Windows, fall back to json.
if !json && runtime.GOOS == "windows" { if !json && runtime.GOOS == "windows" {
json = true json = true
} }
aliases := newAliasCache(client) aliases := newAliasCache(lnClient)
first := true first := true
var lastLineCount int var lastLineCount int
@ -706,17 +725,17 @@ func printLivePayment(ctxc context.Context,
// Write raw json to stdout. // Write raw json to stdout.
printRespJSON(payment) printRespJSON(payment)
} else { } else {
table := formatPayment(ctxc, payment, aliases) resultTable := formatPayment(ctxc, payment, aliases)
// Clear all previously written lines and print the // Clear all previously written lines and print the
// updated table. // updated table.
clearLines(lastLineCount) clearLines(lastLineCount)
fmt.Print(table) fmt.Print(resultTable)
// Store the number of lines written for the next update // Store the number of lines written for the next update
// pass. // pass.
lastLineCount = 0 lastLineCount = 0
for _, b := range table { for _, b := range resultTable {
if b == '\n' { if b == '\n' {
lastLineCount++ lastLineCount++
} }
@ -874,7 +893,7 @@ var payInvoiceCommand = cli.Command{
This command is a shortcut for 'sendpayment --pay_req='. This command is a shortcut for 'sendpayment --pay_req='.
`, `,
ArgsUsage: "pay_req", ArgsUsage: "pay_req",
Flags: append(paymentFlags(), Flags: append(PaymentFlags(),
cli.Int64Flag{ cli.Int64Flag{
Name: "amt", Name: "amt",
Usage: "(optional) number of satoshis to fulfill the " + Usage: "(optional) number of satoshis to fulfill the " +
@ -885,6 +904,9 @@ var payInvoiceCommand = cli.Command{
} }
func payInvoice(ctx *cli.Context) error { func payInvoice(ctx *cli.Context) error {
conn := getClientConn(ctx, false)
defer conn.Close()
args := ctx.Args() args := ctx.Args()
var payReq string var payReq string
@ -898,14 +920,14 @@ func payInvoice(ctx *cli.Context) error {
} }
req := &routerrpc.SendPaymentRequest{ req := &routerrpc.SendPaymentRequest{
PaymentRequest: stripPrefix(payReq), PaymentRequest: StripPrefix(payReq),
Amt: ctx.Int64("amt"), Amt: ctx.Int64("amt"),
DestCustomRecords: make(map[uint64][]byte), DestCustomRecords: make(map[uint64][]byte),
Amp: ctx.Bool(ampFlag.Name), Amp: ctx.Bool(ampFlag.Name),
Cancelable: ctx.Bool(cancelableFlag.Name), Cancelable: ctx.Bool(cancelableFlag.Name),
} }
return sendPaymentRequest(ctx, req) return SendPaymentRequest(ctx, req, conn, conn, routerRPCSendPayment)
} }
var sendToRouteCommand = cli.Command{ var sendToRouteCommand = cli.Command{
@ -1900,7 +1922,7 @@ func estimateRouteFee(ctx *cli.Context) error {
req.AmtSat = amtSat req.AmtSat = amtSat
case ctx.IsSet("pay_req"): case ctx.IsSet("pay_req"):
req.PaymentRequest = stripPrefix(ctx.String("pay_req")) req.PaymentRequest = StripPrefix(ctx.String("pay_req"))
if ctx.IsSet("timeout") { if ctx.IsSet("timeout") {
req.Timeout = uint32(ctx.Duration("timeout").Seconds()) req.Timeout = uint32(ctx.Duration("timeout").Seconds())
} }

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"fmt" "fmt"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"context" "context"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"errors" "errors"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"fmt" "fmt"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"bufio" "bufio"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"bufio" "bufio"
@ -11,6 +11,7 @@ import (
"io" "io"
"math" "math"
"os" "os"
"regexp"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -41,8 +42,49 @@ const (
defaultUtxoMinConf = 1 defaultUtxoMinConf = 1
) )
var errBadChanPoint = errors.New("expecting chan_point to be in format of: " + var (
"txid:index") errBadChanPoint = errors.New(
"expecting chan_point to be in format of: txid:index",
)
customDataPattern = regexp.MustCompile(
`"custom_channel_data":\s*"([0-9a-f]+)"`,
)
)
// replaceCustomData replaces the custom channel data hex string with the
// decoded custom channel data in the JSON response.
func replaceCustomData(jsonBytes []byte) []byte {
// If there's nothing to replace, return the original JSON.
if !customDataPattern.Match(jsonBytes) {
return jsonBytes
}
replacedBytes := customDataPattern.ReplaceAllFunc(
jsonBytes, func(match []byte) []byte {
encoded := customDataPattern.FindStringSubmatch(
string(match),
)[1]
decoded, err := hex.DecodeString(encoded)
if err != nil {
return match
}
return []byte("\"custom_channel_data\":" +
string(decoded))
},
)
var buf bytes.Buffer
err := json.Indent(&buf, replacedBytes, "", " ")
if err != nil {
// If we can't indent the JSON, it likely means the replacement
// data wasn't correct, so we return the original JSON.
return jsonBytes
}
return buf.Bytes()
}
func getContext() context.Context { func getContext() context.Context {
shutdownInterceptor, err := signal.Intercept() shutdownInterceptor, err := signal.Intercept()
@ -66,9 +108,9 @@ func printJSON(resp interface{}) {
} }
var out bytes.Buffer var out bytes.Buffer
json.Indent(&out, b, "", "\t") _ = json.Indent(&out, b, "", " ")
out.WriteString("\n") _, _ = out.WriteString("\n")
out.WriteTo(os.Stdout) _, _ = out.WriteTo(os.Stdout)
} }
func printRespJSON(resp proto.Message) { func printRespJSON(resp proto.Message) {
@ -78,7 +120,9 @@ func printRespJSON(resp proto.Message) {
return return
} }
fmt.Printf("%s\n", jsonBytes) jsonBytesReplaced := replaceCustomData(jsonBytes)
fmt.Printf("%s\n", jsonBytesReplaced)
} }
// actionDecorator is used to add additional information and error handling // actionDecorator is used to add additional information and error handling
@ -1442,15 +1486,15 @@ func walletBalance(ctx *cli.Context) error {
return nil return nil
} }
var channelBalanceCommand = cli.Command{ var ChannelBalanceCommand = cli.Command{
Name: "channelbalance", Name: "channelbalance",
Category: "Channels", Category: "Channels",
Usage: "Returns the sum of the total available channel balance across " + Usage: "Returns the sum of the total available channel balance across " +
"all open channels.", "all open channels.",
Action: actionDecorator(channelBalance), Action: actionDecorator(ChannelBalance),
} }
func channelBalance(ctx *cli.Context) error { func ChannelBalance(ctx *cli.Context) error {
ctxc := getContext() ctxc := getContext()
client, cleanUp := getClient(ctx) client, cleanUp := getClient(ctx)
defer cleanUp() defer cleanUp()
@ -1575,7 +1619,7 @@ func pendingChannels(ctx *cli.Context) error {
return nil return nil
} }
var listChannelsCommand = cli.Command{ var ListChannelsCommand = cli.Command{
Name: "listchannels", Name: "listchannels",
Category: "Channels", Category: "Channels",
Usage: "List all open channels.", Usage: "List all open channels.",
@ -1608,7 +1652,7 @@ var listChannelsCommand = cli.Command{
"order to improve performance", "order to improve performance",
}, },
}, },
Action: actionDecorator(listChannels), Action: actionDecorator(ListChannels),
} }
var listAliasesCommand = cli.Command{ var listAliasesCommand = cli.Command{
@ -1616,10 +1660,10 @@ var listAliasesCommand = cli.Command{
Category: "Channels", Category: "Channels",
Usage: "List all aliases.", Usage: "List all aliases.",
Flags: []cli.Flag{}, Flags: []cli.Flag{},
Action: actionDecorator(listaliases), Action: actionDecorator(listAliases),
} }
func listaliases(ctx *cli.Context) error { func listAliases(ctx *cli.Context) error {
ctxc := getContext() ctxc := getContext()
client, cleanUp := getClient(ctx) client, cleanUp := getClient(ctx)
defer cleanUp() defer cleanUp()
@ -1636,7 +1680,7 @@ func listaliases(ctx *cli.Context) error {
return nil return nil
} }
func listChannels(ctx *cli.Context) error { func ListChannels(ctx *cli.Context) error {
ctxc := getContext() ctxc := getContext()
client, cleanUp := getClient(ctx) client, cleanUp := getClient(ctx)
defer cleanUp() defer cleanUp()

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"encoding/hex" "encoding/hex"
@ -120,3 +120,74 @@ func TestParseTimeLockDelta(t *testing.T) {
} }
} }
} }
// TestReplaceCustomData tests that hex encoded custom data can be formatted as
// JSON in the console output.
func TestReplaceCustomData(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
data string
replaceData string
expected string
}{
{
name: "no replacement necessary",
data: "foo",
expected: "foo",
},
{
name: "valid json with replacement",
data: "{\"foo\":\"bar\",\"custom_channel_data\":\"" +
hex.EncodeToString([]byte(
"{\"bar\":\"baz\"}",
)) + "\"}",
expected: `{
"foo": "bar",
"custom_channel_data": {
"bar": "baz"
}
}`,
},
{
name: "valid json with replacement and space",
data: "{\"foo\":\"bar\",\"custom_channel_data\": \"" +
hex.EncodeToString([]byte(
"{\"bar\":\"baz\"}",
)) + "\"}",
expected: `{
"foo": "bar",
"custom_channel_data": {
"bar": "baz"
}
}`,
},
{
name: "doesn't match pattern, returned identical",
data: "this ain't even json, and no custom data " +
"either",
expected: "this ain't even json, and no custom data " +
"either",
},
{
name: "invalid json",
data: "this ain't json, " +
"\"custom_channel_data\":\"a\"",
expected: "this ain't json, " +
"\"custom_channel_data\":\"a\"",
},
{
name: "valid json, invalid hex, just formatted",
data: "{\"custom_channel_data\":\"f\"}",
expected: "{\n \"custom_channel_data\": \"f\"\n}",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := replaceCustomData([]byte(tc.data))
require.Equal(t, tc.expected, string(result))
})
}
}

View file

@ -1,7 +1,7 @@
//go:build dev //go:build dev
// +build dev // +build dev
package main package commands
import ( import (
"fmt" "fmt"

View file

@ -1,7 +1,7 @@
//go:build !dev //go:build !dev
// +build !dev // +build !dev
package main package commands
import "github.com/urfave/cli" import "github.com/urfave/cli"

View file

@ -1,7 +1,7 @@
//go:build invoicesrpc //go:build invoicesrpc
// +build invoicesrpc // +build invoicesrpc
package main package commands
import ( import (
"encoding/hex" "encoding/hex"

View file

@ -1,7 +1,7 @@
//go:build !invoicesrpc //go:build !invoicesrpc
// +build !invoicesrpc // +build !invoicesrpc
package main package commands
import "github.com/urfave/cli" import "github.com/urfave/cli"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"encoding/base64" "encoding/base64"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"encoding/hex" "encoding/hex"

601
cmd/commands/main.go Normal file
View file

@ -0,0 +1,601 @@
// Copyright (c) 2013-2017 The btcsuite developers
// Copyright (c) 2015-2016 The Decred developers
// Copyright (C) 2015-2024 The Lightning Network Developers
package commands
import (
"context"
"crypto/tls"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"syscall"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightningnetwork/lnd"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/lncfg"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/macaroons"
"github.com/lightningnetwork/lnd/tor"
"github.com/urfave/cli"
"golang.org/x/term"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
)
const (
defaultDataDir = "data"
defaultChainSubDir = "chain"
defaultTLSCertFilename = "tls.cert"
defaultMacaroonFilename = "admin.macaroon"
defaultRPCPort = "10009"
defaultRPCHostPort = "localhost:" + defaultRPCPort
envVarRPCServer = "LNCLI_RPCSERVER"
envVarLNDDir = "LNCLI_LNDDIR"
envVarSOCKSProxy = "LNCLI_SOCKSPROXY"
envVarTLSCertPath = "LNCLI_TLSCERTPATH"
envVarChain = "LNCLI_CHAIN"
envVarNetwork = "LNCLI_NETWORK"
envVarMacaroonPath = "LNCLI_MACAROONPATH"
envVarMacaroonTimeout = "LNCLI_MACAROONTIMEOUT"
envVarMacaroonIP = "LNCLI_MACAROONIP"
envVarProfile = "LNCLI_PROFILE"
envVarMacFromJar = "LNCLI_MACFROMJAR"
)
var (
DefaultLndDir = btcutil.AppDataDir("lnd", false)
defaultTLSCertPath = filepath.Join(
DefaultLndDir, defaultTLSCertFilename,
)
// maxMsgRecvSize is the largest message our client will receive. We
// set this to 200MiB atm.
maxMsgRecvSize = grpc.MaxCallRecvMsgSize(lnrpc.MaxGrpcMsgSize)
)
func fatal(err error) {
fmt.Fprintf(os.Stderr, "[lncli] %v\n", err)
os.Exit(1)
}
func getWalletUnlockerClient(ctx *cli.Context) (lnrpc.WalletUnlockerClient,
func()) {
conn := getClientConn(ctx, true)
cleanUp := func() {
conn.Close()
}
return lnrpc.NewWalletUnlockerClient(conn), cleanUp
}
func getStateServiceClient(ctx *cli.Context) (lnrpc.StateClient, func()) {
conn := getClientConn(ctx, true)
cleanUp := func() {
conn.Close()
}
return lnrpc.NewStateClient(conn), cleanUp
}
func getClient(ctx *cli.Context) (lnrpc.LightningClient, func()) {
conn := getClientConn(ctx, false)
cleanUp := func() {
conn.Close()
}
return lnrpc.NewLightningClient(conn), cleanUp
}
func getClientConn(ctx *cli.Context, skipMacaroons bool) *grpc.ClientConn {
// First, we'll get the selected stored profile or an ephemeral one
// created from the global options in the CLI context.
profile, err := getGlobalOptions(ctx, skipMacaroons)
if err != nil {
fatal(fmt.Errorf("could not load global options: %w", err))
}
// Create a dial options array.
opts := []grpc.DialOption{
grpc.WithUnaryInterceptor(
addMetadataUnaryInterceptor(profile.Metadata),
),
grpc.WithStreamInterceptor(
addMetaDataStreamInterceptor(profile.Metadata),
),
}
if profile.Insecure {
opts = append(opts, grpc.WithInsecure())
} else {
// Load the specified TLS certificate.
certPool, err := profile.cert()
if err != nil {
fatal(fmt.Errorf("could not create cert pool: %w", err))
}
// Build transport credentials from the certificate pool. If
// there is no certificate pool, we expect the server to use a
// non-self-signed certificate such as a certificate obtained
// from Let's Encrypt.
var creds credentials.TransportCredentials
if certPool != nil {
creds = credentials.NewClientTLSFromCert(certPool, "")
} else {
// Fallback to the system pool. Using an empty tls
// config is an alternative to x509.SystemCertPool().
// That call is not supported on Windows.
creds = credentials.NewTLS(&tls.Config{})
}
opts = append(opts, grpc.WithTransportCredentials(creds))
}
// Only process macaroon credentials if --no-macaroons isn't set and
// if we're not skipping macaroon processing.
if !profile.NoMacaroons && !skipMacaroons {
// Find out which macaroon to load.
macName := profile.Macaroons.Default
if ctx.GlobalIsSet("macfromjar") {
macName = ctx.GlobalString("macfromjar")
}
var macEntry *macaroonEntry
for _, entry := range profile.Macaroons.Jar {
if entry.Name == macName {
macEntry = entry
break
}
}
if macEntry == nil {
fatal(fmt.Errorf("macaroon with name '%s' not found "+
"in profile", macName))
}
// Get and possibly decrypt the specified macaroon.
//
// TODO(guggero): Make it possible to cache the password so we
// don't need to ask for it every time.
mac, err := macEntry.loadMacaroon(readPassword)
if err != nil {
fatal(fmt.Errorf("could not load macaroon: %w", err))
}
macConstraints := []macaroons.Constraint{
// We add a time-based constraint to prevent replay of
// the macaroon. It's good for 60 seconds by default to
// make up for any discrepancy between client and server
// clocks, but leaking the macaroon before it becomes
// invalid makes it possible for an attacker to reuse
// the macaroon. In addition, the validity time of the
// macaroon is extended by the time the server clock is
// behind the client clock, or shortened by the time the
// server clock is ahead of the client clock (or invalid
// altogether if, in the latter case, this time is more
// than 60 seconds).
// TODO(aakselrod): add better anti-replay protection.
macaroons.TimeoutConstraint(profile.Macaroons.Timeout),
// Lock macaroon down to a specific IP address.
macaroons.IPLockConstraint(profile.Macaroons.IP),
// ... Add more constraints if needed.
}
// Apply constraints to the macaroon.
constrainedMac, err := macaroons.AddConstraints(
mac, macConstraints...,
)
if err != nil {
fatal(err)
}
// Now we append the macaroon credentials to the dial options.
cred, err := macaroons.NewMacaroonCredential(constrainedMac)
if err != nil {
fatal(fmt.Errorf("error cloning mac: %w", err))
}
opts = append(opts, grpc.WithPerRPCCredentials(cred))
}
// If a socksproxy server is specified we use a tor dialer
// to connect to the grpc server.
if ctx.GlobalIsSet("socksproxy") {
socksProxy := ctx.GlobalString("socksproxy")
torDialer := func(_ context.Context, addr string) (net.Conn,
error) {
return tor.Dial(
addr, socksProxy, false, false,
tor.DefaultConnTimeout,
)
}
opts = append(opts, grpc.WithContextDialer(torDialer))
} else {
// We need to use a custom dialer so we can also connect to
// unix sockets and not just TCP addresses.
genericDialer := lncfg.ClientAddressDialer(defaultRPCPort)
opts = append(opts, grpc.WithContextDialer(genericDialer))
}
opts = append(opts, grpc.WithDefaultCallOptions(maxMsgRecvSize))
conn, err := grpc.Dial(profile.RPCServer, opts...)
if err != nil {
fatal(fmt.Errorf("unable to connect to RPC server: %w", err))
}
return conn
}
// addMetadataUnaryInterceptor returns a grpc client side interceptor that
// appends any key-value metadata strings to the outgoing context of a grpc
// unary call.
func addMetadataUnaryInterceptor(
md map[string]string) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker,
opts ...grpc.CallOption) error {
outCtx := contextWithMetadata(ctx, md)
return invoker(outCtx, method, req, reply, cc, opts...)
}
}
// addMetaDataStreamInterceptor returns a grpc client side interceptor that
// appends any key-value metadata strings to the outgoing context of a grpc
// stream call.
func addMetaDataStreamInterceptor(
md map[string]string) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc,
cc *grpc.ClientConn, method string, streamer grpc.Streamer,
opts ...grpc.CallOption) (grpc.ClientStream, error) {
outCtx := contextWithMetadata(ctx, md)
return streamer(outCtx, desc, cc, method, opts...)
}
}
// contextWithMetaData appends the given metadata key-value pairs to the given
// context.
func contextWithMetadata(ctx context.Context,
md map[string]string) context.Context {
kvPairs := make([]string, 0, 2*len(md))
for k, v := range md {
kvPairs = append(kvPairs, k, v)
}
return metadata.AppendToOutgoingContext(ctx, kvPairs...)
}
// extractPathArgs parses the TLS certificate and macaroon paths from the
// command.
func extractPathArgs(ctx *cli.Context) (string, string, error) {
network := strings.ToLower(ctx.GlobalString("network"))
switch network {
case "mainnet", "testnet", "regtest", "simnet", "signet":
default:
return "", "", fmt.Errorf("unknown network: %v", network)
}
// We'll now fetch the lnddir so we can make a decision on how to
// properly read the macaroons (if needed) and also the cert. This will
// either be the default, or will have been overwritten by the end
// user.
lndDir := lncfg.CleanAndExpandPath(ctx.GlobalString("lnddir"))
// If the macaroon path as been manually provided, then we'll only
// target the specified file.
var macPath string
if ctx.GlobalString("macaroonpath") != "" {
macPath = lncfg.CleanAndExpandPath(ctx.GlobalString(
"macaroonpath",
))
} else {
// Otherwise, we'll go into the path:
// lnddir/data/chain/<chain>/<network> in order to fetch the
// macaroon that we need.
macPath = filepath.Join(
lndDir, defaultDataDir, defaultChainSubDir,
lnd.BitcoinChainName, network, defaultMacaroonFilename,
)
}
tlsCertPath := lncfg.CleanAndExpandPath(ctx.GlobalString("tlscertpath"))
// If a custom lnd directory was set, we'll also check if custom paths
// for the TLS cert and macaroon file were set as well. If not, we'll
// override their paths so they can be found within the custom lnd
// directory set. This allows us to set a custom lnd directory, along
// with custom paths to the TLS cert and macaroon file.
if lndDir != DefaultLndDir {
tlsCertPath = filepath.Join(lndDir, defaultTLSCertFilename)
}
return tlsCertPath, macPath, nil
}
// checkNotBothSet accepts two flag names, a and b, and checks that only flag a
// or flag b can be set, but not both. It returns the name of the flag or an
// error.
func checkNotBothSet(ctx *cli.Context, a, b string) (string, error) {
if ctx.IsSet(a) && ctx.IsSet(b) {
return "", fmt.Errorf(
"either %s or %s should be set, but not both", a, b,
)
}
if ctx.IsSet(a) {
return a, nil
}
return b, nil
}
func Main() {
app := cli.NewApp()
app.Name = "lncli"
app.Version = build.Version() + " commit=" + build.Commit
app.Usage = "control plane for your Lightning Network Daemon (lnd)"
app.Flags = []cli.Flag{
cli.StringFlag{
Name: "rpcserver",
Value: defaultRPCHostPort,
Usage: "The host:port of LN daemon.",
EnvVar: envVarRPCServer,
},
cli.StringFlag{
Name: "lnddir",
Value: DefaultLndDir,
Usage: "The path to lnd's base directory.",
TakesFile: true,
EnvVar: envVarLNDDir,
},
cli.StringFlag{
Name: "socksproxy",
Usage: "The host:port of a SOCKS proxy through " +
"which all connections to the LN " +
"daemon will be established over.",
EnvVar: envVarSOCKSProxy,
},
cli.StringFlag{
Name: "tlscertpath",
Value: defaultTLSCertPath,
Usage: "The path to lnd's TLS certificate.",
TakesFile: true,
EnvVar: envVarTLSCertPath,
},
cli.StringFlag{
Name: "chain, c",
Usage: "The chain lnd is running on, e.g. bitcoin.",
Value: "bitcoin",
EnvVar: envVarChain,
},
cli.StringFlag{
Name: "network, n",
Usage: "The network lnd is running on, e.g. mainnet, " +
"testnet, etc.",
Value: "mainnet",
EnvVar: envVarNetwork,
},
cli.BoolFlag{
Name: "no-macaroons",
Usage: "Disable macaroon authentication.",
},
cli.StringFlag{
Name: "macaroonpath",
Usage: "The path to macaroon file.",
TakesFile: true,
EnvVar: envVarMacaroonPath,
},
cli.Int64Flag{
Name: "macaroontimeout",
Value: 60,
Usage: "Anti-replay macaroon validity time in " +
"seconds.",
EnvVar: envVarMacaroonTimeout,
},
cli.StringFlag{
Name: "macaroonip",
Usage: "If set, lock macaroon to specific IP address.",
EnvVar: envVarMacaroonIP,
},
cli.StringFlag{
Name: "profile, p",
Usage: "Instead of reading settings from command " +
"line parameters or using the default " +
"profile, use a specific profile. If " +
"a default profile is set, this flag can be " +
"set to an empty string to disable reading " +
"values from the profiles file.",
EnvVar: envVarProfile,
},
cli.StringFlag{
Name: "macfromjar",
Usage: "Use this macaroon from the profile's " +
"macaroon jar instead of the default one. " +
"Can only be used if profiles are defined.",
EnvVar: envVarMacFromJar,
},
cli.StringSliceFlag{
Name: "metadata",
Usage: "This flag can be used to specify a key-value " +
"pair that should be appended to the " +
"outgoing context before the request is sent " +
"to lnd. This flag may be specified multiple " +
"times. The format is: \"key:value\".",
},
cli.BoolFlag{
Name: "insecure",
Usage: "Connect to the rpc server without TLS " +
"authentication",
Hidden: true,
},
}
app.Commands = []cli.Command{
createCommand,
createWatchOnlyCommand,
unlockCommand,
changePasswordCommand,
newAddressCommand,
estimateFeeCommand,
sendManyCommand,
sendCoinsCommand,
listUnspentCommand,
connectCommand,
disconnectCommand,
openChannelCommand,
batchOpenChannelCommand,
closeChannelCommand,
closeAllChannelsCommand,
abandonChannelCommand,
listPeersCommand,
walletBalanceCommand,
ChannelBalanceCommand,
getInfoCommand,
getDebugInfoCommand,
encryptDebugPackageCommand,
decryptDebugPackageCommand,
getRecoveryInfoCommand,
pendingChannelsCommand,
SendPaymentCommand,
payInvoiceCommand,
sendToRouteCommand,
AddInvoiceCommand,
lookupInvoiceCommand,
listInvoicesCommand,
ListChannelsCommand,
closedChannelsCommand,
listPaymentsCommand,
describeGraphCommand,
getNodeMetricsCommand,
getChanInfoCommand,
getNodeInfoCommand,
queryRoutesCommand,
getNetworkInfoCommand,
debugLevelCommand,
decodePayReqCommand,
listChainTxnsCommand,
stopCommand,
signMessageCommand,
verifyMessageCommand,
feeReportCommand,
updateChannelPolicyCommand,
forwardingHistoryCommand,
exportChanBackupCommand,
verifyChanBackupCommand,
restoreChanBackupCommand,
bakeMacaroonCommand,
listMacaroonIDsCommand,
deleteMacaroonIDCommand,
listPermissionsCommand,
printMacaroonCommand,
constrainMacaroonCommand,
trackPaymentCommand,
versionCommand,
profileSubCommand,
getStateCommand,
deletePaymentsCommand,
sendCustomCommand,
subscribeCustomCommand,
fishCompletionCommand,
listAliasesCommand,
estimateRouteFeeCommand,
generateManPageCommand,
}
// Add any extra commands determined by build flags.
app.Commands = append(app.Commands, autopilotCommands()...)
app.Commands = append(app.Commands, invoicesCommands()...)
app.Commands = append(app.Commands, neutrinoCommands()...)
app.Commands = append(app.Commands, routerCommands()...)
app.Commands = append(app.Commands, walletCommands()...)
app.Commands = append(app.Commands, watchtowerCommands()...)
app.Commands = append(app.Commands, wtclientCommands()...)
app.Commands = append(app.Commands, devCommands()...)
app.Commands = append(app.Commands, peersCommands()...)
app.Commands = append(app.Commands, chainCommands()...)
if err := app.Run(os.Args); err != nil {
fatal(err)
}
}
// readPassword reads a password from the terminal. This requires there to be an
// actual TTY so passing in a password from stdin won't work.
func readPassword(text string) ([]byte, error) {
fmt.Print(text)
// The variable syscall.Stdin is of a different type in the Windows API
// that's why we need the explicit cast. And of course the linter
// doesn't like it either.
pw, err := term.ReadPassword(int(syscall.Stdin)) //nolint:unconvert
fmt.Println()
return pw, err
}
// networkParams parses the global network flag into a chaincfg.Params.
func networkParams(ctx *cli.Context) (*chaincfg.Params, error) {
network := strings.ToLower(ctx.GlobalString("network"))
switch network {
case "mainnet":
return &chaincfg.MainNetParams, nil
case "testnet":
return &chaincfg.TestNet3Params, nil
case "regtest":
return &chaincfg.RegressionNetParams, nil
case "simnet":
return &chaincfg.SimNetParams, nil
case "signet":
return &chaincfg.SigNetParams, nil
default:
return nil, fmt.Errorf("unknown network: %v", network)
}
}
// parseCoinSelectionStrategy parses a coin selection strategy string
// from the CLI to its lnrpc.CoinSelectionStrategy counterpart proto type.
func parseCoinSelectionStrategy(ctx *cli.Context) (
lnrpc.CoinSelectionStrategy, error) {
strategy := ctx.String(coinSelectionStrategyFlag.Name)
if !ctx.IsSet(coinSelectionStrategyFlag.Name) {
return lnrpc.CoinSelectionStrategy_STRATEGY_USE_GLOBAL_CONFIG,
nil
}
switch strategy {
case "global-config":
return lnrpc.CoinSelectionStrategy_STRATEGY_USE_GLOBAL_CONFIG,
nil
case "largest":
return lnrpc.CoinSelectionStrategy_STRATEGY_LARGEST, nil
case "random":
return lnrpc.CoinSelectionStrategy_STRATEGY_RANDOM, nil
default:
return 0, fmt.Errorf("unknown coin selection strategy "+
"%v", strategy)
}
}

View file

@ -1,7 +1,7 @@
//go:build neutrinorpc //go:build neutrinorpc
// +build neutrinorpc // +build neutrinorpc
package main package commands
import ( import (
"github.com/lightningnetwork/lnd/lnrpc/neutrinorpc" "github.com/lightningnetwork/lnd/lnrpc/neutrinorpc"

View file

@ -1,7 +1,7 @@
//go:build !neutrinorpc //go:build !neutrinorpc
// +build !neutrinorpc // +build !neutrinorpc
package main package commands
import "github.com/urfave/cli" import "github.com/urfave/cli"

View file

@ -1,7 +1,7 @@
//go:build peersrpc //go:build peersrpc
// +build peersrpc // +build peersrpc
package main package commands
import ( import (
"fmt" "fmt"

View file

@ -1,7 +1,7 @@
//go:build !peersrpc //go:build !peersrpc
// +build !peersrpc // +build !peersrpc
package main package commands
import "github.com/urfave/cli" import "github.com/urfave/cli"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"bytes" "bytes"

View file

@ -1,4 +1,4 @@
package main package commands
import "github.com/urfave/cli" import "github.com/urfave/cli"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"encoding/hex" "encoding/hex"

View file

@ -1,7 +1,7 @@
//go:build walletrpc //go:build walletrpc
// +build walletrpc // +build walletrpc
package main package commands
import ( import (
"bytes" "bytes"

View file

@ -1,7 +1,7 @@
//go:build !walletrpc //go:build !walletrpc
// +build !walletrpc // +build !walletrpc
package main package commands
import "github.com/urfave/cli" import "github.com/urfave/cli"

View file

@ -1,4 +1,4 @@
package main package commands
import "github.com/lightningnetwork/lnd/lnrpc/walletrpc" import "github.com/lightningnetwork/lnd/lnrpc/walletrpc"

View file

@ -1,7 +1,7 @@
//go:build watchtowerrpc //go:build watchtowerrpc
// +build watchtowerrpc // +build watchtowerrpc
package main package commands
import ( import (
"github.com/lightningnetwork/lnd/lnrpc/watchtowerrpc" "github.com/lightningnetwork/lnd/lnrpc/watchtowerrpc"

View file

@ -1,7 +1,7 @@
//go:build !watchtowerrpc //go:build !watchtowerrpc
// +build !watchtowerrpc // +build !watchtowerrpc
package main package commands
import "github.com/urfave/cli" import "github.com/urfave/cli"

View file

@ -1,4 +1,4 @@
package main package commands
import ( import (
"encoding/hex" "encoding/hex"

View file

@ -1,594 +1,11 @@
// Copyright (c) 2013-2017 The btcsuite developers // Copyright (c) 2013-2017 The btcsuite developers
// Copyright (c) 2015-2016 The Decred developers // Copyright (c) 2015-2016 The Decred developers
// Copyright (C) 2015-2022 The Lightning Network Developers // Copyright (C) 2015-2024 The Lightning Network Developers
package main package main
import ( import "github.com/lightningnetwork/lnd/cmd/commands"
"context"
"crypto/tls"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"syscall"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightningnetwork/lnd"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/lncfg"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/macaroons"
"github.com/lightningnetwork/lnd/tor"
"github.com/urfave/cli"
"golang.org/x/term"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
)
const (
defaultDataDir = "data"
defaultChainSubDir = "chain"
defaultTLSCertFilename = "tls.cert"
defaultMacaroonFilename = "admin.macaroon"
defaultRPCPort = "10009"
defaultRPCHostPort = "localhost:" + defaultRPCPort
envVarRPCServer = "LNCLI_RPCSERVER"
envVarLNDDir = "LNCLI_LNDDIR"
envVarSOCKSProxy = "LNCLI_SOCKSPROXY"
envVarTLSCertPath = "LNCLI_TLSCERTPATH"
envVarChain = "LNCLI_CHAIN"
envVarNetwork = "LNCLI_NETWORK"
envVarMacaroonPath = "LNCLI_MACAROONPATH"
envVarMacaroonTimeout = "LNCLI_MACAROONTIMEOUT"
envVarMacaroonIP = "LNCLI_MACAROONIP"
envVarProfile = "LNCLI_PROFILE"
envVarMacFromJar = "LNCLI_MACFROMJAR"
)
var (
defaultLndDir = btcutil.AppDataDir("lnd", false)
defaultTLSCertPath = filepath.Join(defaultLndDir, defaultTLSCertFilename)
// maxMsgRecvSize is the largest message our client will receive. We
// set this to 200MiB atm.
maxMsgRecvSize = grpc.MaxCallRecvMsgSize(lnrpc.MaxGrpcMsgSize)
)
func fatal(err error) {
fmt.Fprintf(os.Stderr, "[lncli] %v\n", err)
os.Exit(1)
}
func getWalletUnlockerClient(ctx *cli.Context) (lnrpc.WalletUnlockerClient, func()) {
conn := getClientConn(ctx, true)
cleanUp := func() {
conn.Close()
}
return lnrpc.NewWalletUnlockerClient(conn), cleanUp
}
func getStateServiceClient(ctx *cli.Context) (lnrpc.StateClient, func()) {
conn := getClientConn(ctx, true)
cleanUp := func() {
conn.Close()
}
return lnrpc.NewStateClient(conn), cleanUp
}
func getClient(ctx *cli.Context) (lnrpc.LightningClient, func()) {
conn := getClientConn(ctx, false)
cleanUp := func() {
conn.Close()
}
return lnrpc.NewLightningClient(conn), cleanUp
}
func getClientConn(ctx *cli.Context, skipMacaroons bool) *grpc.ClientConn {
// First, we'll get the selected stored profile or an ephemeral one
// created from the global options in the CLI context.
profile, err := getGlobalOptions(ctx, skipMacaroons)
if err != nil {
fatal(fmt.Errorf("could not load global options: %w", err))
}
// Create a dial options array.
opts := []grpc.DialOption{
grpc.WithUnaryInterceptor(
addMetadataUnaryInterceptor(profile.Metadata),
),
grpc.WithStreamInterceptor(
addMetaDataStreamInterceptor(profile.Metadata),
),
}
if profile.Insecure {
opts = append(opts, grpc.WithInsecure())
} else {
// Load the specified TLS certificate.
certPool, err := profile.cert()
if err != nil {
fatal(fmt.Errorf("could not create cert pool: %w", err))
}
// Build transport credentials from the certificate pool. If
// there is no certificate pool, we expect the server to use a
// non-self-signed certificate such as a certificate obtained
// from Let's Encrypt.
var creds credentials.TransportCredentials
if certPool != nil {
creds = credentials.NewClientTLSFromCert(certPool, "")
} else {
// Fallback to the system pool. Using an empty tls
// config is an alternative to x509.SystemCertPool().
// That call is not supported on Windows.
creds = credentials.NewTLS(&tls.Config{})
}
opts = append(opts, grpc.WithTransportCredentials(creds))
}
// Only process macaroon credentials if --no-macaroons isn't set and
// if we're not skipping macaroon processing.
if !profile.NoMacaroons && !skipMacaroons {
// Find out which macaroon to load.
macName := profile.Macaroons.Default
if ctx.GlobalIsSet("macfromjar") {
macName = ctx.GlobalString("macfromjar")
}
var macEntry *macaroonEntry
for _, entry := range profile.Macaroons.Jar {
if entry.Name == macName {
macEntry = entry
break
}
}
if macEntry == nil {
fatal(fmt.Errorf("macaroon with name '%s' not found "+
"in profile", macName))
}
// Get and possibly decrypt the specified macaroon.
//
// TODO(guggero): Make it possible to cache the password so we
// don't need to ask for it every time.
mac, err := macEntry.loadMacaroon(readPassword)
if err != nil {
fatal(fmt.Errorf("could not load macaroon: %w", err))
}
macConstraints := []macaroons.Constraint{
// We add a time-based constraint to prevent replay of
// the macaroon. It's good for 60 seconds by default to
// make up for any discrepancy between client and server
// clocks, but leaking the macaroon before it becomes
// invalid makes it possible for an attacker to reuse
// the macaroon. In addition, the validity time of the
// macaroon is extended by the time the server clock is
// behind the client clock, or shortened by the time the
// server clock is ahead of the client clock (or invalid
// altogether if, in the latter case, this time is more
// than 60 seconds).
// TODO(aakselrod): add better anti-replay protection.
macaroons.TimeoutConstraint(profile.Macaroons.Timeout),
// Lock macaroon down to a specific IP address.
macaroons.IPLockConstraint(profile.Macaroons.IP),
// ... Add more constraints if needed.
}
// Apply constraints to the macaroon.
constrainedMac, err := macaroons.AddConstraints(
mac, macConstraints...,
)
if err != nil {
fatal(err)
}
// Now we append the macaroon credentials to the dial options.
cred, err := macaroons.NewMacaroonCredential(constrainedMac)
if err != nil {
fatal(fmt.Errorf("error cloning mac: %w", err))
}
opts = append(opts, grpc.WithPerRPCCredentials(cred))
}
// If a socksproxy server is specified we use a tor dialer
// to connect to the grpc server.
if ctx.GlobalIsSet("socksproxy") {
socksProxy := ctx.GlobalString("socksproxy")
torDialer := func(_ context.Context, addr string) (net.Conn,
error) {
return tor.Dial(
addr, socksProxy, false, false,
tor.DefaultConnTimeout,
)
}
opts = append(opts, grpc.WithContextDialer(torDialer))
} else {
// We need to use a custom dialer so we can also connect to
// unix sockets and not just TCP addresses.
genericDialer := lncfg.ClientAddressDialer(defaultRPCPort)
opts = append(opts, grpc.WithContextDialer(genericDialer))
}
opts = append(opts, grpc.WithDefaultCallOptions(maxMsgRecvSize))
conn, err := grpc.Dial(profile.RPCServer, opts...)
if err != nil {
fatal(fmt.Errorf("unable to connect to RPC server: %w", err))
}
return conn
}
// addMetadataUnaryInterceptor returns a grpc client side interceptor that
// appends any key-value metadata strings to the outgoing context of a grpc
// unary call.
func addMetadataUnaryInterceptor(
md map[string]string) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker,
opts ...grpc.CallOption) error {
outCtx := contextWithMetadata(ctx, md)
return invoker(outCtx, method, req, reply, cc, opts...)
}
}
// addMetaDataStreamInterceptor returns a grpc client side interceptor that
// appends any key-value metadata strings to the outgoing context of a grpc
// stream call.
func addMetaDataStreamInterceptor(
md map[string]string) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc,
cc *grpc.ClientConn, method string, streamer grpc.Streamer,
opts ...grpc.CallOption) (grpc.ClientStream, error) {
outCtx := contextWithMetadata(ctx, md)
return streamer(outCtx, desc, cc, method, opts...)
}
}
// contextWithMetaData appends the given metadata key-value pairs to the given
// context.
func contextWithMetadata(ctx context.Context,
md map[string]string) context.Context {
kvPairs := make([]string, 0, 2*len(md))
for k, v := range md {
kvPairs = append(kvPairs, k, v)
}
return metadata.AppendToOutgoingContext(ctx, kvPairs...)
}
// extractPathArgs parses the TLS certificate and macaroon paths from the
// command.
func extractPathArgs(ctx *cli.Context) (string, string, error) {
network := strings.ToLower(ctx.GlobalString("network"))
switch network {
case "mainnet", "testnet", "regtest", "simnet", "signet":
default:
return "", "", fmt.Errorf("unknown network: %v", network)
}
// We'll now fetch the lnddir so we can make a decision on how to
// properly read the macaroons (if needed) and also the cert. This will
// either be the default, or will have been overwritten by the end
// user.
lndDir := lncfg.CleanAndExpandPath(ctx.GlobalString("lnddir"))
// If the macaroon path as been manually provided, then we'll only
// target the specified file.
var macPath string
if ctx.GlobalString("macaroonpath") != "" {
macPath = lncfg.CleanAndExpandPath(ctx.GlobalString("macaroonpath"))
} else {
// Otherwise, we'll go into the path:
// lnddir/data/chain/<chain>/<network> in order to fetch the
// macaroon that we need.
macPath = filepath.Join(
lndDir, defaultDataDir, defaultChainSubDir,
lnd.BitcoinChainName, network, defaultMacaroonFilename,
)
}
tlsCertPath := lncfg.CleanAndExpandPath(ctx.GlobalString("tlscertpath"))
// If a custom lnd directory was set, we'll also check if custom paths
// for the TLS cert and macaroon file were set as well. If not, we'll
// override their paths so they can be found within the custom lnd
// directory set. This allows us to set a custom lnd directory, along
// with custom paths to the TLS cert and macaroon file.
if lndDir != defaultLndDir {
tlsCertPath = filepath.Join(lndDir, defaultTLSCertFilename)
}
return tlsCertPath, macPath, nil
}
// checkNotBothSet accepts two flag names, a and b, and checks that only flag a
// or flag b can be set, but not both. It returns the name of the flag or an
// error.
func checkNotBothSet(ctx *cli.Context, a, b string) (string, error) {
if ctx.IsSet(a) && ctx.IsSet(b) {
return "", fmt.Errorf(
"either %s or %s should be set, but not both", a, b,
)
}
if ctx.IsSet(a) {
return a, nil
}
return b, nil
}
func main() { func main() {
app := cli.NewApp() commands.Main()
app.Name = "lncli"
app.Version = build.Version() + " commit=" + build.Commit
app.Usage = "control plane for your Lightning Network Daemon (lnd)"
app.Flags = []cli.Flag{
cli.StringFlag{
Name: "rpcserver",
Value: defaultRPCHostPort,
Usage: "The host:port of LN daemon.",
EnvVar: envVarRPCServer,
},
cli.StringFlag{
Name: "lnddir",
Value: defaultLndDir,
Usage: "The path to lnd's base directory.",
TakesFile: true,
EnvVar: envVarLNDDir,
},
cli.StringFlag{
Name: "socksproxy",
Usage: "The host:port of a SOCKS proxy through " +
"which all connections to the LN " +
"daemon will be established over.",
EnvVar: envVarSOCKSProxy,
},
cli.StringFlag{
Name: "tlscertpath",
Value: defaultTLSCertPath,
Usage: "The path to lnd's TLS certificate.",
TakesFile: true,
EnvVar: envVarTLSCertPath,
},
cli.StringFlag{
Name: "chain, c",
Usage: "The chain lnd is running on, e.g. bitcoin.",
Value: "bitcoin",
EnvVar: envVarChain,
},
cli.StringFlag{
Name: "network, n",
Usage: "The network lnd is running on, e.g. mainnet, " +
"testnet, etc.",
Value: "mainnet",
EnvVar: envVarNetwork,
},
cli.BoolFlag{
Name: "no-macaroons",
Usage: "Disable macaroon authentication.",
},
cli.StringFlag{
Name: "macaroonpath",
Usage: "The path to macaroon file.",
TakesFile: true,
EnvVar: envVarMacaroonPath,
},
cli.Int64Flag{
Name: "macaroontimeout",
Value: 60,
Usage: "Anti-replay macaroon validity time in " +
"seconds.",
EnvVar: envVarMacaroonTimeout,
},
cli.StringFlag{
Name: "macaroonip",
Usage: "If set, lock macaroon to specific IP address.",
EnvVar: envVarMacaroonIP,
},
cli.StringFlag{
Name: "profile, p",
Usage: "Instead of reading settings from command " +
"line parameters or using the default " +
"profile, use a specific profile. If " +
"a default profile is set, this flag can be " +
"set to an empty string to disable reading " +
"values from the profiles file.",
EnvVar: envVarProfile,
},
cli.StringFlag{
Name: "macfromjar",
Usage: "Use this macaroon from the profile's " +
"macaroon jar instead of the default one. " +
"Can only be used if profiles are defined.",
EnvVar: envVarMacFromJar,
},
cli.StringSliceFlag{
Name: "metadata",
Usage: "This flag can be used to specify a key-value " +
"pair that should be appended to the " +
"outgoing context before the request is sent " +
"to lnd. This flag may be specified multiple " +
"times. The format is: \"key:value\".",
},
cli.BoolFlag{
Name: "insecure",
Usage: "Connect to the rpc server without TLS " +
"authentication",
Hidden: true,
},
}
app.Commands = []cli.Command{
createCommand,
createWatchOnlyCommand,
unlockCommand,
changePasswordCommand,
newAddressCommand,
estimateFeeCommand,
sendManyCommand,
sendCoinsCommand,
listUnspentCommand,
connectCommand,
disconnectCommand,
openChannelCommand,
batchOpenChannelCommand,
closeChannelCommand,
closeAllChannelsCommand,
abandonChannelCommand,
listPeersCommand,
walletBalanceCommand,
channelBalanceCommand,
getInfoCommand,
getDebugInfoCommand,
encryptDebugPackageCommand,
decryptDebugPackageCommand,
getRecoveryInfoCommand,
pendingChannelsCommand,
sendPaymentCommand,
payInvoiceCommand,
sendToRouteCommand,
addInvoiceCommand,
lookupInvoiceCommand,
listInvoicesCommand,
listChannelsCommand,
closedChannelsCommand,
listPaymentsCommand,
describeGraphCommand,
getNodeMetricsCommand,
getChanInfoCommand,
getNodeInfoCommand,
queryRoutesCommand,
getNetworkInfoCommand,
debugLevelCommand,
decodePayReqCommand,
listChainTxnsCommand,
stopCommand,
signMessageCommand,
verifyMessageCommand,
feeReportCommand,
updateChannelPolicyCommand,
forwardingHistoryCommand,
exportChanBackupCommand,
verifyChanBackupCommand,
restoreChanBackupCommand,
bakeMacaroonCommand,
listMacaroonIDsCommand,
deleteMacaroonIDCommand,
listPermissionsCommand,
printMacaroonCommand,
constrainMacaroonCommand,
trackPaymentCommand,
versionCommand,
profileSubCommand,
getStateCommand,
deletePaymentsCommand,
sendCustomCommand,
subscribeCustomCommand,
fishCompletionCommand,
listAliasesCommand,
estimateRouteFeeCommand,
generateManPageCommand,
}
// Add any extra commands determined by build flags.
app.Commands = append(app.Commands, autopilotCommands()...)
app.Commands = append(app.Commands, invoicesCommands()...)
app.Commands = append(app.Commands, neutrinoCommands()...)
app.Commands = append(app.Commands, routerCommands()...)
app.Commands = append(app.Commands, walletCommands()...)
app.Commands = append(app.Commands, watchtowerCommands()...)
app.Commands = append(app.Commands, wtclientCommands()...)
app.Commands = append(app.Commands, devCommands()...)
app.Commands = append(app.Commands, peersCommands()...)
app.Commands = append(app.Commands, chainCommands()...)
if err := app.Run(os.Args); err != nil {
fatal(err)
}
}
// readPassword reads a password from the terminal. This requires there to be an
// actual TTY so passing in a password from stdin won't work.
func readPassword(text string) ([]byte, error) {
fmt.Print(text)
// The variable syscall.Stdin is of a different type in the Windows API
// that's why we need the explicit cast. And of course the linter
// doesn't like it either.
pw, err := term.ReadPassword(int(syscall.Stdin)) // nolint:unconvert
fmt.Println()
return pw, err
}
// networkParams parses the global network flag into a chaincfg.Params.
func networkParams(ctx *cli.Context) (*chaincfg.Params, error) {
network := strings.ToLower(ctx.GlobalString("network"))
switch network {
case "mainnet":
return &chaincfg.MainNetParams, nil
case "testnet":
return &chaincfg.TestNet3Params, nil
case "regtest":
return &chaincfg.RegressionNetParams, nil
case "simnet":
return &chaincfg.SimNetParams, nil
case "signet":
return &chaincfg.SigNetParams, nil
default:
return nil, fmt.Errorf("unknown network: %v", network)
}
}
// parseCoinSelectionStrategy parses a coin selection strategy string
// from the CLI to its lnrpc.CoinSelectionStrategy counterpart proto type.
func parseCoinSelectionStrategy(ctx *cli.Context) (
lnrpc.CoinSelectionStrategy, error) {
strategy := ctx.String(coinSelectionStrategyFlag.Name)
if !ctx.IsSet(coinSelectionStrategyFlag.Name) {
return lnrpc.CoinSelectionStrategy_STRATEGY_USE_GLOBAL_CONFIG,
nil
}
switch strategy {
case "global-config":
return lnrpc.CoinSelectionStrategy_STRATEGY_USE_GLOBAL_CONFIG,
nil
case "largest":
return lnrpc.CoinSelectionStrategy_STRATEGY_LARGEST, nil
case "random":
return lnrpc.CoinSelectionStrategy_STRATEGY_RANDOM, nil
default:
return 0, fmt.Errorf("unknown coin selection strategy "+
"%v", strategy)
}
} }

View file

@ -33,6 +33,8 @@ import (
"github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chainreg"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/funding"
"github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
@ -40,11 +42,15 @@ import (
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/btcwallet" "github.com/lightningnetwork/lnd/lnwallet/btcwallet"
"github.com/lightningnetwork/lnd/lnwallet/chancloser"
"github.com/lightningnetwork/lnd/lnwallet/rpcwallet" "github.com/lightningnetwork/lnd/lnwallet/rpcwallet"
"github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/macaroons"
"github.com/lightningnetwork/lnd/msgmux"
"github.com/lightningnetwork/lnd/routing"
"github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/rpcperms"
"github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/signal"
"github.com/lightningnetwork/lnd/sqldb" "github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sweep"
"github.com/lightningnetwork/lnd/walletunlocker" "github.com/lightningnetwork/lnd/walletunlocker"
"github.com/lightningnetwork/lnd/watchtower" "github.com/lightningnetwork/lnd/watchtower"
"github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtclient"
@ -103,7 +109,7 @@ type DatabaseBuilder interface {
type WalletConfigBuilder interface { type WalletConfigBuilder interface {
// BuildWalletConfig is responsible for creating or unlocking and then // BuildWalletConfig is responsible for creating or unlocking and then
// fully initializing a wallet. // fully initializing a wallet.
BuildWalletConfig(context.Context, *DatabaseInstances, BuildWalletConfig(context.Context, *DatabaseInstances, *AuxComponents,
*rpcperms.InterceptorChain, *rpcperms.InterceptorChain,
[]*ListenerWithSignal) (*chainreg.PartialChainControl, []*ListenerWithSignal) (*chainreg.PartialChainControl,
*btcwallet.Config, func(), error) *btcwallet.Config, func(), error)
@ -144,6 +150,52 @@ type ImplementationCfg struct {
// ChainControlBuilder is a type that can provide a custom wallet // ChainControlBuilder is a type that can provide a custom wallet
// implementation. // implementation.
ChainControlBuilder ChainControlBuilder
// AuxComponents is a set of auxiliary components that can be used by
// lnd for certain custom channel types.
AuxComponents
}
// AuxComponents is a set of auxiliary components that can be used by lnd for
// certain custom channel types.
type AuxComponents struct {
// AuxLeafStore is an optional data source that can be used by custom
// channels to fetch+store various data.
AuxLeafStore fn.Option[lnwallet.AuxLeafStore]
// TrafficShaper is an optional traffic shaper that can be used to
// control the outgoing channel of a payment.
TrafficShaper fn.Option[routing.TlvTrafficShaper]
// MsgRouter is an optional message router that if set will be used in
// place of a new blank default message router.
MsgRouter fn.Option[msgmux.Router]
// AuxFundingController is an optional controller that can be used to
// modify the way we handle certain custom channel types. It's also
// able to automatically handle new custom protocol messages related to
// the funding process.
AuxFundingController fn.Option[funding.AuxFundingController]
// AuxSigner is an optional signer that can be used to sign auxiliary
// leaves for certain custom channel types.
AuxSigner fn.Option[lnwallet.AuxSigner]
// AuxDataParser is an optional data parser that can be used to parse
// auxiliary data for certain custom channel types.
AuxDataParser fn.Option[AuxDataParser]
// AuxChanCloser is an optional channel closer that can be used to
// modify the way a coop-close transaction is constructed.
AuxChanCloser fn.Option[chancloser.AuxChanCloser]
// AuxSweeper is an optional interface that can be used to modify the
// way sweep transaction are generated.
AuxSweeper fn.Option[sweep.AuxSweeper]
// AuxContractResolver is an optional interface that can be used to
// modify the way contracts are resolved.
AuxContractResolver fn.Option[lnwallet.AuxContractResolver]
} }
// DefaultWalletImpl is the default implementation of our normal, btcwallet // DefaultWalletImpl is the default implementation of our normal, btcwallet
@ -228,7 +280,8 @@ func (d *DefaultWalletImpl) Permissions() map[string][]bakery.Op {
// //
// NOTE: This is part of the WalletConfigBuilder interface. // NOTE: This is part of the WalletConfigBuilder interface.
func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context,
dbs *DatabaseInstances, interceptorChain *rpcperms.InterceptorChain, dbs *DatabaseInstances, aux *AuxComponents,
interceptorChain *rpcperms.InterceptorChain,
grpcListeners []*ListenerWithSignal) (*chainreg.PartialChainControl, grpcListeners []*ListenerWithSignal) (*chainreg.PartialChainControl,
*btcwallet.Config, func(), error) { *btcwallet.Config, func(), error) {
@ -548,6 +601,8 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context,
HeightHintDB: dbs.HeightHintDB, HeightHintDB: dbs.HeightHintDB,
ChanStateDB: dbs.ChanStateDB.ChannelStateDB(), ChanStateDB: dbs.ChanStateDB.ChannelStateDB(),
NeutrinoCS: neutrinoCS, NeutrinoCS: neutrinoCS,
AuxLeafStore: aux.AuxLeafStore,
AuxSigner: aux.AuxSigner,
ActiveNetParams: d.cfg.ActiveNetParams, ActiveNetParams: d.cfg.ActiveNetParams,
FeeURL: d.cfg.FeeURL, FeeURL: d.cfg.FeeURL,
Fee: &lncfg.Fee{ Fee: &lncfg.Fee{
@ -611,8 +666,9 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context,
// proxyBlockEpoch proxies a block epoch subsections to the underlying neutrino // proxyBlockEpoch proxies a block epoch subsections to the underlying neutrino
// rebroadcaster client. // rebroadcaster client.
func proxyBlockEpoch(notifier chainntnfs.ChainNotifier, func proxyBlockEpoch(
) func() (*blockntfns.Subscription, error) { notifier chainntnfs.ChainNotifier) func() (*blockntfns.Subscription,
error) {
return func() (*blockntfns.Subscription, error) { return func() (*blockntfns.Subscription, error) {
blockEpoch, err := notifier.RegisterBlockEpochNtfn( blockEpoch, err := notifier.RegisterBlockEpochNtfn(
@ -703,6 +759,8 @@ func (d *DefaultWalletImpl) BuildChainControl(
ChainIO: walletController, ChainIO: walletController,
NetParams: *walletConfig.NetParams, NetParams: *walletConfig.NetParams,
CoinSelectionStrategy: walletConfig.CoinSelectionStrategy, CoinSelectionStrategy: walletConfig.CoinSelectionStrategy,
AuxLeafStore: partialChainControl.Cfg.AuxLeafStore,
AuxSigner: partialChainControl.Cfg.AuxSigner,
} }
// The broadcast is already always active for neutrino nodes, so we // The broadcast is already always active for neutrino nodes, so we

View file

@ -15,6 +15,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/labels"
@ -22,6 +23,8 @@ import (
"github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/sweep"
"github.com/lightningnetwork/lnd/tlv"
) )
const ( const (
@ -147,7 +150,7 @@ type BreachConfig struct {
Estimator chainfee.Estimator Estimator chainfee.Estimator
// GenSweepScript generates the receiving scripts for swept outputs. // GenSweepScript generates the receiving scripts for swept outputs.
GenSweepScript func() ([]byte, error) GenSweepScript func() fn.Result[lnwallet.AddrWithKey]
// Notifier provides a publish/subscribe interface for event driven // Notifier provides a publish/subscribe interface for event driven
// notifications regarding the confirmation of txids. // notifications regarding the confirmation of txids.
@ -172,6 +175,10 @@ type BreachConfig struct {
// breached channels. This is used in conjunction with DB to recover // breached channels. This is used in conjunction with DB to recover
// from crashes, restarts, or other failures. // from crashes, restarts, or other failures.
Store RetributionStorer Store RetributionStorer
// AuxSweeper is an optional interface that can be used to modify the
// way sweep transaction are generated.
AuxSweeper fn.Option[sweep.AuxSweeper]
} }
// BreachArbitrator is a special subsystem which is responsible for watching and // BreachArbitrator is a special subsystem which is responsible for watching and
@ -735,10 +742,28 @@ justiceTxBroadcast:
brarLog.Debugf("Broadcasting justice tx: %v", lnutils.SpewLogClosure( brarLog.Debugf("Broadcasting justice tx: %v", lnutils.SpewLogClosure(
finalTx)) finalTx))
// As we're about to broadcast our breach transaction, we'll notify the
// aux sweeper of our broadcast attempt first.
err = fn.MapOptionZ(b.cfg.AuxSweeper, func(aux sweep.AuxSweeper) error {
bumpReq := sweep.BumpRequest{
Inputs: finalTx.inputs,
DeliveryAddress: finalTx.sweepAddr,
ExtraTxOut: finalTx.extraTxOut,
}
return aux.NotifyBroadcast(
&bumpReq, finalTx.justiceTx, finalTx.fee, nil,
)
})
if err != nil {
brarLog.Errorf("unable to notify broadcast: %w", err)
return
}
// We'll now attempt to broadcast the transaction which finalized the // We'll now attempt to broadcast the transaction which finalized the
// channel's retribution against the cheating counter party. // channel's retribution against the cheating counter party.
label := labels.MakeLabel(labels.LabelTypeJusticeTransaction, nil) label := labels.MakeLabel(labels.LabelTypeJusticeTransaction, nil)
err = b.cfg.PublishTransaction(finalTx, label) err = b.cfg.PublishTransaction(finalTx.justiceTx, label)
if err != nil { if err != nil {
brarLog.Errorf("Unable to broadcast justice tx: %v", err) brarLog.Errorf("Unable to broadcast justice tx: %v", err)
} }
@ -858,7 +883,9 @@ Loop:
"spending commitment outs: %v", "spending commitment outs: %v",
lnutils.SpewLogClosure(tx)) lnutils.SpewLogClosure(tx))
err = b.cfg.PublishTransaction(tx, label) err = b.cfg.PublishTransaction(
tx.justiceTx, label,
)
if err != nil { if err != nil {
brarLog.Warnf("Unable to broadcast "+ brarLog.Warnf("Unable to broadcast "+
"commit out spending justice "+ "commit out spending justice "+
@ -873,7 +900,9 @@ Loop:
"spending HTLC outs: %v", "spending HTLC outs: %v",
lnutils.SpewLogClosure(tx)) lnutils.SpewLogClosure(tx))
err = b.cfg.PublishTransaction(tx, label) err = b.cfg.PublishTransaction(
tx.justiceTx, label,
)
if err != nil { if err != nil {
brarLog.Warnf("Unable to broadcast "+ brarLog.Warnf("Unable to broadcast "+
"HTLC out spending justice "+ "HTLC out spending justice "+
@ -888,7 +917,9 @@ Loop:
"spending second-level HTLC output: %v", "spending second-level HTLC output: %v",
lnutils.SpewLogClosure(tx)) lnutils.SpewLogClosure(tx))
err = b.cfg.PublishTransaction(tx, label) err = b.cfg.PublishTransaction(
tx.justiceTx, label,
)
if err != nil { if err != nil {
brarLog.Warnf("Unable to broadcast "+ brarLog.Warnf("Unable to broadcast "+
"second-level HTLC out "+ "second-level HTLC out "+
@ -1067,15 +1098,18 @@ type breachedOutput struct {
secondLevelTapTweak [32]byte secondLevelTapTweak [32]byte
witnessFunc input.WitnessGenerator witnessFunc input.WitnessGenerator
resolutionBlob fn.Option[tlv.Blob]
// TODO(roasbeef): function opt and hook into brar
} }
// makeBreachedOutput assembles a new breachedOutput that can be used by the // makeBreachedOutput assembles a new breachedOutput that can be used by the
// breach arbiter to construct a justice or sweep transaction. // breach arbiter to construct a justice or sweep transaction.
func makeBreachedOutput(outpoint *wire.OutPoint, func makeBreachedOutput(outpoint *wire.OutPoint,
witnessType input.StandardWitnessType, witnessType input.StandardWitnessType, secondLevelScript []byte,
secondLevelScript []byte, signDescriptor *input.SignDescriptor, confHeight uint32,
signDescriptor *input.SignDescriptor, resolutionBlob fn.Option[tlv.Blob]) breachedOutput {
confHeight uint32) breachedOutput {
amount := signDescriptor.Output.Value amount := signDescriptor.Output.Value
@ -1086,6 +1120,7 @@ func makeBreachedOutput(outpoint *wire.OutPoint,
witnessType: witnessType, witnessType: witnessType,
signDesc: *signDescriptor, signDesc: *signDescriptor,
confHeight: confHeight, confHeight: confHeight,
resolutionBlob: resolutionBlob,
} }
} }
@ -1125,6 +1160,11 @@ func (bo *breachedOutput) SignDesc() *input.SignDescriptor {
return &bo.signDesc return &bo.signDesc
} }
// Preimage returns the preimage that was used to create the breached output.
func (bo *breachedOutput) Preimage() fn.Option[lntypes.Preimage] {
return fn.None[lntypes.Preimage]()
}
// CraftInputScript computes a valid witness that allows us to spend from the // CraftInputScript computes a valid witness that allows us to spend from the
// breached output. It does so by first generating and memoizing the witness // breached output. It does so by first generating and memoizing the witness
// generation function, which parameterized primarily by the witness type and // generation function, which parameterized primarily by the witness type and
@ -1174,6 +1214,12 @@ func (bo *breachedOutput) UnconfParent() *input.TxInfo {
return nil return nil
} }
// ResolutionBlob returns a special opaque blob to be used to sweep/resolve this
// input.
func (bo *breachedOutput) ResolutionBlob() fn.Option[tlv.Blob] {
return bo.resolutionBlob
}
// Add compile-time constraint ensuring breachedOutput implements the Input // Add compile-time constraint ensuring breachedOutput implements the Input
// interface. // interface.
var _ input.Input = (*breachedOutput)(nil) var _ input.Input = (*breachedOutput)(nil)
@ -1258,6 +1304,7 @@ func newRetributionInfo(chanPoint *wire.OutPoint,
nil, nil,
breachInfo.LocalOutputSignDesc, breachInfo.LocalOutputSignDesc,
breachInfo.BreachHeight, breachInfo.BreachHeight,
breachInfo.LocalResolutionBlob,
) )
breachedOutputs = append(breachedOutputs, localOutput) breachedOutputs = append(breachedOutputs, localOutput)
@ -1284,6 +1331,7 @@ func newRetributionInfo(chanPoint *wire.OutPoint,
nil, nil,
breachInfo.RemoteOutputSignDesc, breachInfo.RemoteOutputSignDesc,
breachInfo.BreachHeight, breachInfo.BreachHeight,
breachInfo.RemoteResolutionBlob,
) )
breachedOutputs = append(breachedOutputs, remoteOutput) breachedOutputs = append(breachedOutputs, remoteOutput)
@ -1318,6 +1366,7 @@ func newRetributionInfo(chanPoint *wire.OutPoint,
breachInfo.HtlcRetributions[i].SecondLevelWitnessScript, breachInfo.HtlcRetributions[i].SecondLevelWitnessScript,
&breachInfo.HtlcRetributions[i].SignDesc, &breachInfo.HtlcRetributions[i].SignDesc,
breachInfo.BreachHeight, breachInfo.BreachHeight,
breachInfo.HtlcRetributions[i].ResolutionBlob,
) )
// For taproot outputs, we also need to hold onto the second // For taproot outputs, we also need to hold onto the second
@ -1357,10 +1406,10 @@ func newRetributionInfo(chanPoint *wire.OutPoint,
// spend the to_local output and commitment level HTLC outputs separately, // spend the to_local output and commitment level HTLC outputs separately,
// before the CSV locks expire. // before the CSV locks expire.
type justiceTxVariants struct { type justiceTxVariants struct {
spendAll *wire.MsgTx spendAll *justiceTxCtx
spendCommitOuts *wire.MsgTx spendCommitOuts *justiceTxCtx
spendHTLCs *wire.MsgTx spendHTLCs *justiceTxCtx
spendSecondLevelHTLCs []*wire.MsgTx spendSecondLevelHTLCs []*justiceTxCtx
} }
// createJusticeTx creates transactions which exacts "justice" by sweeping ALL // createJusticeTx creates transactions which exacts "justice" by sweeping ALL
@ -1424,7 +1473,9 @@ func (b *BreachArbitrator) createJusticeTx(
err) err)
} }
secondLevelSweeps := make([]*wire.MsgTx, 0, len(secondLevelInputs)) // TODO(roasbeef): only register one of them?
secondLevelSweeps := make([]*justiceTxCtx, 0, len(secondLevelInputs))
for _, input := range secondLevelInputs { for _, input := range secondLevelInputs {
sweepTx, err := b.createSweepTx(input) sweepTx, err := b.createSweepTx(input)
if err != nil { if err != nil {
@ -1441,9 +1492,23 @@ func (b *BreachArbitrator) createJusticeTx(
return txs, nil return txs, nil
} }
// justiceTxCtx contains the justice transaction along with other related meta
// data.
type justiceTxCtx struct {
justiceTx *wire.MsgTx
sweepAddr lnwallet.AddrWithKey
extraTxOut fn.Option[sweep.SweepOutput]
fee btcutil.Amount
inputs []input.Input
}
// createSweepTx creates a tx that sweeps the passed inputs back to our wallet. // createSweepTx creates a tx that sweeps the passed inputs back to our wallet.
func (b *BreachArbitrator) createSweepTx(inputs ...input.Input) (*wire.MsgTx, func (b *BreachArbitrator) createSweepTx(
error) { inputs ...input.Input) (*justiceTxCtx, error) {
if len(inputs) == 0 { if len(inputs) == 0 {
return nil, nil return nil, nil
@ -1466,6 +1531,18 @@ func (b *BreachArbitrator) createSweepTx(inputs ...input.Input) (*wire.MsgTx,
// nLockTime, and output are already included in the TxWeightEstimator. // nLockTime, and output are already included in the TxWeightEstimator.
weightEstimate.AddP2TROutput() weightEstimate.AddP2TROutput()
// If any of our inputs has a resolution blob, then we'll add another
// P2TR _output_, since we'll want to separate the custom channel
// outputs from the regular, BTC only outputs. So we only need one such
// output, which'll carry the custom channel "valuables" from both the
// breached commitment and HTLC outputs.
hasBlobs := fn.Any(func(i input.Input) bool {
return i.ResolutionBlob().IsSome()
}, inputs)
if hasBlobs {
weightEstimate.AddP2TROutput()
}
// Next, we iterate over the breached outputs contained in the // Next, we iterate over the breached outputs contained in the
// retribution info. For each, we switch over the witness type such // retribution info. For each, we switch over the witness type such
// that we contribute the appropriate weight for each input and // that we contribute the appropriate weight for each input and
@ -1499,13 +1576,13 @@ func (b *BreachArbitrator) createSweepTx(inputs ...input.Input) (*wire.MsgTx,
// sweepSpendableOutputsTxn creates a signed transaction from a sequence of // sweepSpendableOutputsTxn creates a signed transaction from a sequence of
// spendable outputs by sweeping the funds into a single p2wkh output. // spendable outputs by sweeping the funds into a single p2wkh output.
func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit, func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit,
inputs ...input.Input) (*wire.MsgTx, error) { inputs ...input.Input) (*justiceTxCtx, error) {
// First, we obtain a new public key script from the wallet which we'll // First, we obtain a new public key script from the wallet which we'll
// sweep the funds to. // sweep the funds to.
// TODO(roasbeef): possibly create many outputs to minimize change in // TODO(roasbeef): possibly create many outputs to minimize change in
// the future? // the future?
pkScript, err := b.cfg.GenSweepScript() pkScript, err := b.cfg.GenSweepScript().Unpack()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1524,6 +1601,18 @@ func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit,
} }
txFee := feePerKw.FeeForWeight(txWeight) txFee := feePerKw.FeeForWeight(txWeight)
// At this point, we'll check to see if we have any extra outputs to
// add from the aux sweeper.
extraChangeOut := fn.MapOptionZ(
b.cfg.AuxSweeper,
func(aux sweep.AuxSweeper) fn.Result[sweep.SweepOutput] {
return aux.DeriveSweepAddr(inputs, pkScript)
},
)
if err := extraChangeOut.Err(); err != nil {
return nil, err
}
// TODO(roasbeef): already start to siphon their funds into fees // TODO(roasbeef): already start to siphon their funds into fees
sweepAmt := int64(totalAmt - txFee) sweepAmt := int64(totalAmt - txFee)
@ -1531,12 +1620,24 @@ func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit,
// information gathered above and the provided retribution information. // information gathered above and the provided retribution information.
txn := wire.NewMsgTx(2) txn := wire.NewMsgTx(2)
// We begin by adding the output to which our funds will be deposited. // First, we'll add the extra sweep output if it exists, subtracting the
// amount from the sweep amt.
if b.cfg.AuxSweeper.IsSome() {
extraChangeOut.WhenResult(func(o sweep.SweepOutput) {
sweepAmt -= o.Value
txn.AddTxOut(&o.TxOut)
})
}
// Next, we'll add the output to which our funds will be deposited.
txn.AddTxOut(&wire.TxOut{ txn.AddTxOut(&wire.TxOut{
PkScript: pkScript, PkScript: pkScript.DeliveryAddress,
Value: sweepAmt, Value: sweepAmt,
}) })
// TODO(roasbeef): add other output change modify sweep amt
// Next, we add all of the spendable outputs as inputs to the // Next, we add all of the spendable outputs as inputs to the
// transaction. // transaction.
for _, inp := range inputs { for _, inp := range inputs {
@ -1592,7 +1693,13 @@ func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit,
} }
} }
return txn, nil return &justiceTxCtx{
justiceTx: txn,
sweepAddr: pkScript,
extraTxOut: extraChangeOut.Option(),
fee: txFee,
inputs: inputs,
}, nil
} }
// RetributionStore handles persistence of retribution states to disk and is // RetributionStore handles persistence of retribution states to disk and is
@ -1622,13 +1729,29 @@ func taprootBriefcaseFromRetInfo(retInfo *retributionInfo) *taprootBriefcase {
// commitment, we'll need to stash the control block. // commitment, we'll need to stash the control block.
case input.TaprootRemoteCommitSpend: case input.TaprootRemoteCommitSpend:
//nolint:lll //nolint:lll
tapCase.CtrlBlocks.CommitSweepCtrlBlock = bo.signDesc.ControlBlock tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock = bo.signDesc.ControlBlock
bo.resolutionBlob.WhenSome(func(blob tlv.Blob) {
tapCase.SettledCommitBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType2](
blob,
),
)
})
// To spend the revoked output again, we'll store the same // To spend the revoked output again, we'll store the same
// control block value as above, but in a different place. // control block value as above, but in a different place.
case input.TaprootCommitmentRevoke: case input.TaprootCommitmentRevoke:
//nolint:lll //nolint:lll
tapCase.CtrlBlocks.RevokeSweepCtrlBlock = bo.signDesc.ControlBlock tapCase.CtrlBlocks.Val.RevokeSweepCtrlBlock = bo.signDesc.ControlBlock
bo.resolutionBlob.WhenSome(func(blob tlv.Blob) {
tapCase.BreachedCommitBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType3](
blob,
),
)
})
// For spending the HTLC outputs, we'll store the first and // For spending the HTLC outputs, we'll store the first and
// second level tweak values. // second level tweak values.
@ -1642,10 +1765,10 @@ func taprootBriefcaseFromRetInfo(retInfo *retributionInfo) *taprootBriefcase {
secondLevelTweak := bo.secondLevelTapTweak secondLevelTweak := bo.secondLevelTapTweak
//nolint:lll //nolint:lll
tapCase.TapTweaks.BreachedHtlcTweaks[resID] = firstLevelTweak tapCase.TapTweaks.Val.BreachedHtlcTweaks[resID] = firstLevelTweak
//nolint:lll //nolint:lll
tapCase.TapTweaks.BreachedSecondLevelHltcTweaks[resID] = secondLevelTweak tapCase.TapTweaks.Val.BreachedSecondLevelHltcTweaks[resID] = secondLevelTweak
} }
} }
@ -1665,13 +1788,25 @@ func applyTaprootRetInfo(tapCase *taprootBriefcase,
// commitment, we'll apply the control block. // commitment, we'll apply the control block.
case input.TaprootRemoteCommitSpend: case input.TaprootRemoteCommitSpend:
//nolint:lll //nolint:lll
bo.signDesc.ControlBlock = tapCase.CtrlBlocks.CommitSweepCtrlBlock bo.signDesc.ControlBlock = tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock
tapCase.SettledCommitBlob.WhenSomeV(
func(blob tlv.Blob) {
bo.resolutionBlob = fn.Some(blob)
},
)
// To spend the revoked output again, we'll apply the same // To spend the revoked output again, we'll apply the same
// control block value as above, but to a different place. // control block value as above, but to a different place.
case input.TaprootCommitmentRevoke: case input.TaprootCommitmentRevoke:
//nolint:lll //nolint:lll
bo.signDesc.ControlBlock = tapCase.CtrlBlocks.RevokeSweepCtrlBlock bo.signDesc.ControlBlock = tapCase.CtrlBlocks.Val.RevokeSweepCtrlBlock
tapCase.BreachedCommitBlob.WhenSomeV(
func(blob tlv.Blob) {
bo.resolutionBlob = fn.Some(blob)
},
)
// For spending the HTLC outputs, we'll apply the first and // For spending the HTLC outputs, we'll apply the first and
// second level tweak values. // second level tweak values.
@ -1680,7 +1815,8 @@ func applyTaprootRetInfo(tapCase *taprootBriefcase,
case input.TaprootHtlcOfferedRevoke: case input.TaprootHtlcOfferedRevoke:
resID := newResolverID(bo.OutPoint()) resID := newResolverID(bo.OutPoint())
tap1, ok := tapCase.TapTweaks.BreachedHtlcTweaks[resID] //nolint:lll
tap1, ok := tapCase.TapTweaks.Val.BreachedHtlcTweaks[resID]
if !ok { if !ok {
return fmt.Errorf("unable to find taproot "+ return fmt.Errorf("unable to find taproot "+
"tweak for: %v", bo.OutPoint()) "tweak for: %v", bo.OutPoint())
@ -1688,7 +1824,7 @@ func applyTaprootRetInfo(tapCase *taprootBriefcase,
bo.signDesc.TapTweak = tap1[:] bo.signDesc.TapTweak = tap1[:]
//nolint:lll //nolint:lll
tap2, ok := tapCase.TapTweaks.BreachedSecondLevelHltcTweaks[resID] tap2, ok := tapCase.TapTweaks.Val.BreachedSecondLevelHltcTweaks[resID]
if !ok { if !ok {
return fmt.Errorf("unable to find taproot "+ return fmt.Errorf("unable to find taproot "+
"tweak for: %v", bo.OutPoint()) "tweak for: %v", bo.OutPoint())

View file

@ -22,6 +22,7 @@ import (
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lntest/channels"
@ -1198,6 +1199,8 @@ func TestBreachCreateJusticeTx(t *testing.T) {
input.HtlcSecondLevelRevoke, input.HtlcSecondLevelRevoke,
} }
rBlob := fn.Some([]byte{0x01})
breachedOutputs := make([]breachedOutput, len(outputTypes)) breachedOutputs := make([]breachedOutput, len(outputTypes))
for i, wt := range outputTypes { for i, wt := range outputTypes {
// Create a fake breached output for each type, ensuring they // Create a fake breached output for each type, ensuring they
@ -1216,6 +1219,7 @@ func TestBreachCreateJusticeTx(t *testing.T) {
nil, nil,
signDesc, signDesc,
1, 1,
rBlob,
) )
} }
@ -1226,16 +1230,16 @@ func TestBreachCreateJusticeTx(t *testing.T) {
// The spendAll tx should be spending all the outputs. This is the // The spendAll tx should be spending all the outputs. This is the
// "regular" justice transaction type. // "regular" justice transaction type.
require.Len(t, justiceTxs.spendAll.TxIn, len(breachedOutputs)) require.Len(t, justiceTxs.spendAll.justiceTx.TxIn, len(breachedOutputs))
// The spendCommitOuts tx should be spending the 4 types of commit outs // The spendCommitOuts tx should be spending the 4 types of commit outs
// (note that in practice there will be at most two commit outputs per // (note that in practice there will be at most two commit outputs per
// commit, but we test all 4 types here). // commit, but we test all 4 types here).
require.Len(t, justiceTxs.spendCommitOuts.TxIn, 4) require.Len(t, justiceTxs.spendCommitOuts.justiceTx.TxIn, 4)
// Check that the spendHTLCs tx is spending the two revoked commitment // Check that the spendHTLCs tx is spending the two revoked commitment
// level HTLC output types. // level HTLC output types.
require.Len(t, justiceTxs.spendHTLCs.TxIn, 2) require.Len(t, justiceTxs.spendHTLCs.justiceTx.TxIn, 2)
// Finally, check that the spendSecondLevelHTLCs txs are spending the // Finally, check that the spendSecondLevelHTLCs txs are spending the
// second level type. // second level type.
@ -1590,6 +1594,10 @@ func testBreachSpends(t *testing.T, test breachTest) {
// Notify the breach arbiter about the breach. // Notify the breach arbiter about the breach.
retribution, err := lnwallet.NewBreachRetribution( retribution, err := lnwallet.NewBreachRetribution(
alice.State(), height, 1, forceCloseTx, alice.State(), height, 1, forceCloseTx,
fn.Some[lnwallet.AuxLeafStore](&lnwallet.MockAuxLeafStore{}),
fn.Some[lnwallet.AuxContractResolver](
&lnwallet.MockAuxContractResolver{},
),
) )
require.NoError(t, err, "unable to create breach retribution") require.NoError(t, err, "unable to create breach retribution")
@ -1799,6 +1807,10 @@ func TestBreachDelayedJusticeConfirmation(t *testing.T) {
// Notify the breach arbiter about the breach. // Notify the breach arbiter about the breach.
retribution, err := lnwallet.NewBreachRetribution( retribution, err := lnwallet.NewBreachRetribution(
alice.State(), height, uint32(blockHeight), forceCloseTx, alice.State(), height, uint32(blockHeight), forceCloseTx,
fn.Some[lnwallet.AuxLeafStore](&lnwallet.MockAuxLeafStore{}),
fn.Some[lnwallet.AuxContractResolver](
&lnwallet.MockAuxContractResolver{},
),
) )
require.NoError(t, err, "unable to create breach retribution") require.NoError(t, err, "unable to create breach retribution")
@ -2126,15 +2138,19 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent,
// Assemble our test arbiter. // Assemble our test arbiter.
notifier := mock.MakeMockSpendNotifier() notifier := mock.MakeMockSpendNotifier()
ba := NewBreachArbitrator(&BreachConfig{ ba := NewBreachArbitrator(&BreachConfig{
CloseLink: func(_ *wire.OutPoint, _ ChannelCloseType) {}, CloseLink: func(_ *wire.OutPoint, _ ChannelCloseType) {},
DB: db.ChannelStateDB(), DB: db.ChannelStateDB(),
Estimator: chainfee.NewStaticEstimator(12500, 0), Estimator: chainfee.NewStaticEstimator(12500, 0),
GenSweepScript: func() ([]byte, error) { return nil, nil }, GenSweepScript: func() fn.Result[lnwallet.AddrWithKey] {
ContractBreaches: contractBreaches, return fn.Ok(lnwallet.AddrWithKey{})
Signer: signer, },
Notifier: notifier, ContractBreaches: contractBreaches,
PublishTransaction: func(_ *wire.MsgTx, _ string) error { return nil }, Signer: signer,
Store: store, Notifier: notifier,
PublishTransaction: func(_ *wire.MsgTx, _ string) error {
return nil
},
Store: store,
}) })
if err := ba.Start(); err != nil { if err := ba.Start(); err != nil {
@ -2357,9 +2373,12 @@ func createInitChannels(t *testing.T) (
) )
bobSigner := input.NewMockSigner([]*btcec.PrivateKey{bobKeyPriv}, nil) bobSigner := input.NewMockSigner([]*btcec.PrivateKey{bobKeyPriv}, nil)
signerMock := lnwallet.NewDefaultAuxSignerMock(t)
alicePool := lnwallet.NewSigPool(1, aliceSigner) alicePool := lnwallet.NewSigPool(1, aliceSigner)
channelAlice, err := lnwallet.NewLightningChannel( channelAlice, err := lnwallet.NewLightningChannel(
aliceSigner, aliceChannelState, alicePool, aliceSigner, aliceChannelState, alicePool,
lnwallet.WithLeafStore(&lnwallet.MockAuxLeafStore{}),
lnwallet.WithAuxSigner(signerMock),
) )
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -2372,6 +2391,8 @@ func createInitChannels(t *testing.T) (
bobPool := lnwallet.NewSigPool(1, bobSigner) bobPool := lnwallet.NewSigPool(1, bobSigner)
channelBob, err := lnwallet.NewLightningChannel( channelBob, err := lnwallet.NewLightningChannel(
bobSigner, bobChannelState, bobPool, bobSigner, bobChannelState, bobPool,
lnwallet.WithLeafStore(&lnwallet.MockAuxLeafStore{}),
lnwallet.WithAuxSigner(signerMock),
) )
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View file

@ -10,9 +10,11 @@ import (
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/tlv"
) )
// ContractResolutions is a wrapper struct around the two forms of resolutions // ContractResolutions is a wrapper struct around the two forms of resolutions
@ -1553,9 +1555,16 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
commitResolution := c.CommitResolution commitResolution := c.CommitResolution
commitSignDesc := commitResolution.SelfOutputSignDesc commitSignDesc := commitResolution.SelfOutputSignDesc
//nolint:lll //nolint:lll
tapCase.CtrlBlocks.CommitSweepCtrlBlock = commitSignDesc.ControlBlock tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock = commitSignDesc.ControlBlock
c.CommitResolution.ResolutionBlob.WhenSome(func(b []byte) {
tapCase.SettledCommitBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType2](b),
)
})
} }
htlcBlobs := newAuxHtlcBlobs()
for _, htlc := range c.HtlcResolutions.IncomingHTLCs { for _, htlc := range c.HtlcResolutions.IncomingHTLCs {
htlc := htlc htlc := htlc
@ -1566,12 +1575,13 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
continue continue
} }
var resID resolverID
if htlc.SignedSuccessTx != nil { if htlc.SignedSuccessTx != nil {
resID := newResolverID( resID = newResolverID(
htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint, htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint,
) )
//nolint:lll //nolint:lll
tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] = ctrlBlock tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID] = ctrlBlock
// For HTLCs we need to go to the second level for, we // For HTLCs we need to go to the second level for, we
// also need to store the control block needed to // also need to store the control block needed to
@ -1580,13 +1590,17 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
//nolint:lll //nolint:lll
bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock
//nolint:lll //nolint:lll
tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] = bridgeCtrlBlock tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] = bridgeCtrlBlock
} }
} else { } else {
resID := newResolverID(htlc.ClaimOutpoint) resID = newResolverID(htlc.ClaimOutpoint)
//nolint:lll //nolint:lll
tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] = ctrlBlock tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] = ctrlBlock
} }
htlc.ResolutionBlob.WhenSome(func(b []byte) {
htlcBlobs[resID] = b
})
} }
for _, htlc := range c.HtlcResolutions.OutgoingHTLCs { for _, htlc := range c.HtlcResolutions.OutgoingHTLCs {
htlc := htlc htlc := htlc
@ -1598,12 +1612,13 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
continue continue
} }
var resID resolverID
if htlc.SignedTimeoutTx != nil { if htlc.SignedTimeoutTx != nil {
resID := newResolverID( resID = newResolverID(
htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint, htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint,
) )
//nolint:lll //nolint:lll
tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] = ctrlBlock tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID] = ctrlBlock
// For HTLCs we need to go to the second level for, we // For HTLCs we need to go to the second level for, we
// also need to store the control block needed to // also need to store the control block needed to
@ -1614,18 +1629,28 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
//nolint:lll //nolint:lll
bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock
//nolint:lll //nolint:lll
tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] = bridgeCtrlBlock tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] = bridgeCtrlBlock
} }
} else { } else {
resID := newResolverID(htlc.ClaimOutpoint) resID = newResolverID(htlc.ClaimOutpoint)
//nolint:lll //nolint:lll
tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] = ctrlBlock tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] = ctrlBlock
} }
htlc.ResolutionBlob.WhenSome(func(b []byte) {
htlcBlobs[resID] = b
})
} }
if c.AnchorResolution != nil { if c.AnchorResolution != nil {
anchorSignDesc := c.AnchorResolution.AnchorSignDescriptor anchorSignDesc := c.AnchorResolution.AnchorSignDescriptor
tapCase.TapTweaks.AnchorTweak = anchorSignDesc.TapTweak tapCase.TapTweaks.Val.AnchorTweak = anchorSignDesc.TapTweak
}
if len(htlcBlobs) != 0 {
tapCase.HtlcBlobs = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType4](htlcBlobs),
)
} }
return tapCase.Encode(w) return tapCase.Encode(w)
@ -1639,9 +1664,15 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error {
if c.CommitResolution != nil { if c.CommitResolution != nil {
c.CommitResolution.SelfOutputSignDesc.ControlBlock = c.CommitResolution.SelfOutputSignDesc.ControlBlock =
tapCase.CtrlBlocks.CommitSweepCtrlBlock tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock
tapCase.SettledCommitBlob.WhenSomeV(func(b []byte) {
c.CommitResolution.ResolutionBlob = fn.Some(b)
})
} }
htlcBlobs := tapCase.HtlcBlobs.ValOpt().UnwrapOr(newAuxHtlcBlobs())
for i := range c.HtlcResolutions.IncomingHTLCs { for i := range c.HtlcResolutions.IncomingHTLCs {
htlc := c.HtlcResolutions.IncomingHTLCs[i] htlc := c.HtlcResolutions.IncomingHTLCs[i]
@ -1652,23 +1683,28 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error {
) )
//nolint:lll //nolint:lll
ctrlBlock := tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] ctrlBlock := tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock htlc.SweepSignDesc.ControlBlock = ctrlBlock
//nolint:lll //nolint:lll
if htlc.SignDetails != nil { if htlc.SignDetails != nil {
bridgeCtrlBlock := tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] bridgeCtrlBlock := tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID]
htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock
} }
} else { } else {
resID = newResolverID(htlc.ClaimOutpoint) resID = newResolverID(htlc.ClaimOutpoint)
//nolint:lll //nolint:lll
ctrlBlock := tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] ctrlBlock := tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock htlc.SweepSignDesc.ControlBlock = ctrlBlock
} }
if htlcBlob, ok := htlcBlobs[resID]; ok {
htlc.ResolutionBlob = fn.Some(htlcBlob)
}
c.HtlcResolutions.IncomingHTLCs[i] = htlc c.HtlcResolutions.IncomingHTLCs[i] = htlc
} }
for i := range c.HtlcResolutions.OutgoingHTLCs { for i := range c.HtlcResolutions.OutgoingHTLCs {
htlc := c.HtlcResolutions.OutgoingHTLCs[i] htlc := c.HtlcResolutions.OutgoingHTLCs[i]
@ -1680,28 +1716,32 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error {
) )
//nolint:lll //nolint:lll
ctrlBlock := tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] ctrlBlock := tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock htlc.SweepSignDesc.ControlBlock = ctrlBlock
//nolint:lll //nolint:lll
if htlc.SignDetails != nil { if htlc.SignDetails != nil {
bridgeCtrlBlock := tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] bridgeCtrlBlock := tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID]
htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock
} }
} else { } else {
resID = newResolverID(htlc.ClaimOutpoint) resID = newResolverID(htlc.ClaimOutpoint)
//nolint:lll //nolint:lll
ctrlBlock := tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] ctrlBlock := tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock htlc.SweepSignDesc.ControlBlock = ctrlBlock
} }
if htlcBlob, ok := htlcBlobs[resID]; ok {
htlc.ResolutionBlob = fn.Some(htlcBlob)
}
c.HtlcResolutions.OutgoingHTLCs[i] = htlc c.HtlcResolutions.OutgoingHTLCs[i] = htlc
} }
if c.AnchorResolution != nil { if c.AnchorResolution != nil {
c.AnchorResolution.AnchorSignDescriptor.TapTweak = c.AnchorResolution.AnchorSignDescriptor.TapTweak =
tapCase.TapTweaks.AnchorTweak tapCase.TapTweaks.Val.AnchorTweak
} }
return nil return nil

View file

@ -217,6 +217,18 @@ type ChainArbitratorConfig struct {
// meanwhile, turn `PaymentCircuit` into an interface or bring it to a // meanwhile, turn `PaymentCircuit` into an interface or bring it to a
// lower package. // lower package.
QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey
// AuxLeafStore is an optional store that can be used to store auxiliary
// leaves for certain custom channel types.
AuxLeafStore fn.Option[lnwallet.AuxLeafStore]
// AuxSigner is an optional signer that can be used to sign auxiliary
// leaves for certain custom channel types.
AuxSigner fn.Option[lnwallet.AuxSigner]
// AuxResolver is an optional interface that can be used to modify the
// way contracts are resolved.
AuxResolver fn.Option[lnwallet.AuxContractResolver]
} }
// ChainArbitrator is a sub-system that oversees the on-chain resolution of all // ChainArbitrator is a sub-system that oversees the on-chain resolution of all
@ -299,8 +311,19 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions,
return nil, err return nil, err
} }
var chanOpts []lnwallet.ChannelOpt
a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) {
chanOpts = append(chanOpts, lnwallet.WithLeafStore(s))
})
a.c.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) {
chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s))
})
a.c.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) {
chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s))
})
chanMachine, err := lnwallet.NewLightningChannel( chanMachine, err := lnwallet.NewLightningChannel(
a.c.cfg.Signer, channel, nil, a.c.cfg.Signer, channel, nil, chanOpts...,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -312,11 +335,10 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions,
// ForceCloseChan should force close the contract that this attendant is // ForceCloseChan should force close the contract that this attendant is
// watching over. We'll use this when we decide that we need to go to chain. It // watching over. We'll use this when we decide that we need to go to chain. It
// should in addition tell the switch to remove the corresponding link, such // should in addition tell the switch to remove the corresponding link, such
// that we won't accept any new updates. The returned summary contains all items // that we won't accept any new updates.
// needed to eventually resolve all outputs on chain.
// //
// NOTE: Part of the ArbChannel interface. // NOTE: Part of the ArbChannel interface.
func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) { func (a *arbChannel) ForceCloseChan() (*wire.MsgTx, error) {
// First, we mark the channel as borked, this ensure // First, we mark the channel as borked, this ensure
// that no new state transitions can happen, and also // that no new state transitions can happen, and also
// that the link won't be loaded into the switch. // that the link won't be loaded into the switch.
@ -344,15 +366,34 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error)
return nil, err return nil, err
} }
var chanOpts []lnwallet.ChannelOpt
a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) {
chanOpts = append(chanOpts, lnwallet.WithLeafStore(s))
})
a.c.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) {
chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s))
})
a.c.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) {
chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s))
})
// Finally, we'll force close the channel completing // Finally, we'll force close the channel completing
// the force close workflow. // the force close workflow.
chanMachine, err := lnwallet.NewLightningChannel( chanMachine, err := lnwallet.NewLightningChannel(
a.c.cfg.Signer, channel, nil, a.c.cfg.Signer, channel, nil, chanOpts...,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
return chanMachine.ForceClose()
closeSummary, err := chanMachine.ForceClose(
lnwallet.WithSkipContractResolutions(),
)
if err != nil {
return nil, err
}
return closeSummary.CloseTx, nil
} }
// newActiveChannelArbitrator creates a new instance of an active channel // newActiveChannelArbitrator creates a new instance of an active channel
@ -557,6 +598,8 @@ func (c *ChainArbitrator) Start() error {
isOurAddr: c.cfg.IsOurAddress, isOurAddr: c.cfg.IsOurAddress,
contractBreach: breachClosure, contractBreach: breachClosure,
extractStateNumHint: lnwallet.GetStateNumHint, extractStateNumHint: lnwallet.GetStateNumHint,
auxLeafStore: c.cfg.AuxLeafStore,
auxResolver: c.cfg.AuxResolver,
}, },
) )
if err != nil { if err != nil {
@ -1186,6 +1229,8 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error
) )
}, },
extractStateNumHint: lnwallet.GetStateNumHint, extractStateNumHint: lnwallet.GetStateNumHint,
auxLeafStore: c.cfg.AuxLeafStore,
auxResolver: c.cfg.AuxResolver,
}, },
) )
if err != nil { if err != nil {

View file

@ -193,6 +193,12 @@ type chainWatcherConfig struct {
// obfuscater. This is used by the chain watcher to identify which // obfuscater. This is used by the chain watcher to identify which
// state was broadcast and confirmed on-chain. // state was broadcast and confirmed on-chain.
extractStateNumHint func(*wire.MsgTx, [lnwallet.StateHintSize]byte) uint64 extractStateNumHint func(*wire.MsgTx, [lnwallet.StateHintSize]byte) uint64
// auxLeafStore can be used to fetch information for custom channels.
auxLeafStore fn.Option[lnwallet.AuxLeafStore]
// auxResolver is used to supplement contract resolution.
auxResolver fn.Option[lnwallet.AuxContractResolver]
} }
// chainWatcher is a system that's assigned to every active channel. The duty // chainWatcher is a system that's assigned to every active channel. The duty
@ -308,7 +314,7 @@ func (c *chainWatcher) Start() error {
) )
if chanState.ChanType.IsTaproot() { if chanState.ChanType.IsTaproot() {
c.fundingPkScript, _, err = input.GenTaprootFundingScript( c.fundingPkScript, _, err = input.GenTaprootFundingScript(
localKey, remoteKey, 0, localKey, remoteKey, 0, chanState.TapscriptRoot,
) )
if err != nil { if err != nil {
return err return err
@ -423,15 +429,37 @@ func (c *chainWatcher) handleUnknownLocalState(
&c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg, &c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg,
) )
auxResult, err := fn.MapOptionZ(
c.cfg.auxLeafStore,
//nolint:lll
func(s lnwallet.AuxLeafStore) fn.Result[lnwallet.CommitDiffAuxResult] {
return s.FetchLeavesFromCommit(
lnwallet.NewAuxChanState(c.cfg.chanState),
c.cfg.chanState.LocalCommitment, *commitKeyRing,
lntypes.Local,
)
},
).Unpack()
if err != nil {
return false, fmt.Errorf("unable to fetch aux leaves: %w", err)
}
// With the keys derived, we'll construct the remote script that'll be // With the keys derived, we'll construct the remote script that'll be
// present if they have a non-dust balance on the commitment. // present if they have a non-dust balance on the commitment.
var leaseExpiry uint32 var leaseExpiry uint32
if c.cfg.chanState.ChanType.HasLeaseExpiration() { if c.cfg.chanState.ChanType.HasLeaseExpiration() {
leaseExpiry = c.cfg.chanState.ThawHeight leaseExpiry = c.cfg.chanState.ThawHeight
} }
remoteAuxLeaf := fn.ChainOption(
func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
},
)(auxResult.AuxLeaves)
remoteScript, _, err := lnwallet.CommitScriptToRemote( remoteScript, _, err := lnwallet.CommitScriptToRemote(
c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator,
commitKeyRing.ToRemoteKey, leaseExpiry, commitKeyRing.ToRemoteKey, leaseExpiry,
remoteAuxLeaf,
) )
if err != nil { if err != nil {
return false, err return false, err
@ -440,10 +468,16 @@ func (c *chainWatcher) handleUnknownLocalState(
// Next, we'll derive our script that includes the revocation base for // Next, we'll derive our script that includes the revocation base for
// the remote party allowing them to claim this output before the CSV // the remote party allowing them to claim this output before the CSV
// delay if we breach. // delay if we breach.
localAuxLeaf := fn.ChainOption(
func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
},
)(auxResult.AuxLeaves)
localScript, err := lnwallet.CommitScriptToSelf( localScript, err := lnwallet.CommitScriptToSelf(
c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator,
commitKeyRing.ToLocalKey, commitKeyRing.RevocationKey, commitKeyRing.ToLocalKey, commitKeyRing.RevocationKey,
uint32(c.cfg.chanState.LocalChanCfg.CsvDelay), leaseExpiry, uint32(c.cfg.chanState.LocalChanCfg.CsvDelay), leaseExpiry,
localAuxLeaf,
) )
if err != nil { if err != nil {
return false, err return false, err
@ -866,7 +900,7 @@ func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail,
spendHeight := uint32(commitSpend.SpendingHeight) spendHeight := uint32(commitSpend.SpendingHeight)
retribution, err := lnwallet.NewBreachRetribution( retribution, err := lnwallet.NewBreachRetribution(
c.cfg.chanState, broadcastStateNum, spendHeight, c.cfg.chanState, broadcastStateNum, spendHeight,
commitSpend.SpendingTx, commitSpend.SpendingTx, c.cfg.auxLeafStore, c.cfg.auxResolver,
) )
switch { switch {
@ -1116,8 +1150,8 @@ func (c *chainWatcher) dispatchLocalForceClose(
"detected", c.cfg.chanState.FundingOutpoint) "detected", c.cfg.chanState.FundingOutpoint)
forceClose, err := lnwallet.NewLocalForceCloseSummary( forceClose, err := lnwallet.NewLocalForceCloseSummary(
c.cfg.chanState, c.cfg.signer, c.cfg.chanState, c.cfg.signer, commitSpend.SpendingTx, stateNum,
commitSpend.SpendingTx, stateNum, c.cfg.auxLeafStore, c.cfg.auxResolver,
) )
if err != nil { if err != nil {
return err return err
@ -1141,16 +1175,29 @@ func (c *chainWatcher) dispatchLocalForceClose(
LocalChanConfig: c.cfg.chanState.LocalChanCfg, LocalChanConfig: c.cfg.chanState.LocalChanCfg,
} }
resolutions, err := forceClose.ContractResolutions.UnwrapOrErr(
fmt.Errorf("resolutions not found"),
)
if err != nil {
return err
}
// If our commitment output isn't dust or we have active HTLC's on the // If our commitment output isn't dust or we have active HTLC's on the
// commitment transaction, then we'll populate the balances on the // commitment transaction, then we'll populate the balances on the
// close channel summary. // close channel summary.
if forceClose.CommitResolution != nil { if resolutions.CommitResolution != nil {
closeSummary.SettledBalance = chanSnapshot.LocalBalance.ToSatoshis() localBalance := chanSnapshot.LocalBalance.ToSatoshis()
closeSummary.TimeLockedBalance = chanSnapshot.LocalBalance.ToSatoshis() closeSummary.SettledBalance = localBalance
closeSummary.TimeLockedBalance = localBalance
} }
for _, htlc := range forceClose.HtlcResolutions.OutgoingHTLCs {
htlcValue := btcutil.Amount(htlc.SweepSignDesc.Output.Value) if resolutions.HtlcResolutions != nil {
closeSummary.TimeLockedBalance += htlcValue for _, htlc := range resolutions.HtlcResolutions.OutgoingHTLCs {
htlcValue := btcutil.Amount(
htlc.SweepSignDesc.Output.Value,
)
closeSummary.TimeLockedBalance += htlcValue
}
} }
// Attempt to add a channel sync message to the close summary. // Attempt to add a channel sync message to the close summary.
@ -1209,8 +1256,8 @@ func (c *chainWatcher) dispatchRemoteForceClose(
// materials required to let each subscriber sweep the funds in the // materials required to let each subscriber sweep the funds in the
// channel on-chain. // channel on-chain.
uniClose, err := lnwallet.NewUnilateralCloseSummary( uniClose, err := lnwallet.NewUnilateralCloseSummary(
c.cfg.chanState, c.cfg.signer, commitSpend, c.cfg.chanState, c.cfg.signer, commitSpend, remoteCommit,
remoteCommit, commitPoint, commitPoint, c.cfg.auxLeafStore, c.cfg.auxResolver,
) )
if err != nil { if err != nil {
return err return err

View file

@ -2,6 +2,7 @@ package contractcourt
import ( import (
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"testing" "testing"
@ -145,17 +146,15 @@ func TestChainWatcherRemoteUnilateralClosePendingCommit(t *testing.T) {
// With the HTLC added, we'll now manually initiate a state transition // With the HTLC added, we'll now manually initiate a state transition
// from Alice to Bob. // from Alice to Bob.
_, err = aliceChannel.SignNextCommitment() testQuit, testQuitFunc := context.WithCancel(context.Background())
if err != nil { t.Cleanup(testQuitFunc)
t.Fatal(err) _, err = aliceChannel.SignNextCommitment(testQuit)
} require.NoError(t, err)
// At this point, we'll now Bob broadcasting this new pending unrevoked // At this point, we'll now Bob broadcasting this new pending unrevoked
// commitment. // commitment.
bobPendingCommit, err := aliceChannel.State().RemoteCommitChainTip() bobPendingCommit, err := aliceChannel.State().RemoteCommitChainTip()
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// We'll craft a fake spend notification with Bob's actual commitment. // We'll craft a fake spend notification with Bob's actual commitment.
// The chain watcher should be able to detect that this is a pending // The chain watcher should be able to detect that this is a pending
@ -505,14 +504,24 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
// outputs. // outputs.
select { select {
case summary := <-chanEvents.LocalUnilateralClosure: case summary := <-chanEvents.LocalUnilateralClosure:
resOpt := summary.LocalForceCloseSummary.
ContractResolutions
resolutions, err := resOpt.UnwrapOrErr(
fmt.Errorf("resolutions not found"),
)
if err != nil {
t.Fatalf("unable to get resolutions: %v", err)
}
// Make sure we correctly extracted the commit // Make sure we correctly extracted the commit
// resolution if we had a local output. // resolution if we had a local output.
if remoteOutputOnly { if remoteOutputOnly {
if summary.CommitResolution != nil { if resolutions.CommitResolution != nil {
t.Fatalf("expected no commit resolution") t.Fatalf("expected no commit resolution")
} }
} else { } else {
if summary.CommitResolution == nil { if resolutions.CommitResolution == nil {
t.Fatalf("expected commit resolution") t.Fatalf("expected commit resolution")
} }
} }

View file

@ -98,7 +98,7 @@ type ArbChannel interface {
// corresponding link, such that we won't accept any new updates. The // corresponding link, such that we won't accept any new updates. The
// returned summary contains all items needed to eventually resolve all // returned summary contains all items needed to eventually resolve all
// outputs on chain. // outputs on chain.
ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) ForceCloseChan() (*wire.MsgTx, error)
// NewAnchorResolutions returns the anchor resolutions for currently // NewAnchorResolutions returns the anchor resolutions for currently
// valid commitment transactions. // valid commitment transactions.
@ -1058,7 +1058,7 @@ func (c *ChannelArbitrator) stateStep(
// We'll tell the switch that it should remove the link for // We'll tell the switch that it should remove the link for
// this channel, in addition to fetching the force close // this channel, in addition to fetching the force close
// summary needed to close this channel on chain. // summary needed to close this channel on chain.
closeSummary, err := c.cfg.Channel.ForceCloseChan() forceCloseTx, err := c.cfg.Channel.ForceCloseChan()
if err != nil { if err != nil {
log.Errorf("ChannelArbitrator(%v): unable to "+ log.Errorf("ChannelArbitrator(%v): unable to "+
"force close: %v", c.cfg.ChanPoint, err) "force close: %v", c.cfg.ChanPoint, err)
@ -1078,7 +1078,7 @@ func (c *ChannelArbitrator) stateStep(
return StateError, closeTx, err return StateError, closeTx, err
} }
closeTx = closeSummary.CloseTx closeTx = forceCloseTx
// Before publishing the transaction, we store it to the // Before publishing the transaction, we store it to the
// database, such that we can re-publish later in case it // database, such that we can re-publish later in case it
@ -1982,9 +1982,11 @@ func (c *ChannelArbitrator) isPreimageAvailable(hash lntypes.Hash) (bool,
// have the incoming contest resolver decide that we don't want to // have the incoming contest resolver decide that we don't want to
// settle this invoice. // settle this invoice.
invoice, err := c.cfg.Registry.LookupInvoice(context.Background(), hash) invoice, err := c.cfg.Registry.LookupInvoice(context.Background(), hash)
switch err { switch {
case nil: case err == nil:
case invoices.ErrInvoiceNotFound, invoices.ErrNoInvoicesCreated: case errors.Is(err, invoices.ErrInvoiceNotFound) ||
errors.Is(err, invoices.ErrNoInvoicesCreated):
return false, nil return false, nil
default: default:
return false, err return false, err
@ -2869,11 +2871,36 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
} }
closeTx := closeInfo.CloseTx closeTx := closeInfo.CloseTx
resolutions, err := closeInfo.ContractResolutions.
UnwrapOrErr(
fmt.Errorf("resolutions not found"),
)
if err != nil {
log.Errorf("ChannelArbitrator(%v): unable to "+
"get resolutions: %v", c.cfg.ChanPoint,
err)
return
}
// We make sure that the htlc resolutions are present
// otherwise we would panic dereferencing the pointer.
//
// TODO(ziggie): Refactor ContractResolutions to use
// options.
if resolutions.HtlcResolutions == nil {
log.Errorf("ChannelArbitrator(%v): htlc "+
"resolutions not found",
c.cfg.ChanPoint)
return
}
contractRes := &ContractResolutions{ contractRes := &ContractResolutions{
CommitHash: closeTx.TxHash(), CommitHash: closeTx.TxHash(),
CommitResolution: closeInfo.CommitResolution, CommitResolution: resolutions.CommitResolution,
HtlcResolutions: *closeInfo.HtlcResolutions, HtlcResolutions: *resolutions.HtlcResolutions,
AnchorResolution: closeInfo.AnchorResolution, AnchorResolution: resolutions.AnchorResolution,
} }
// When processing a unilateral close event, we'll // When processing a unilateral close event, we'll
@ -2882,7 +2909,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
// available to fetch in that state, we'll also write // available to fetch in that state, we'll also write
// the commit set so we can reconstruct our chain // the commit set so we can reconstruct our chain
// actions on restart. // actions on restart.
err := c.log.LogContractResolutions(contractRes) err = c.log.LogContractResolutions(contractRes)
if err != nil { if err != nil {
log.Errorf("Unable to write resolutions: %v", log.Errorf("Unable to write resolutions: %v",
err) err)

View file

@ -693,11 +693,15 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
chanArbCtx.AssertState(StateCommitmentBroadcasted) chanArbCtx.AssertState(StateCommitmentBroadcasted)
// Now notify about the local force close getting confirmed. // Now notify about the local force close getting confirmed.
//
//nolint:lll
chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{
SpendDetail: &chainntnfs.SpendDetail{}, SpendDetail: &chainntnfs.SpendDetail{},
LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{ LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{
CloseTx: &wire.MsgTx{}, CloseTx: &wire.MsgTx{},
HtlcResolutions: &lnwallet.HtlcResolutions{}, ContractResolutions: fn.Some(lnwallet.ContractResolutions{
HtlcResolutions: &lnwallet.HtlcResolutions{},
}),
}, },
ChannelCloseSummary: &channeldb.ChannelCloseSummary{}, ChannelCloseSummary: &channeldb.ChannelCloseSummary{},
} }
@ -969,15 +973,18 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
}, },
} }
//nolint:lll
chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{
SpendDetail: &chainntnfs.SpendDetail{}, SpendDetail: &chainntnfs.SpendDetail{},
LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{ LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{
CloseTx: closeTx, CloseTx: closeTx,
HtlcResolutions: &lnwallet.HtlcResolutions{ ContractResolutions: fn.Some(lnwallet.ContractResolutions{
OutgoingHTLCs: []lnwallet.OutgoingHtlcResolution{ HtlcResolutions: &lnwallet.HtlcResolutions{
outgoingRes, OutgoingHTLCs: []lnwallet.OutgoingHtlcResolution{
outgoingRes,
},
}, },
}, }),
}, },
ChannelCloseSummary: &channeldb.ChannelCloseSummary{}, ChannelCloseSummary: &channeldb.ChannelCloseSummary{},
CommitSet: CommitSet{ CommitSet: CommitSet{
@ -1611,12 +1618,15 @@ func TestChannelArbitratorCommitFailure(t *testing.T) {
}, },
{ {
closeType: channeldb.LocalForceClose, closeType: channeldb.LocalForceClose,
//nolint:lll
sendEvent: func(chanArb *ChannelArbitrator) { sendEvent: func(chanArb *ChannelArbitrator) {
chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{
SpendDetail: &chainntnfs.SpendDetail{}, SpendDetail: &chainntnfs.SpendDetail{},
LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{ LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{
CloseTx: &wire.MsgTx{}, CloseTx: &wire.MsgTx{},
HtlcResolutions: &lnwallet.HtlcResolutions{}, ContractResolutions: fn.Some(lnwallet.ContractResolutions{
HtlcResolutions: &lnwallet.HtlcResolutions{},
}),
}, },
ChannelCloseSummary: &channeldb.ChannelCloseSummary{}, ChannelCloseSummary: &channeldb.ChannelCloseSummary{},
} }
@ -1944,11 +1954,15 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// being canalled back. Also note that there're no HTLC // being canalled back. Also note that there're no HTLC
// resolutions sent since we have none on our // resolutions sent since we have none on our
// commitment transaction. // commitment transaction.
//
//nolint:lll
uniCloseInfo := &LocalUnilateralCloseInfo{ uniCloseInfo := &LocalUnilateralCloseInfo{
SpendDetail: &chainntnfs.SpendDetail{}, SpendDetail: &chainntnfs.SpendDetail{},
LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{ LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{
CloseTx: closeTx, CloseTx: closeTx,
HtlcResolutions: &lnwallet.HtlcResolutions{}, ContractResolutions: fn.Some(lnwallet.ContractResolutions{
HtlcResolutions: &lnwallet.HtlcResolutions{},
}),
}, },
ChannelCloseSummary: &channeldb.ChannelCloseSummary{}, ChannelCloseSummary: &channeldb.ChannelCloseSummary{},
CommitSet: CommitSet{ CommitSet: CommitSet{
@ -2754,12 +2768,15 @@ func TestChannelArbitratorAnchors(t *testing.T) {
}, },
} }
//nolint:lll
chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{
SpendDetail: &chainntnfs.SpendDetail{}, SpendDetail: &chainntnfs.SpendDetail{},
LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{ LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{
CloseTx: closeTx, CloseTx: closeTx,
HtlcResolutions: &lnwallet.HtlcResolutions{}, ContractResolutions: fn.Some(lnwallet.ContractResolutions{
AnchorResolution: anchorResolution, HtlcResolutions: &lnwallet.HtlcResolutions{},
AnchorResolution: anchorResolution,
}),
}, },
ChannelCloseSummary: &channeldb.ChannelCloseSummary{}, ChannelCloseSummary: &channeldb.ChannelCloseSummary{},
CommitSet: CommitSet{ CommitSet: CommitSet{
@ -2993,14 +3010,10 @@ func (m *mockChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions,
return &lnwallet.AnchorResolutions{}, nil return &lnwallet.AnchorResolutions{}, nil
} }
func (m *mockChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) { func (m *mockChannel) ForceCloseChan() (*wire.MsgTx, error) {
if m.forceCloseErr != nil { if m.forceCloseErr != nil {
return nil, m.forceCloseErr return nil, m.forceCloseErr
} }
summary := &lnwallet.LocalForceCloseSummary{ return &wire.MsgTx{}, nil
CloseTx: &wire.MsgTx{},
HtlcResolutions: &lnwallet.HtlcResolutions{},
}
return summary, nil
} }

View file

@ -345,12 +345,18 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) {
&c.commitResolution.SelfOutputSignDesc, &c.commitResolution.SelfOutputSignDesc,
c.broadcastHeight, c.commitResolution.MaturityDelay, c.broadcastHeight, c.commitResolution.MaturityDelay,
c.leaseExpiry, c.leaseExpiry,
input.WithResolutionBlob(
c.commitResolution.ResolutionBlob,
),
) )
} else { } else {
inp = input.NewCsvInput( inp = input.NewCsvInput(
&c.commitResolution.SelfOutPoint, witnessType, &c.commitResolution.SelfOutPoint, witnessType,
&c.commitResolution.SelfOutputSignDesc, &c.commitResolution.SelfOutputSignDesc,
c.broadcastHeight, c.commitResolution.MaturityDelay, c.broadcastHeight, c.commitResolution.MaturityDelay,
input.WithResolutionBlob(
c.commitResolution.ResolutionBlob,
),
) )
} }

View file

@ -308,7 +308,7 @@ func (h *htlcIncomingContestResolver) Resolve(
resolution, err := h.Registry.NotifyExitHopHtlc( resolution, err := h.Registry.NotifyExitHopHtlc(
h.htlc.RHash, h.htlc.Amt, h.htlcExpiry, currentHeight, h.htlc.RHash, h.htlc.Amt, h.htlcExpiry, currentHeight,
circuitKey, hodlQueue.ChanIn(), payload, circuitKey, hodlQueue.ChanIn(), nil, payload,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -6,7 +6,9 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/tlv"
) )
// htlcLeaseResolver is a struct that houses the lease specific HTLC resolution // htlcLeaseResolver is a struct that houses the lease specific HTLC resolution
@ -52,8 +54,8 @@ func (h *htlcLeaseResolver) deriveWaitHeight(csvDelay uint32,
// send to the sweeper so the output can ultimately be swept. // send to the sweeper so the output can ultimately be swept.
func (h *htlcLeaseResolver) makeSweepInput(op *wire.OutPoint, func (h *htlcLeaseResolver) makeSweepInput(op *wire.OutPoint,
wType, cltvWtype input.StandardWitnessType, wType, cltvWtype input.StandardWitnessType,
signDesc *input.SignDescriptor, signDesc *input.SignDescriptor, csvDelay, broadcastHeight uint32,
csvDelay, broadcastHeight uint32, payHash [32]byte) *input.BaseInput { payHash [32]byte, resBlob fn.Option[tlv.Blob]) *input.BaseInput {
if h.hasCLTV() { if h.hasCLTV() {
log.Infof("%T(%x): CSV and CLTV locks expired, offering "+ log.Infof("%T(%x): CSV and CLTV locks expired, offering "+
@ -63,13 +65,17 @@ func (h *htlcLeaseResolver) makeSweepInput(op *wire.OutPoint,
op, cltvWtype, signDesc, op, cltvWtype, signDesc,
broadcastHeight, csvDelay, broadcastHeight, csvDelay,
h.leaseExpiry, h.leaseExpiry,
input.WithResolutionBlob(resBlob),
) )
} }
log.Infof("%T(%x): CSV lock expired, offering second-layer output to "+ log.Infof("%T(%x): CSV lock expired, offering second-layer output to "+
"sweeper: %v", h, payHash, op) "sweeper: %v", h, payHash, op)
return input.NewCsvInput(op, wType, signDesc, broadcastHeight, csvDelay) return input.NewCsvInput(
op, wType, signDesc, broadcastHeight, csvDelay,
input.WithResolutionBlob(resBlob),
)
} }
// SupplementState allows the user of a ContractResolver to supplement it with // SupplementState allows the user of a ContractResolver to supplement it with

View file

@ -247,6 +247,9 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) (
h.htlcResolution.SignedSuccessTx, h.htlcResolution.SignedSuccessTx,
h.htlcResolution.SignDetails, h.htlcResolution.Preimage, h.htlcResolution.SignDetails, h.htlcResolution.Preimage,
h.broadcastHeight, h.broadcastHeight,
input.WithResolutionBlob(
h.htlcResolution.ResolutionBlob,
),
) )
} else { } else {
//nolint:lll //nolint:lll
@ -403,7 +406,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) (
input.LeaseHtlcAcceptedSuccessSecondLevel, input.LeaseHtlcAcceptedSuccessSecondLevel,
&h.htlcResolution.SweepSignDesc, &h.htlcResolution.SweepSignDesc,
h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight), h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight),
h.htlc.RHash, h.htlc.RHash, h.htlcResolution.ResolutionBlob,
) )
// Calculate the budget for this sweep. // Calculate the budget for this sweep.
@ -459,6 +462,9 @@ func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) (
h.htlcResolution.Preimage[:], h.htlcResolution.Preimage[:],
h.broadcastHeight, h.broadcastHeight,
h.htlcResolution.CsvDelay, h.htlcResolution.CsvDelay,
input.WithResolutionBlob(
h.htlcResolution.ResolutionBlob,
),
)) ))
} else { } else {
inp = lnutils.Ptr(input.MakeHtlcSucceedInput( inp = lnutils.Ptr(input.MakeHtlcSucceedInput(

View file

@ -484,6 +484,9 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error {
h.htlcResolution.SignedTimeoutTx, h.htlcResolution.SignedTimeoutTx,
h.htlcResolution.SignDetails, h.htlcResolution.SignDetails,
h.broadcastHeight, h.broadcastHeight,
input.WithResolutionBlob(
h.htlcResolution.ResolutionBlob,
),
)) ))
} else { } else {
inp = lnutils.Ptr(input.MakeHtlcSecondLevelTimeoutAnchorInput( inp = lnutils.Ptr(input.MakeHtlcSecondLevelTimeoutAnchorInput(
@ -538,7 +541,6 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error {
return err return err
} }
// TODO(yy): checkpoint here?
return err return err
} }
@ -562,6 +564,60 @@ func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error {
return h.Checkpoint(h) return h.Checkpoint(h)
} }
// sweepDirectHtlcOutput sends the direct spend of the HTLC output to the
// sweeper. This is used when the remote party goes on chain, and we're able to
// sweep an HTLC we offered after a timeout. Only the CLTV encumbered outputs
// are resolved via this path.
func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error {
var htlcWitnessType input.StandardWitnessType
if h.isTaproot() {
htlcWitnessType = input.TaprootHtlcOfferedRemoteTimeout
} else {
htlcWitnessType = input.HtlcOfferedRemoteTimeout
}
sweepInput := input.NewCsvInputWithCltv(
&h.htlcResolution.ClaimOutpoint, htlcWitnessType,
&h.htlcResolution.SweepSignDesc, h.broadcastHeight,
h.htlcResolution.CsvDelay, h.htlcResolution.Expiry,
input.WithResolutionBlob(h.htlcResolution.ResolutionBlob),
)
// Calculate the budget.
//
// TODO(yy): the budget is twice the output's value, which is needed as
// we don't force sweep the output now. To prevent cascading force
// closes, we use all its output value plus a wallet input as the
// budget. This is a temporary solution until we can optionally cancel
// the incoming HTLC, more details in,
// - https://github.com/lightningnetwork/lnd/issues/7969
budget := calculateBudget(
btcutil.Amount(sweepInput.SignDesc().Output.Value), 2, 0,
)
log.Infof("%T(%x): offering offered remote timeout HTLC output to "+
"sweeper with deadline %v and budget=%v at height=%v",
h, h.htlc.RHash[:], h.incomingHTLCExpiryHeight, budget,
h.broadcastHeight)
_, err := h.Sweeper.SweepInput(
sweepInput,
sweep.Params{
Budget: budget,
// This is an outgoing HTLC, so we want to make sure
// that we sweep it before the incoming HTLC expires.
DeadlineHeight: h.incomingHTLCExpiryHeight,
Immediate: immediate,
},
)
if err != nil {
return err
}
return nil
}
// spendHtlcOutput handles the initial spend of an HTLC output via the timeout // spendHtlcOutput handles the initial spend of an HTLC output via the timeout
// clause. If this is our local commitment, the second-level timeout TX will be // clause. If this is our local commitment, the second-level timeout TX will be
// used to spend the output into the next stage. If this is the remote // used to spend the output into the next stage. If this is the remote
@ -582,8 +638,18 @@ func (h *htlcTimeoutResolver) spendHtlcOutput(
return nil, err return nil, err
} }
// If we have no SignDetails, and we haven't already sent the output to // If this is a remote commitment there's no second level timeout txn,
// the utxo nursery, then we'll do so now. // and we can just send this directly to the sweeper.
case h.htlcResolution.SignedTimeoutTx == nil && !h.outputIncubating:
if err := h.sweepDirectHtlcOutput(immediate); err != nil {
log.Errorf("Sending direct spend to sweeper: %v", err)
return nil, err
}
// If we have a SignedTimeoutTx but no SignDetails, this is a local
// commitment for a non-anchor channel, so we'll send it to the utxo
// nursery.
case h.htlcResolution.SignDetails == nil && !h.outputIncubating: case h.htlcResolution.SignDetails == nil && !h.outputIncubating:
if err := h.sendSecondLevelTxLegacy(); err != nil { if err := h.sendSecondLevelTxLegacy(); err != nil {
log.Errorf("Sending timeout tx to nursery: %v", err) log.Errorf("Sending timeout tx to nursery: %v", err)
@ -690,6 +756,13 @@ func (h *htlcTimeoutResolver) handleCommitSpend(
) )
switch { switch {
// If we swept an HTLC directly off the remote party's commitment
// transaction, then we can exit here as there's no second level sweep
// to do.
case h.htlcResolution.SignedTimeoutTx == nil:
break
// If the sweeper is handling the second level transaction, wait for // If the sweeper is handling the second level transaction, wait for
// the CSV and possible CLTV lock to expire, before sweeping the output // the CSV and possible CLTV lock to expire, before sweeping the output
// on the second-level. // on the second-level.
@ -762,7 +835,9 @@ func (h *htlcTimeoutResolver) handleCommitSpend(
&h.htlcResolution.SweepSignDesc, &h.htlcResolution.SweepSignDesc,
h.htlcResolution.CsvDelay, h.htlcResolution.CsvDelay,
uint32(commitSpend.SpendingHeight), h.htlc.RHash, uint32(commitSpend.SpendingHeight), h.htlc.RHash,
h.htlcResolution.ResolutionBlob,
) )
// Calculate the budget for this sweep. // Calculate the budget for this sweep.
budget := calculateBudget( budget := calculateBudget(
btcutil.Amount(inp.SignDesc().Output.Value), btcutil.Amount(inp.SignDesc().Output.Value),
@ -800,6 +875,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend(
case h.htlcResolution.SignedTimeoutTx != nil: case h.htlcResolution.SignedTimeoutTx != nil:
log.Infof("%T(%v): waiting for nursery/sweeper to spend CSV "+ log.Infof("%T(%v): waiting for nursery/sweeper to spend CSV "+
"delayed output", h, claimOutpoint) "delayed output", h, claimOutpoint)
sweepTx, err := waitForSpend( sweepTx, err := waitForSpend(
&claimOutpoint, &claimOutpoint,
h.htlcResolution.SweepSignDesc.Output.PkScript, h.htlcResolution.SweepSignDesc.Output.PkScript,
@ -866,9 +942,11 @@ func (h *htlcTimeoutResolver) IsResolved() bool {
// report returns a report on the resolution state of the contract. // report returns a report on the resolution state of the contract.
func (h *htlcTimeoutResolver) report() *ContractReport { func (h *htlcTimeoutResolver) report() *ContractReport {
// If the sign details are nil, the report will be created by handled // If we have a SignedTimeoutTx but no SignDetails, this is a local
// by the nursery. // commitment for a non-anchor channel, which was handled by the utxo
if h.htlcResolution.SignDetails == nil { // nursery.
if h.htlcResolution.SignDetails == nil && h.
htlcResolution.SignedTimeoutTx != nil {
return nil return nil
} }
@ -888,13 +966,20 @@ func (h *htlcTimeoutResolver) initReport() {
) )
} }
// If there's no timeout transaction, then we're already effectively in
// level two.
stage := uint32(1)
if h.htlcResolution.SignedTimeoutTx == nil {
stage = 2
}
h.currentReport = ContractReport{ h.currentReport = ContractReport{
Outpoint: h.htlcResolution.ClaimOutpoint, Outpoint: h.htlcResolution.ClaimOutpoint,
Type: ReportOutputOutgoingHtlc, Type: ReportOutputOutgoingHtlc,
Amount: finalAmt, Amount: finalAmt,
MaturityHeight: h.htlcResolution.Expiry, MaturityHeight: h.htlcResolution.Expiry,
LimboBalance: finalAmt, LimboBalance: finalAmt,
Stage: 1, Stage: stage,
} }
} }

View file

@ -69,11 +69,31 @@ func (m *mockWitnessBeacon) AddPreimages(preimages ...lntypes.Preimage) error {
return nil return nil
} }
// TestHtlcTimeoutResolver tests that the timeout resolver properly handles all type htlcTimeoutTestCase struct {
// variations of possible local+remote spends. // name is a human readable description of the test case.
func TestHtlcTimeoutResolver(t *testing.T) { name string
t.Parallel()
// remoteCommit denotes if the commitment broadcast was the remote
// commitment or not.
remoteCommit bool
// timeout denotes if the HTLC should be let timeout, or if the "remote"
// party should sweep it on-chain. This also affects what type of
// resolution message we expect.
timeout bool
// txToBroadcast is a function closure that should generate the
// transaction that should spend the HTLC output. Test authors can use
// this to customize the witness used when spending to trigger various
// redemption cases.
txToBroadcast func() (*wire.MsgTx, error)
// outcome is the resolver outcome that we expect to be reported once
// the contract is fully resolved.
outcome channeldb.ResolverOutcome
}
func genHtlcTimeoutTestCases() []htlcTimeoutTestCase {
fakePreimageBytes := bytes.Repeat([]byte{1}, lntypes.HashSize) fakePreimageBytes := bytes.Repeat([]byte{1}, lntypes.HashSize)
var ( var (
@ -105,29 +125,7 @@ func TestHtlcTimeoutResolver(t *testing.T) {
}, },
} }
testCases := []struct { return []htlcTimeoutTestCase{
// name is a human readable description of the test case.
name string
// remoteCommit denotes if the commitment broadcast was the
// remote commitment or not.
remoteCommit bool
// timeout denotes if the HTLC should be let timeout, or if the
// "remote" party should sweep it on-chain. This also affects
// what type of resolution message we expect.
timeout bool
// txToBroadcast is a function closure that should generate the
// transaction that should spend the HTLC output. Test authors
// can use this to customize the witness used when spending to
// trigger various redemption cases.
txToBroadcast func() (*wire.MsgTx, error)
// outcome is the resolver outcome that we expect to be reported
// once the contract is fully resolved.
outcome channeldb.ResolverOutcome
}{
// Remote commitment is broadcast, we time out the HTLC on // Remote commitment is broadcast, we time out the HTLC on
// chain, and should expect a fail HTLC resolution. // chain, and should expect a fail HTLC resolution.
{ {
@ -149,7 +147,8 @@ func TestHtlcTimeoutResolver(t *testing.T) {
// immediately if the witness is already set // immediately if the witness is already set
// correctly. // correctly.
if reflect.DeepEqual( if reflect.DeepEqual(
templateTx.TxIn[0].Witness, witness, templateTx.TxIn[0].Witness,
witness,
) { ) {
return templateTx, nil return templateTx, nil
@ -219,7 +218,8 @@ func TestHtlcTimeoutResolver(t *testing.T) {
// immediately if the witness is already set // immediately if the witness is already set
// correctly. // correctly.
if reflect.DeepEqual( if reflect.DeepEqual(
templateTx.TxIn[0].Witness, witness, templateTx.TxIn[0].Witness,
witness,
) { ) {
return templateTx, nil return templateTx, nil
@ -253,7 +253,8 @@ func TestHtlcTimeoutResolver(t *testing.T) {
// immediately if the witness is already set // immediately if the witness is already set
// correctly. // correctly.
if reflect.DeepEqual( if reflect.DeepEqual(
templateTx.TxIn[0].Witness, witness, templateTx.TxIn[0].Witness,
witness,
) { ) {
return templateTx, nil return templateTx, nil
@ -265,243 +266,280 @@ func TestHtlcTimeoutResolver(t *testing.T) {
outcome: channeldb.ResolverOutcomeClaimed, outcome: channeldb.ResolverOutcomeClaimed,
}, },
} }
}
func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
fakePreimageBytes := bytes.Repeat([]byte{1}, lntypes.HashSize)
var fakePreimage lntypes.Preimage
fakeSignDesc := &input.SignDescriptor{
Output: &wire.TxOut{},
}
copy(fakePreimage[:], fakePreimageBytes)
notifier := &mock.ChainNotifier{ notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch), EpochChan: make(chan *chainntnfs.BlockEpoch),
SpendChan: make(chan *chainntnfs.SpendDetail), SpendChan: make(chan *chainntnfs.SpendDetail),
ConfChan: make(chan *chainntnfs.TxConfirmation), ConfChan: make(chan *chainntnfs.TxConfirmation),
} }
witnessBeacon := newMockWitnessBeacon() witnessBeacon := newMockWitnessBeacon()
checkPointChan := make(chan struct{}, 1)
incubateChan := make(chan struct{}, 1)
resolutionChan := make(chan ResolutionMsg, 1)
reportChan := make(chan *channeldb.ResolverReport)
//nolint:lll
chainCfg := ChannelArbitratorConfig{
ChainArbitratorConfig: ChainArbitratorConfig{
Notifier: notifier,
Sweeper: newMockSweeper(),
PreimageDB: witnessBeacon,
IncubateOutputs: func(wire.OutPoint,
fn.Option[lnwallet.OutgoingHtlcResolution],
fn.Option[lnwallet.IncomingHtlcResolution],
uint32, fn.Option[int32]) error {
incubateChan <- struct{}{}
return nil
},
DeliverResolutionMsg: func(msgs ...ResolutionMsg) error {
if len(msgs) != 1 {
return fmt.Errorf("expected 1 "+
"resolution msg, instead got %v",
len(msgs))
}
resolutionChan <- msgs[0]
return nil
},
Budget: *DefaultBudgetConfig(),
QueryIncomingCircuit: func(circuit models.CircuitKey,
) *models.CircuitKey {
return nil
},
},
PutResolverReport: func(_ kvdb.RwTx,
_ *channeldb.ResolverReport) error {
return nil
},
}
cfg := ResolverConfig{
ChannelArbitratorConfig: chainCfg,
Checkpoint: func(_ ContractResolver,
reports ...*channeldb.ResolverReport) error {
checkPointChan <- struct{}{}
// Send all of our reports into the channel.
for _, report := range reports {
reportChan <- report
}
return nil
},
}
resolver := &htlcTimeoutResolver{
htlcResolution: lnwallet.OutgoingHtlcResolution{
ClaimOutpoint: testChanPoint2,
SweepSignDesc: *fakeSignDesc,
},
contractResolverKit: *newContractResolverKit(
cfg,
),
htlc: channeldb.HTLC{
Amt: testHtlcAmt,
},
}
var reports []*channeldb.ResolverReport
// If the test case needs the remote commitment to be
// broadcast, then we'll set the timeout commit to a fake
// transaction to force the code path.
if !testCase.remoteCommit {
timeoutTx, err := testCase.txToBroadcast()
require.NoError(t, err)
resolver.htlcResolution.SignedTimeoutTx = timeoutTx
if testCase.timeout {
timeoutTxID := timeoutTx.TxHash()
report := &channeldb.ResolverReport{
OutPoint: timeoutTx.TxIn[0].PreviousOutPoint, //nolint:lll
Amount: testHtlcAmt.ToSatoshis(),
ResolverType: channeldb.ResolverTypeOutgoingHtlc, //nolint:lll
ResolverOutcome: channeldb.ResolverOutcomeFirstStage, //nolint:lll
SpendTxID: &timeoutTxID,
}
reports = append(reports, report)
}
}
// With all the setup above complete, we can initiate the
// resolution process, and the bulk of our test.
var wg sync.WaitGroup
resolveErr := make(chan error, 1)
wg.Add(1)
go func() {
defer wg.Done()
_, err := resolver.Resolve(false)
if err != nil {
resolveErr <- err
}
}()
// If this is a remote commit, then we expct the outputs should receive
// an incubation request to go through the sweeper, otherwise the
// nursery.
var sweepChan chan input.Input
if testCase.remoteCommit {
mockSweeper, ok := resolver.Sweeper.(*mockSweeper)
require.True(t, ok)
sweepChan = mockSweeper.sweptInputs
}
// The output should be offered to either the sweeper or
// the nursery.
select {
case <-incubateChan:
case <-sweepChan:
case err := <-resolveErr:
t.Fatalf("unable to resolve HTLC: %v", err)
case <-time.After(time.Second * 5):
t.Fatalf("failed to receive incubation request")
}
// Next, the resolver should request a spend notification for
// the direct HTLC output. We'll use the txToBroadcast closure
// for the test case to generate the transaction that we'll
// send to the resolver.
spendingTx, err := testCase.txToBroadcast()
if err != nil {
t.Fatalf("unable to generate tx: %v", err)
}
spendTxHash := spendingTx.TxHash()
select {
case notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: spendingTx,
SpenderTxHash: &spendTxHash,
}:
case <-time.After(time.Second * 5):
t.Fatalf("failed to request spend ntfn")
}
if !testCase.timeout {
// If the resolver should settle now, then we'll
// extract the pre-image to be extracted and the
// resolution message sent.
select {
case newPreimage := <-witnessBeacon.newPreimages:
if newPreimage[0] != fakePreimage {
t.Fatalf("wrong pre-image: "+
"expected %v, got %v",
fakePreimage, newPreimage)
}
case <-time.After(time.Second * 5):
t.Fatalf("pre-image not added")
}
// Finally, we should get a resolution message with the
// pre-image set within the message.
select {
case resolutionMsg := <-resolutionChan:
// Once again, the pre-images should match up.
if *resolutionMsg.PreImage != fakePreimage {
t.Fatalf("wrong pre-image: "+
"expected %v, got %v",
fakePreimage, resolutionMsg.PreImage)
}
case <-time.After(time.Second * 5):
t.Fatalf("resolution not sent")
}
} else {
// Otherwise, the HTLC should now timeout. First, we
// should get a resolution message with a populated
// failure message.
select {
case resolutionMsg := <-resolutionChan:
if resolutionMsg.Failure == nil {
t.Fatalf("expected failure resolution msg")
}
case <-time.After(time.Second * 5):
t.Fatalf("resolution not sent")
}
// We should also get another request for the spend
// notification of the second-level transaction to
// indicate that it's been swept by the nursery, but
// only if this is a local commitment transaction.
if !testCase.remoteCommit {
select {
case notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: spendingTx,
SpenderTxHash: &spendTxHash,
}:
case <-time.After(time.Second * 5):
t.Fatalf("failed to request spend ntfn")
}
}
}
// In any case, before the resolver exits, it should checkpoint
// its final state.
select {
case <-checkPointChan:
case err := <-resolveErr:
t.Fatalf("unable to resolve HTLC: %v", err)
case <-time.After(time.Second * 5):
t.Fatalf("check point not received")
}
// Add a report to our set of expected reports with the outcome
// that the test specifies (either success or timeout).
spendTxID := spendingTx.TxHash()
amt := btcutil.Amount(fakeSignDesc.Output.Value)
reports = append(reports, &channeldb.ResolverReport{
OutPoint: testChanPoint2,
Amount: amt,
ResolverType: channeldb.ResolverTypeOutgoingHtlc,
ResolverOutcome: testCase.outcome,
SpendTxID: &spendTxID,
})
for _, report := range reports {
assertResolverReport(t, reportChan, report)
}
wg.Wait()
// Finally, the resolver should be marked as resolved.
if !resolver.resolved {
t.Fatalf("resolver should be marked as resolved")
}
}
// TestHtlcTimeoutResolver tests that the timeout resolver properly handles all
// variations of possible local+remote spends.
func TestHtlcTimeoutResolver(t *testing.T) {
t.Parallel()
testCases := genHtlcTimeoutTestCases()
for _, testCase := range testCases { for _, testCase := range testCases {
t.Logf("Running test case: %v", testCase.name) t.Run(testCase.name, func(t *testing.T) {
testHtlcTimeoutResolver(t, testCase)
checkPointChan := make(chan struct{}, 1)
incubateChan := make(chan struct{}, 1)
resolutionChan := make(chan ResolutionMsg, 1)
reportChan := make(chan *channeldb.ResolverReport)
//nolint:lll
chainCfg := ChannelArbitratorConfig{
ChainArbitratorConfig: ChainArbitratorConfig{
Notifier: notifier,
PreimageDB: witnessBeacon,
IncubateOutputs: func(wire.OutPoint,
fn.Option[lnwallet.OutgoingHtlcResolution],
fn.Option[lnwallet.IncomingHtlcResolution],
uint32, fn.Option[int32]) error {
incubateChan <- struct{}{}
return nil
},
DeliverResolutionMsg: func(msgs ...ResolutionMsg) error {
if len(msgs) != 1 {
return fmt.Errorf("expected 1 "+
"resolution msg, instead got %v",
len(msgs))
}
resolutionChan <- msgs[0]
return nil
},
Budget: *DefaultBudgetConfig(),
QueryIncomingCircuit: func(circuit models.CircuitKey) *models.CircuitKey {
return nil
},
},
PutResolverReport: func(_ kvdb.RwTx,
_ *channeldb.ResolverReport) error {
return nil
},
}
cfg := ResolverConfig{
ChannelArbitratorConfig: chainCfg,
Checkpoint: func(_ ContractResolver,
reports ...*channeldb.ResolverReport) error {
checkPointChan <- struct{}{}
// Send all of our reports into the channel.
for _, report := range reports {
reportChan <- report
}
return nil
},
}
resolver := &htlcTimeoutResolver{
htlcResolution: lnwallet.OutgoingHtlcResolution{
ClaimOutpoint: testChanPoint2,
SweepSignDesc: *fakeSignDesc,
},
contractResolverKit: *newContractResolverKit(
cfg,
),
htlc: channeldb.HTLC{
Amt: testHtlcAmt,
},
}
var reports []*channeldb.ResolverReport
// If the test case needs the remote commitment to be
// broadcast, then we'll set the timeout commit to a fake
// transaction to force the code path.
if !testCase.remoteCommit {
timeoutTx, err := testCase.txToBroadcast()
require.NoError(t, err)
resolver.htlcResolution.SignedTimeoutTx = timeoutTx
if testCase.timeout {
timeoutTxID := timeoutTx.TxHash()
reports = append(reports, &channeldb.ResolverReport{
OutPoint: timeoutTx.TxIn[0].PreviousOutPoint,
Amount: testHtlcAmt.ToSatoshis(),
ResolverType: channeldb.ResolverTypeOutgoingHtlc,
ResolverOutcome: channeldb.ResolverOutcomeFirstStage,
SpendTxID: &timeoutTxID,
})
}
}
// With all the setup above complete, we can initiate the
// resolution process, and the bulk of our test.
var wg sync.WaitGroup
resolveErr := make(chan error, 1)
wg.Add(1)
go func() {
defer wg.Done()
_, err := resolver.Resolve(false)
if err != nil {
resolveErr <- err
}
}()
// At the output isn't yet in the nursery, we expect that we
// should receive an incubation request.
select {
case <-incubateChan:
case err := <-resolveErr:
t.Fatalf("unable to resolve HTLC: %v", err)
case <-time.After(time.Second * 5):
t.Fatalf("failed to receive incubation request")
}
// Next, the resolver should request a spend notification for
// the direct HTLC output. We'll use the txToBroadcast closure
// for the test case to generate the transaction that we'll
// send to the resolver.
spendingTx, err := testCase.txToBroadcast()
if err != nil {
t.Fatalf("unable to generate tx: %v", err)
}
spendTxHash := spendingTx.TxHash()
select {
case notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: spendingTx,
SpenderTxHash: &spendTxHash,
}:
case <-time.After(time.Second * 5):
t.Fatalf("failed to request spend ntfn")
}
if !testCase.timeout {
// If the resolver should settle now, then we'll
// extract the pre-image to be extracted and the
// resolution message sent.
select {
case newPreimage := <-witnessBeacon.newPreimages:
if newPreimage[0] != fakePreimage {
t.Fatalf("wrong pre-image: "+
"expected %v, got %v",
fakePreimage, newPreimage)
}
case <-time.After(time.Second * 5):
t.Fatalf("pre-image not added")
}
// Finally, we should get a resolution message with the
// pre-image set within the message.
select {
case resolutionMsg := <-resolutionChan:
// Once again, the pre-images should match up.
if *resolutionMsg.PreImage != fakePreimage {
t.Fatalf("wrong pre-image: "+
"expected %v, got %v",
fakePreimage, resolutionMsg.PreImage)
}
case <-time.After(time.Second * 5):
t.Fatalf("resolution not sent")
}
} else {
// Otherwise, the HTLC should now timeout. First, we
// should get a resolution message with a populated
// failure message.
select {
case resolutionMsg := <-resolutionChan:
if resolutionMsg.Failure == nil {
t.Fatalf("expected failure resolution msg")
}
case <-time.After(time.Second * 5):
t.Fatalf("resolution not sent")
}
// We should also get another request for the spend
// notification of the second-level transaction to
// indicate that it's been swept by the nursery, but
// only if this is a local commitment transaction.
if !testCase.remoteCommit {
select {
case notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: spendingTx,
SpenderTxHash: &spendTxHash,
}:
case <-time.After(time.Second * 5):
t.Fatalf("failed to request spend ntfn")
}
}
}
// In any case, before the resolver exits, it should checkpoint
// its final state.
select {
case <-checkPointChan:
case err := <-resolveErr:
t.Fatalf("unable to resolve HTLC: %v", err)
case <-time.After(time.Second * 5):
t.Fatalf("check point not received")
}
// Add a report to our set of expected reports with the outcome
// that the test specifies (either success or timeout).
spendTxID := spendingTx.TxHash()
amt := btcutil.Amount(fakeSignDesc.Output.Value)
reports = append(reports, &channeldb.ResolverReport{
OutPoint: testChanPoint2,
Amount: amt,
ResolverType: channeldb.ResolverTypeOutgoingHtlc,
ResolverOutcome: testCase.outcome,
SpendTxID: &spendTxID,
}) })
for _, report := range reports {
assertResolverReport(t, reportChan, report)
}
wg.Wait()
// Finally, the resolver should be marked as resolved.
if !resolver.resolved {
t.Fatalf("resolver should be marked as resolved")
}
} }
} }
@ -536,15 +574,12 @@ func TestHtlcTimeoutSingleStage(t *testing.T) {
} }
checkpoints := []checkpoint{ checkpoints := []checkpoint{
{
// Output should be handed off to the nursery.
incubating: true,
},
{ {
// We send a confirmation the sweep tx from published // We send a confirmation the sweep tx from published
// by the nursery. // by the nursery.
preCheckpoint: func(ctx *htlcResolverTestContext, preCheckpoint: func(ctx *htlcResolverTestContext,
_ bool) error { _ bool) error {
// The nursery will create and publish a sweep // The nursery will create and publish a sweep
// tx. // tx.
ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
@ -570,7 +605,7 @@ func TestHtlcTimeoutSingleStage(t *testing.T) {
// After the sweep has confirmed, we expect the // After the sweep has confirmed, we expect the
// checkpoint to be resolved, and with the above // checkpoint to be resolved, and with the above
// report. // report.
incubating: true, incubating: false,
resolved: true, resolved: true,
reports: []*channeldb.ResolverReport{ reports: []*channeldb.ResolverReport{
claim, claim,
@ -653,6 +688,7 @@ func TestHtlcTimeoutSecondStage(t *testing.T) {
// that our sweep succeeded. // that our sweep succeeded.
preCheckpoint: func(ctx *htlcResolverTestContext, preCheckpoint: func(ctx *htlcResolverTestContext,
_ bool) error { _ bool) error {
// The nursery will publish the timeout tx. // The nursery will publish the timeout tx.
ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: timeoutTx, SpendingTx: timeoutTx,
@ -824,9 +860,9 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) {
) )
} }
// TestHtlcTimeoutSecondStageRemoteSpend tests that when a remite commitment // TestHtlcTimeoutSecondStageRemoteSpend tests that when a remote commitment
// confirms, and the remote spends the output using the success tx, we // confirms, and the remote spends the output using the success tx, we properly
// properly detect this and extract the preimage. // detect this and extract the preimage.
func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) { func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) {
commitOutpoint := wire.OutPoint{Index: 2} commitOutpoint := wire.OutPoint{Index: 2}
@ -870,10 +906,6 @@ func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) {
} }
checkpoints := []checkpoint{ checkpoints := []checkpoint{
{
// Output should be handed off to the nursery.
incubating: true,
},
{ {
// We send a confirmation for the remote's second layer // We send a confirmation for the remote's second layer
// success transcation. // success transcation.
@ -919,7 +951,7 @@ func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) {
// After the sweep has confirmed, we expect the // After the sweep has confirmed, we expect the
// checkpoint to be resolved, and with the above // checkpoint to be resolved, and with the above
// report. // report.
incubating: true, incubating: false,
resolved: true, resolved: true,
reports: []*channeldb.ResolverReport{ reports: []*channeldb.ResolverReport{
claim, claim,
@ -1298,6 +1330,8 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) {
func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution, func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution,
checkpoints []checkpoint) { checkpoints []checkpoint) {
t.Helper()
defer timeout()() defer timeout()()
// We first run the resolver from start to finish, ensuring it gets // We first run the resolver from start to finish, ensuring it gets

View file

@ -30,6 +30,7 @@ type Registry interface {
NotifyExitHopHtlc(payHash lntypes.Hash, paidAmount lnwire.MilliSatoshi, NotifyExitHopHtlc(payHash lntypes.Hash, paidAmount lnwire.MilliSatoshi,
expiry uint32, currentHeight int32, expiry uint32, currentHeight int32,
circuitKey models.CircuitKey, hodlChan chan<- interface{}, circuitKey models.CircuitKey, hodlChan chan<- interface{},
wireCustomRecords lnwire.CustomRecords,
payload invoices.Payload) (invoices.HtlcResolution, error) payload invoices.Payload) (invoices.HtlcResolution, error)
// HodlUnsubscribeAll unsubscribes from all htlc resolutions. // HodlUnsubscribeAll unsubscribes from all htlc resolutions.

View file

@ -10,5 +10,6 @@ type mockHTLCNotifier struct {
} }
func (m *mockHTLCNotifier) NotifyFinalHtlcEvent(key models.CircuitKey, func (m *mockHTLCNotifier) NotifyFinalHtlcEvent(key models.CircuitKey,
info channeldb.FinalHtlcInfo) { //nolint:whitespace info channeldb.FinalHtlcInfo) {
} }

View file

@ -26,6 +26,7 @@ type mockRegistry struct {
func (r *mockRegistry) NotifyExitHopHtlc(payHash lntypes.Hash, func (r *mockRegistry) NotifyExitHopHtlc(payHash lntypes.Hash,
paidAmount lnwire.MilliSatoshi, expiry uint32, currentHeight int32, paidAmount lnwire.MilliSatoshi, expiry uint32, currentHeight int32,
circuitKey models.CircuitKey, hodlChan chan<- interface{}, circuitKey models.CircuitKey, hodlChan chan<- interface{},
wireCustomRecords lnwire.CustomRecords,
payload invoices.Payload) (invoices.HtlcResolution, error) { payload invoices.Payload) (invoices.HtlcResolution, error) {
r.notifyChan <- notifyExitHopData{ r.notifyChan <- notifyExitHopData{

View file

@ -8,9 +8,6 @@ import (
) )
const ( const (
taprootCtrlBlockType tlv.Type = 0
taprootTapTweakType tlv.Type = 1
commitCtrlBlockType tlv.Type = 0 commitCtrlBlockType tlv.Type = 0
revokeCtrlBlockType tlv.Type = 1 revokeCtrlBlockType tlv.Type = 1
outgoingHtlcCtrlBlockType tlv.Type = 2 outgoingHtlcCtrlBlockType tlv.Type = 2
@ -26,36 +23,67 @@ const (
// information we need to sweep taproot outputs. // information we need to sweep taproot outputs.
type taprootBriefcase struct { type taprootBriefcase struct {
// CtrlBlock is the set of control block for the taproot outputs. // CtrlBlock is the set of control block for the taproot outputs.
CtrlBlocks *ctrlBlocks CtrlBlocks tlv.RecordT[tlv.TlvType0, ctrlBlocks]
// TapTweaks is the set of taproot tweaks for the taproot outputs that // TapTweaks is the set of taproot tweaks for the taproot outputs that
// are to be spent via a keyspend path. This includes anchors, and any // are to be spent via a keyspend path. This includes anchors, and any
// revocation paths. // revocation paths.
TapTweaks *tapTweaks TapTweaks tlv.RecordT[tlv.TlvType1, tapTweaks]
// SettledCommitBlob is an optional record that contains an opaque blob
// that may be used to properly sweep commitment outputs on a force
// close transaction.
SettledCommitBlob tlv.OptionalRecordT[tlv.TlvType2, tlv.Blob]
// BreachCommitBlob is an optional record that contains an opaque blob
// used to sweep a remote party's breached output.
BreachedCommitBlob tlv.OptionalRecordT[tlv.TlvType3, tlv.Blob]
// HtlcBlobs is an optikonal record that contains the opaque blobs for
// the set of active HTLCs on the commitment transaction.
HtlcBlobs tlv.OptionalRecordT[tlv.TlvType4, htlcAuxBlobs]
} }
// TODO(roasbeef): morph into new tlv record
// newTaprootBriefcase returns a new instance of the taproot specific briefcase // newTaprootBriefcase returns a new instance of the taproot specific briefcase
// variant. // variant.
func newTaprootBriefcase() *taprootBriefcase { func newTaprootBriefcase() *taprootBriefcase {
return &taprootBriefcase{ return &taprootBriefcase{
CtrlBlocks: newCtrlBlocks(), CtrlBlocks: tlv.NewRecordT[tlv.TlvType0](newCtrlBlocks()),
TapTweaks: newTapTweaks(), TapTweaks: tlv.NewRecordT[tlv.TlvType1](newTapTweaks()),
} }
} }
// EncodeRecords returns a slice of TLV records that should be encoded. // EncodeRecords returns a slice of TLV records that should be encoded.
func (t *taprootBriefcase) EncodeRecords() []tlv.Record { func (t *taprootBriefcase) EncodeRecords() []tlv.Record {
return []tlv.Record{ records := []tlv.Record{
newCtrlBlocksRecord(&t.CtrlBlocks), t.CtrlBlocks.Record(),
newTapTweaksRecord(&t.TapTweaks), t.TapTweaks.Record(),
} }
t.SettledCommitBlob.WhenSome(
func(r tlv.RecordT[tlv.TlvType2, tlv.Blob]) {
records = append(records, r.Record())
},
)
t.BreachedCommitBlob.WhenSome(
func(r tlv.RecordT[tlv.TlvType3, tlv.Blob]) {
records = append(records, r.Record())
},
)
t.HtlcBlobs.WhenSome(func(r tlv.RecordT[tlv.TlvType4, htlcAuxBlobs]) {
records = append(records, r.Record())
})
return records
} }
// DecodeRecords returns a slice of TLV records that should be decoded. // DecodeRecords returns a slice of TLV records that should be decoded.
func (t *taprootBriefcase) DecodeRecords() []tlv.Record { func (t *taprootBriefcase) DecodeRecords() []tlv.Record {
return []tlv.Record{ return []tlv.Record{
newCtrlBlocksRecord(&t.CtrlBlocks), t.CtrlBlocks.Record(),
newTapTweaksRecord(&t.TapTweaks), t.TapTweaks.Record(),
} }
} }
@ -71,12 +99,35 @@ func (t *taprootBriefcase) Encode(w io.Writer) error {
// Decode decodes the given reader into the target struct. // Decode decodes the given reader into the target struct.
func (t *taprootBriefcase) Decode(r io.Reader) error { func (t *taprootBriefcase) Decode(r io.Reader) error {
stream, err := tlv.NewStream(t.DecodeRecords()...) settledCommitBlob := t.SettledCommitBlob.Zero()
breachedCommitBlob := t.BreachedCommitBlob.Zero()
htlcBlobs := t.HtlcBlobs.Zero()
records := append(
t.DecodeRecords(), settledCommitBlob.Record(),
breachedCommitBlob.Record(), htlcBlobs.Record(),
)
stream, err := tlv.NewStream(records...)
if err != nil { if err != nil {
return err return err
} }
return stream.Decode(r) typeMap, err := stream.DecodeWithParsedTypes(r)
if err != nil {
return err
}
if val, ok := typeMap[t.SettledCommitBlob.TlvType()]; ok && val == nil {
t.SettledCommitBlob = tlv.SomeRecordT(settledCommitBlob)
}
if v, ok := typeMap[t.BreachedCommitBlob.TlvType()]; ok && v == nil {
t.BreachedCommitBlob = tlv.SomeRecordT(breachedCommitBlob)
}
if v, ok := typeMap[t.HtlcBlobs.TlvType()]; ok && v == nil {
t.HtlcBlobs = tlv.SomeRecordT(htlcBlobs)
}
return nil
} }
// resolverCtrlBlocks is a map of resolver IDs to their corresponding control // resolverCtrlBlocks is a map of resolver IDs to their corresponding control
@ -216,8 +267,8 @@ type ctrlBlocks struct {
} }
// newCtrlBlocks returns a new instance of the ctrlBlocks struct. // newCtrlBlocks returns a new instance of the ctrlBlocks struct.
func newCtrlBlocks() *ctrlBlocks { func newCtrlBlocks() ctrlBlocks {
return &ctrlBlocks{ return ctrlBlocks{
OutgoingHtlcCtrlBlocks: newResolverCtrlBlocks(), OutgoingHtlcCtrlBlocks: newResolverCtrlBlocks(),
IncomingHtlcCtrlBlocks: newResolverCtrlBlocks(), IncomingHtlcCtrlBlocks: newResolverCtrlBlocks(),
SecondLevelCtrlBlocks: newResolverCtrlBlocks(), SecondLevelCtrlBlocks: newResolverCtrlBlocks(),
@ -260,7 +311,7 @@ func varBytesDecoder(r io.Reader, val any, buf *[8]byte, l uint64) error {
// ctrlBlockEncoder is a custom TLV encoder for the ctrlBlocks struct. // ctrlBlockEncoder is a custom TLV encoder for the ctrlBlocks struct.
func ctrlBlockEncoder(w io.Writer, val any, _ *[8]byte) error { func ctrlBlockEncoder(w io.Writer, val any, _ *[8]byte) error {
if t, ok := val.(**ctrlBlocks); ok { if t, ok := val.(*ctrlBlocks); ok {
return (*t).Encode(w) return (*t).Encode(w)
} }
@ -269,7 +320,7 @@ func ctrlBlockEncoder(w io.Writer, val any, _ *[8]byte) error {
// ctrlBlockDecoder is a custom TLV decoder for the ctrlBlocks struct. // ctrlBlockDecoder is a custom TLV decoder for the ctrlBlocks struct.
func ctrlBlockDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error { func ctrlBlockDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error {
if typ, ok := val.(**ctrlBlocks); ok { if typ, ok := val.(*ctrlBlocks); ok {
ctrlReader := io.LimitReader(r, int64(l)) ctrlReader := io.LimitReader(r, int64(l))
var ctrlBlocks ctrlBlocks var ctrlBlocks ctrlBlocks
@ -278,7 +329,7 @@ func ctrlBlockDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error {
return err return err
} }
*typ = &ctrlBlocks *typ = ctrlBlocks
return nil return nil
} }
@ -286,28 +337,6 @@ func ctrlBlockDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error {
return tlv.NewTypeForDecodingErr(val, "ctrlBlocks", l, l) return tlv.NewTypeForDecodingErr(val, "ctrlBlocks", l, l)
} }
// newCtrlBlocksRecord returns a new TLV record that can be used to
// encode/decode the set of cotrol blocks for the taproot outputs for a
// channel.
func newCtrlBlocksRecord(blks **ctrlBlocks) tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := ctrlBlockEncoder(&b, blks, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
taprootCtrlBlockType, blks, recordSize, ctrlBlockEncoder,
ctrlBlockDecoder,
)
}
// EncodeRecords returns the set of TLV records that encode the control block // EncodeRecords returns the set of TLV records that encode the control block
// for the commitment transaction. // for the commitment transaction.
func (c *ctrlBlocks) EncodeRecords() []tlv.Record { func (c *ctrlBlocks) EncodeRecords() []tlv.Record {
@ -382,7 +411,21 @@ func (c *ctrlBlocks) DecodeRecords() []tlv.Record {
// Record returns a TLV record that can be used to encode/decode the control // Record returns a TLV record that can be used to encode/decode the control
// blocks. type from a given TLV stream. // blocks. type from a given TLV stream.
func (c *ctrlBlocks) Record() tlv.Record { func (c *ctrlBlocks) Record() tlv.Record {
return tlv.MakePrimitiveRecord(commitCtrlBlockType, c) recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := ctrlBlockEncoder(&b, c, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, c, recordSize, ctrlBlockEncoder, ctrlBlockDecoder,
)
} }
// Encode encodes the set of control blocks. // Encode encodes the set of control blocks.
@ -530,8 +573,8 @@ type tapTweaks struct {
} }
// newTapTweaks returns a new tapTweaks struct. // newTapTweaks returns a new tapTweaks struct.
func newTapTweaks() *tapTweaks { func newTapTweaks() tapTweaks {
return &tapTweaks{ return tapTweaks{
BreachedHtlcTweaks: make(htlcTapTweaks), BreachedHtlcTweaks: make(htlcTapTweaks),
BreachedSecondLevelHltcTweaks: make(htlcTapTweaks), BreachedSecondLevelHltcTweaks: make(htlcTapTweaks),
} }
@ -539,7 +582,7 @@ func newTapTweaks() *tapTweaks {
// tapTweaksEncoder is a custom TLV encoder for the tapTweaks struct. // tapTweaksEncoder is a custom TLV encoder for the tapTweaks struct.
func tapTweaksEncoder(w io.Writer, val any, _ *[8]byte) error { func tapTweaksEncoder(w io.Writer, val any, _ *[8]byte) error {
if t, ok := val.(**tapTweaks); ok { if t, ok := val.(*tapTweaks); ok {
return (*t).Encode(w) return (*t).Encode(w)
} }
@ -548,7 +591,7 @@ func tapTweaksEncoder(w io.Writer, val any, _ *[8]byte) error {
// tapTweaksDecoder is a custom TLV decoder for the tapTweaks struct. // tapTweaksDecoder is a custom TLV decoder for the tapTweaks struct.
func tapTweaksDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error { func tapTweaksDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error {
if typ, ok := val.(**tapTweaks); ok { if typ, ok := val.(*tapTweaks); ok {
tweakReader := io.LimitReader(r, int64(l)) tweakReader := io.LimitReader(r, int64(l))
var tapTweaks tapTweaks var tapTweaks tapTweaks
@ -557,7 +600,7 @@ func tapTweaksDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error {
return err return err
} }
*typ = &tapTweaks *typ = tapTweaks
return nil return nil
} }
@ -565,27 +608,6 @@ func tapTweaksDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error {
return tlv.NewTypeForDecodingErr(val, "tapTweaks", l, l) return tlv.NewTypeForDecodingErr(val, "tapTweaks", l, l)
} }
// newTapTweaksRecord returns a new TLV record that can be used to
// encode/decode the tap tweak structs.
func newTapTweaksRecord(tweaks **tapTweaks) tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := tapTweaksEncoder(&b, tweaks, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
taprootTapTweakType, tweaks, recordSize, tapTweaksEncoder,
tapTweaksDecoder,
)
}
// EncodeRecords returns the set of TLV records that encode the tweaks. // EncodeRecords returns the set of TLV records that encode the tweaks.
func (t *tapTweaks) EncodeRecords() []tlv.Record { func (t *tapTweaks) EncodeRecords() []tlv.Record {
var records []tlv.Record var records []tlv.Record
@ -637,7 +659,21 @@ func (t *tapTweaks) DecodeRecords() []tlv.Record {
// Record returns a TLV record that can be used to encode/decode the tap // Record returns a TLV record that can be used to encode/decode the tap
// tweaks. // tweaks.
func (t *tapTweaks) Record() tlv.Record { func (t *tapTweaks) Record() tlv.Record {
return tlv.MakePrimitiveRecord(taprootTapTweakType, t) recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := tapTweaksEncoder(&b, t, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, t, recordSize, tapTweaksEncoder, tapTweaksDecoder,
)
} }
// Encode encodes the set of tap tweaks. // Encode encodes the set of tap tweaks.
@ -659,3 +695,110 @@ func (t *tapTweaks) Decode(r io.Reader) error {
return stream.Decode(r) return stream.Decode(r)
} }
// htlcAuxBlobs is a map of resolver IDs to their corresponding HTLC blobs.
// This is used to store the resolution blobs for HTLCs that are not yet
// resolved.
type htlcAuxBlobs map[resolverID]tlv.Blob
// newAuxHtlcBlobs returns a new instance of the htlcAuxBlobs struct.
func newAuxHtlcBlobs() htlcAuxBlobs {
return make(htlcAuxBlobs)
}
// Encode encodes the set of HTLC blobs into the target writer.
func (h *htlcAuxBlobs) Encode(w io.Writer) error {
var buf [8]byte
numBlobs := uint64(len(*h))
if err := tlv.WriteVarInt(w, numBlobs, &buf); err != nil {
return err
}
for id, blob := range *h {
if _, err := w.Write(id[:]); err != nil {
return err
}
if err := varBytesEncoder(w, &blob, &buf); err != nil {
return err
}
}
return nil
}
// Decode decodes the set of HTLC blobs from the target reader.
func (h *htlcAuxBlobs) Decode(r io.Reader) error {
var buf [8]byte
numBlobs, err := tlv.ReadVarInt(r, &buf)
if err != nil {
return err
}
for i := uint64(0); i < numBlobs; i++ {
var id resolverID
if _, err := io.ReadFull(r, id[:]); err != nil {
return err
}
var blob tlv.Blob
if err := varBytesDecoder(r, &blob, &buf, 0); err != nil {
return err
}
(*h)[id] = blob
}
return nil
}
// eHtlcAuxBlobsEncoder is a custom TLV encoder for the htlcAuxBlobs struct.
func htlcAuxBlobsEncoder(w io.Writer, val any, _ *[8]byte) error {
if t, ok := val.(*htlcAuxBlobs); ok {
return (*t).Encode(w)
}
return tlv.NewTypeForEncodingErr(val, "htlcAuxBlobs")
}
// dHtlcAuxBlobsDecoder is a custom TLV decoder for the htlcAuxBlobs struct.
func htlcAuxBlobsDecoder(r io.Reader, val any, _ *[8]byte,
l uint64) error {
if typ, ok := val.(*htlcAuxBlobs); ok {
blobReader := io.LimitReader(r, int64(l))
htlcBlobs := newAuxHtlcBlobs()
err := htlcBlobs.Decode(blobReader)
if err != nil {
return err
}
*typ = htlcBlobs
return nil
}
return tlv.NewTypeForDecodingErr(val, "htlcAuxBlobs", l, l)
}
// Record returns a tlv.Record for the htlcAuxBlobs struct.
func (h *htlcAuxBlobs) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := htlcAuxBlobsEncoder(&b, h, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, h, recordSize, htlcAuxBlobsEncoder, htlcAuxBlobsDecoder,
)
}

View file

@ -5,7 +5,9 @@ import (
"math/rand" "math/rand"
"testing" "testing"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"pgregory.net/rapid"
) )
func randResolverCtrlBlocks(t *testing.T) resolverCtrlBlocks { func randResolverCtrlBlocks(t *testing.T) resolverCtrlBlocks {
@ -52,6 +54,25 @@ func randHtlcTweaks(t *testing.T) htlcTapTweaks {
return tweaks return tweaks
} }
func randHtlcAuxBlobs(t *testing.T) htlcAuxBlobs {
numBlobs := rand.Int() % 256
blobs := make(htlcAuxBlobs, numBlobs)
for i := 0; i < numBlobs; i++ {
var id resolverID
_, err := rand.Read(id[:])
require.NoError(t, err)
var blob [100]byte
_, err = rand.Read(blob[:])
require.NoError(t, err)
blobs[id] = blob[:]
}
return blobs
}
// TestTaprootBriefcase tests the encode/decode methods of the taproot // TestTaprootBriefcase tests the encode/decode methods of the taproot
// briefcase extension. // briefcase extension.
func TestTaprootBriefcase(t *testing.T) { func TestTaprootBriefcase(t *testing.T) {
@ -69,19 +90,32 @@ func TestTaprootBriefcase(t *testing.T) {
_, err = rand.Read(anchorTweak[:]) _, err = rand.Read(anchorTweak[:])
require.NoError(t, err) require.NoError(t, err)
var commitBlob [100]byte
_, err = rand.Read(commitBlob[:])
require.NoError(t, err)
testCase := &taprootBriefcase{ testCase := &taprootBriefcase{
CtrlBlocks: &ctrlBlocks{ CtrlBlocks: tlv.NewRecordT[tlv.TlvType0](ctrlBlocks{
CommitSweepCtrlBlock: sweepCtrlBlock[:], CommitSweepCtrlBlock: sweepCtrlBlock[:],
RevokeSweepCtrlBlock: revokeCtrlBlock[:], RevokeSweepCtrlBlock: revokeCtrlBlock[:],
OutgoingHtlcCtrlBlocks: randResolverCtrlBlocks(t), OutgoingHtlcCtrlBlocks: randResolverCtrlBlocks(t),
IncomingHtlcCtrlBlocks: randResolverCtrlBlocks(t), IncomingHtlcCtrlBlocks: randResolverCtrlBlocks(t),
SecondLevelCtrlBlocks: randResolverCtrlBlocks(t), SecondLevelCtrlBlocks: randResolverCtrlBlocks(t),
}, }),
TapTweaks: &tapTweaks{ TapTweaks: tlv.NewRecordT[tlv.TlvType1](tapTweaks{
AnchorTweak: anchorTweak[:], AnchorTweak: anchorTweak[:],
BreachedHtlcTweaks: randHtlcTweaks(t), BreachedHtlcTweaks: randHtlcTweaks(t),
BreachedSecondLevelHltcTweaks: randHtlcTweaks(t), BreachedSecondLevelHltcTweaks: randHtlcTweaks(t),
}, }),
SettledCommitBlob: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType2](commitBlob[:]),
),
BreachedCommitBlob: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType3](commitBlob[:]),
),
HtlcBlobs: tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType4](randHtlcAuxBlobs(t)),
),
} }
var b bytes.Buffer var b bytes.Buffer
@ -92,3 +126,21 @@ func TestTaprootBriefcase(t *testing.T) {
require.Equal(t, testCase, &decodedCase) require.Equal(t, testCase, &decodedCase)
} }
// TestHtlcAuxBlobEncodeDecode tests the encode/decode methods of the HTLC aux
// blobs.
func TestHtlcAuxBlobEncodeDecode(t *testing.T) {
t.Parallel()
rapid.Check(t, func(t *rapid.T) {
htlcBlobs := rapid.Make[htlcAuxBlobs]().Draw(t, "htlcAuxBlobs")
var b bytes.Buffer
require.NoError(t, htlcBlobs.Encode(&b))
decodedBlobs := newAuxHtlcBlobs()
require.NoError(t, decodedBlobs.Decode(&b))
require.Equal(t, htlcBlobs, decodedBlobs)
})
}

View file

@ -0,0 +1,78 @@
# 2024/09/02 14:02:53.354676 [TestHtlcAuxBlobEncodeDecode] [rapid] draw htlcAuxBlobs: contractcourt.htlcAuxBlobs{contractcourt.resolverID{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}:[]uint8{}}
#
v0.4.8#15807814492030881602
0x5555555555555
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0
0x0

View file

@ -21,6 +21,7 @@ import (
"github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/sweep" "github.com/lightningnetwork/lnd/sweep"
"github.com/lightningnetwork/lnd/tlv"
) )
// SUMMARY OF OUTPUT STATES // SUMMARY OF OUTPUT STATES
@ -1423,6 +1424,7 @@ func makeKidOutput(outpoint, originChanPoint *wire.OutPoint,
return kidOutput{ return kidOutput{
breachedOutput: makeBreachedOutput( breachedOutput: makeBreachedOutput(
outpoint, witnessType, nil, signDescriptor, heightHint, outpoint, witnessType, nil, signDescriptor, heightHint,
fn.None[tlv.Blob](),
), ),
isHtlc: isHtlc, isHtlc: isHtlc,
originChanPoint: *originChanPoint, originChanPoint: *originChanPoint,

View file

@ -3,7 +3,7 @@
# /make/builder.Dockerfile # /make/builder.Dockerfile
# /.github/workflows/main.yml # /.github/workflows/main.yml
# /.github/workflows/release.yml # /.github/workflows/release.yml
FROM golang:1.22.5-alpine as builder FROM golang:1.22.6-alpine as builder
LABEL maintainer="Olaoluwa Osuntokun <laolu@lightning.engineering>" LABEL maintainer="Olaoluwa Osuntokun <laolu@lightning.engineering>"

View file

@ -20,6 +20,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/graph"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
@ -82,9 +83,10 @@ var (
// can provide that serve useful when processing a specific network // can provide that serve useful when processing a specific network
// announcement. // announcement.
type optionalMsgFields struct { type optionalMsgFields struct {
capacity *btcutil.Amount capacity *btcutil.Amount
channelPoint *wire.OutPoint channelPoint *wire.OutPoint
remoteAlias *lnwire.ShortChannelID remoteAlias *lnwire.ShortChannelID
tapscriptRoot fn.Option[chainhash.Hash]
} }
// apply applies the optional fields within the functional options. // apply applies the optional fields within the functional options.
@ -115,6 +117,14 @@ func ChannelPoint(op wire.OutPoint) OptionalMsgField {
} }
} }
// TapscriptRoot is an optional field that lets the gossiper know of the root of
// the tapscript tree for a custom channel.
func TapscriptRoot(root fn.Option[chainhash.Hash]) OptionalMsgField {
return func(f *optionalMsgFields) {
f.tapscriptRoot = root
}
}
// RemoteAlias is an optional field that lets the gossiper know that a locally // RemoteAlias is an optional field that lets the gossiper know that a locally
// sent channel update is actually an update for the peer that should replace // sent channel update is actually an update for the peer that should replace
// the ShortChannelID field with the remote's alias. This is only used for // the ShortChannelID field with the remote's alias. This is only used for
@ -2578,6 +2588,9 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg,
cp := *nMsg.optionalMsgFields.channelPoint cp := *nMsg.optionalMsgFields.channelPoint
edge.ChannelPoint = cp edge.ChannelPoint = cp
} }
// Optional tapscript root for custom channels.
edge.TapscriptRoot = nMsg.optionalMsgFields.tapscriptRoot
} }
log.Debugf("Adding edge for short_chan_id: %v", scid.ToUint64()) log.Debugf("Adding edge for short_chan_id: %v", scid.ToUint64())

View file

@ -835,12 +835,6 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro
} }
g.prevReplyChannelRange = msg g.prevReplyChannelRange = msg
if len(msg.Timestamps) != 0 &&
len(msg.Timestamps) != len(msg.ShortChanIDs) {
return fmt.Errorf("number of timestamps not equal to " +
"number of SCIDs")
}
for i, scid := range msg.ShortChanIDs { for i, scid := range msg.ShortChanIDs {
info := channeldb.NewChannelUpdateInfo( info := channeldb.NewChannelUpdateInfo(

View file

@ -1,4 +1,4 @@
FROM golang:1.22.5-alpine as builder FROM golang:1.22.6-alpine as builder
LABEL maintainer="Olaoluwa Osuntokun <laolu@lightning.engineering>" LABEL maintainer="Olaoluwa Osuntokun <laolu@lightning.engineering>"

View file

@ -100,12 +100,12 @@ the following commands for your OS:
<summary>Linux (x86-64)</summary> <summary>Linux (x86-64)</summary>
``` ```
wget https://dl.google.com/go/go1.22.5.linux-amd64.tar.gz wget https://dl.google.com/go/go1.22.6.linux-amd64.tar.gz
sha256sum go1.22.5.linux-amd64.tar.gz | awk -F " " '{ print $1 }' sha256sum go1.22.5.linux-amd64.tar.gz | awk -F " " '{ print $1 }'
``` ```
The final output of the command above should be The final output of the command above should be
`904b924d435eaea086515bc63235b192ea441bd8c9b198c507e85009e6e4c7f0`. If it `999805bed7d9039ec3da1a53bfbcafc13e367da52aa823cb60b68ba22d44c616`. If it
isn't, then the target REPO HAS BEEN MODIFIED, and you shouldn't install isn't, then the target REPO HAS BEEN MODIFIED, and you shouldn't install
this version of Go. If it matches, then proceed to install Go: this version of Go. If it matches, then proceed to install Go:
``` ```
@ -123,7 +123,7 @@ the following commands for your OS:
``` ```
The final output of the command above should be The final output of the command above should be
`8c4587cf3e63c9aefbcafa92818c4d9d51683af93ea687bf6c7508d6fa36f85e`. If it `b566484fe89a54c525dd1a4cbfec903c1f6e8f0b7b3dbaf94c79bc9145391083`. If it
isn't, then the target REPO HAS BEEN MODIFIED, and you shouldn't install isn't, then the target REPO HAS BEEN MODIFIED, and you shouldn't install
this version of Go. If it matches, then proceed to install Go: this version of Go. If it matches, then proceed to install Go:
``` ```

View file

@ -0,0 +1,124 @@
# Release Notes
- [Bug Fixes](#bug-fixes)
- [New Features](#new-features)
- [Functional Enhancements](#functional-enhancements)
- [RPC Additions](#rpc-additions)
- [lncli Additions](#lncli-additions)
- [Improvements](#improvements)
- [Functional Updates](#functional-updates)
- [RPC Updates](#rpc-updates)
- [lncli Updates](#lncli-updates)
- [Breaking Changes](#breaking-changes)
- [Performance Improvements](#performance-improvements)
- [Technical and Architectural Updates](#technical-and-architectural-updates)
- [BOLT Spec Updates](#bolt-spec-updates)
- [Testing](#testing)
- [Database](#database)
- [Code Health](#code-health)
- [Tooling and Documentation](#tooling-and-documentation)
# Bug Fixes
* [Fix a bug](https://github.com/lightningnetwork/lnd/pull/9134) that would
cause a nil pointer dereference during the probing of a payment request that
does not contain a payment address.
* [Make the contract resolutions for the channel arbitrator optional](
https://github.com/lightningnetwork/lnd/pull/9253).
# New Features
The main channel state machine and database now allow for processing and storing
custom Taproot script leaves, allowing the implementation of custom channel
types in a series of changes:
* https://github.com/lightningnetwork/lnd/pull/9025
* https://github.com/lightningnetwork/lnd/pull/9030
* https://github.com/lightningnetwork/lnd/pull/9049
* https://github.com/lightningnetwork/lnd/pull/9072
* https://github.com/lightningnetwork/lnd/pull/9095
* https://github.com/lightningnetwork/lnd/pull/8960
* https://github.com/lightningnetwork/lnd/pull/9194
* https://github.com/lightningnetwork/lnd/pull/9288
## Functional Enhancements
* A new `protocol.simple-taproot-overlay-chans` configuration item/CLI flag was
added [to turn on custom channel
functionality](https://github.com/lightningnetwork/lnd/pull/8960).
* Compatibility with [`bitcoind
v28.0`](https://github.com/lightningnetwork/lnd/pull/9059) was ensured by
updating the version the CI pipeline is running against.
## RPC Additions
* Some new experimental [RPCs for managing SCID
aliases](https://github.com/lightningnetwork/lnd/pull/8960) were added under
the `routerrpc` package. These methods allow manually adding and deleting SCID
aliases locally to your node.
> NOTE: these new RPC methods are marked as experimental
(`XAddLocalChanAliases` & `XDeleteLocalChanAliases`) and upon calling
them the aliases will not be communicated with the channel peer.
* The responses for the `ListChannels`, `PendingChannels` and `ChannelBalance`
RPCs now include [a new `custom_channel_data` field that is only set for
custom channels](https://github.com/lightningnetwork/lnd/pull/8960).
* The `routerrpc.SendPaymentV2` RPC has a new field [`first_hop_custom_records`
that allows the user to send custom p2p wire message TLV types to the first
hop of a payment](https://github.com/lightningnetwork/lnd/pull/8960).
That new field is also exposed in the `routerrpc.HtlcInterceptor`, so it can
be read and interpreted by external software.
* The `routerrpc.HtlcInterceptor` now [allows some values of the HTLC to be
modified before they're validated by the state
machine](https://github.com/lightningnetwork/lnd/pull/8960). The fields that
can be modified are `outgoing_amount_msat` (if transported overlaid value of
HTLC doesn't match the actual BTC amount being transferred) and
`outgoing_htlc_wire_custom_records` (allow adding custom TLV values to the
p2p wire message of the forwarded HTLC).
* A new [`invoicesrpc.HtlcModifier` RPC now allows incoming HTLCs that attempt
to satisfy an invoice to be modified before they're
validated](https://github.com/lightningnetwork/lnd/pull/8960). This allows
custom channels to determine what the actual (overlaid) value of an HTLC is,
even if that value doesn't match the actual BTC amount being transferred by
the HTLC.
## lncli Additions
# Improvements
## Functional Updates
## RPC Updates
## lncli Updates
## Code Health
## Breaking Changes
## Performance Improvements
* [A new method](https://github.com/lightningnetwork/lnd/pull/9195)
`AssertTxnsNotInMempool` has been added to `lntest` package to allow batch
exclusion check in itest.
# Technical and Architectural Updates
## BOLT Spec Updates
## Testing
## Database
## Code Health
## Tooling and Documentation
# Contributors (Alphabetical Order)
* Elle Mouton
* ffranr
* George Tsagkarelis
* Olaoluwa Osuntokun
* Oliver Gugger

View file

@ -92,7 +92,8 @@ var defaultSetDesc = setDesc{
SetInit: {}, // I SetInit: {}, // I
SetNodeAnn: {}, // N SetNodeAnn: {}, // N
}, },
lnwire.Bolt11BlindedPathsOptional: { lnwire.SimpleTaprootOverlayChansOptional: {
SetInvoice: {}, // I SetInit: {}, // I
SetNodeAnn: {}, // N
}, },
} }

View file

@ -79,6 +79,11 @@ var deps = depDesc{
lnwire.AnchorsZeroFeeHtlcTxOptional: {}, lnwire.AnchorsZeroFeeHtlcTxOptional: {},
lnwire.ExplicitChannelTypeOptional: {}, lnwire.ExplicitChannelTypeOptional: {},
}, },
lnwire.SimpleTaprootOverlayChansOptional: {
lnwire.SimpleTaprootChannelsOptionalStaging: {},
lnwire.TLVOnionPayloadOptional: {},
lnwire.ScidAliasOptional: {},
},
lnwire.RouteBlindingOptional: { lnwire.RouteBlindingOptional: {
lnwire.TLVOnionPayloadOptional: {}, lnwire.TLVOnionPayloadOptional: {},
}, },

View file

@ -63,6 +63,9 @@ type Config struct {
// NoRouteBlinding unsets route blinding feature bits. // NoRouteBlinding unsets route blinding feature bits.
NoRouteBlinding bool NoRouteBlinding bool
// NoTaprootOverlay unsets the taproot overlay channel feature bits.
NoTaprootOverlay bool
// CustomFeatures is a set of custom features to advertise in each // CustomFeatures is a set of custom features to advertise in each
// set. // set.
CustomFeatures map[Set][]lnwire.FeatureBit CustomFeatures map[Set][]lnwire.FeatureBit
@ -192,6 +195,10 @@ func newManager(cfg Config, desc setDesc) (*Manager, error) {
raw.Unset(lnwire.Bolt11BlindedPathsOptional) raw.Unset(lnwire.Bolt11BlindedPathsOptional)
raw.Unset(lnwire.Bolt11BlindedPathsRequired) raw.Unset(lnwire.Bolt11BlindedPathsRequired)
} }
if cfg.NoTaprootOverlay {
raw.Unset(lnwire.SimpleTaprootOverlayChansOptional)
raw.Unset(lnwire.SimpleTaprootOverlayChansRequired)
}
for _, custom := range cfg.CustomFeatures[set] { for _, custom := range cfg.CustomFeatures[set] {
if custom > set.Maximum() { if custom > set.Maximum() {
return nil, fmt.Errorf("feature bit: %v "+ return nil, fmt.Errorf("feature bit: %v "+

51
funding/aux_funding.go Normal file
View file

@ -0,0 +1,51 @@
package funding
import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/msgmux"
)
// AuxFundingDescResult is a type alias for a function that returns an optional
// aux funding desc.
type AuxFundingDescResult = fn.Result[fn.Option[lnwallet.AuxFundingDesc]]
// AuxTapscriptResult is a type alias for a function that returns an optional
// tapscript root.
type AuxTapscriptResult = fn.Result[fn.Option[chainhash.Hash]]
// AuxFundingController permits the implementation of the funding of custom
// channels types. The controller serves as a MsgEndpoint which allows it to
// intercept custom messages, or even the regular funding messages. The
// controller might also pass along an aux funding desc based on an existing
// pending channel ID.
type AuxFundingController interface {
// Endpoint is the embedded interface that signals that the funding
// controller is also a message endpoint. This'll allow it to handle
// custom messages specific to the funding type.
msgmux.Endpoint
// DescFromPendingChanID takes a pending channel ID, that may already be
// known due to prior custom channel messages, and maybe returns an aux
// funding desc which can be used to modify how a channel is funded.
DescFromPendingChanID(pid PendingChanID, openChan lnwallet.AuxChanState,
keyRing lntypes.Dual[lnwallet.CommitmentKeyRing],
initiator bool) AuxFundingDescResult
// DeriveTapscriptRoot takes a pending channel ID and maybe returns a
// tapscript root that should be used when creating any MuSig2 sessions
// for a channel.
DeriveTapscriptRoot(PendingChanID) AuxTapscriptResult
// ChannelReady is called when a channel has been fully opened (multiple
// confirmations) and is ready to be used. This can be used to perform
// any final setup or cleanup.
ChannelReady(openChan lnwallet.AuxChanState) error
// ChannelFinalized is called when a channel has been fully finalized.
// In this state, we've received the commitment sig from the remote
// party, so we are safe to broadcast the funding transaction.
ChannelFinalized(PendingChanID) error
}

View file

@ -307,6 +307,74 @@ func explicitNegotiateCommitmentType(channelType lnwire.ChannelType, local,
return lnwallet.CommitmentTypeSimpleTaproot, nil return lnwallet.CommitmentTypeSimpleTaproot, nil
// Simple taproot channels overlay only.
case channelFeatures.OnlyContains(
lnwire.SimpleTaprootOverlayChansRequired,
):
if !hasFeatures(
local, remote,
lnwire.SimpleTaprootOverlayChansOptional,
) {
return 0, errUnsupportedChannelType
}
return lnwallet.CommitmentTypeSimpleTaprootOverlay, nil
// Simple taproot overlay channels with scid only.
case channelFeatures.OnlyContains(
lnwire.SimpleTaprootOverlayChansRequired,
lnwire.ScidAliasRequired,
):
if !hasFeatures(
local, remote,
lnwire.SimpleTaprootOverlayChansOptional,
lnwire.ScidAliasOptional,
) {
return 0, errUnsupportedChannelType
}
return lnwallet.CommitmentTypeSimpleTaprootOverlay, nil
// Simple taproot overlay channels with zero conf only.
case channelFeatures.OnlyContains(
lnwire.SimpleTaprootOverlayChansRequired,
lnwire.ZeroConfRequired,
):
if !hasFeatures(
local, remote,
lnwire.SimpleTaprootOverlayChansOptional,
lnwire.ZeroConfOptional,
) {
return 0, errUnsupportedChannelType
}
return lnwallet.CommitmentTypeSimpleTaprootOverlay, nil
// Simple taproot overlay channels with scid and zero conf.
case channelFeatures.OnlyContains(
lnwire.SimpleTaprootOverlayChansRequired,
lnwire.ZeroConfRequired,
lnwire.ScidAliasRequired,
):
if !hasFeatures(
local, remote,
lnwire.SimpleTaprootOverlayChansOptional,
lnwire.ZeroConfOptional,
lnwire.ScidAliasOptional,
) {
return 0, errUnsupportedChannelType
}
return lnwallet.CommitmentTypeSimpleTaprootOverlay, nil
// No features, use legacy commitment type. // No features, use legacy commitment type.
case channelFeatures.IsEmpty(): case channelFeatures.IsEmpty():
return lnwallet.CommitmentTypeLegacy, nil return lnwallet.CommitmentTypeLegacy, nil

View file

@ -36,7 +36,8 @@ type aliasHandler interface {
GetPeerAlias(lnwire.ChannelID) (lnwire.ShortChannelID, error) GetPeerAlias(lnwire.ChannelID) (lnwire.ShortChannelID, error)
// AddLocalAlias persists an alias to an underlying alias store. // AddLocalAlias persists an alias to an underlying alias store.
AddLocalAlias(lnwire.ShortChannelID, lnwire.ShortChannelID, bool) error AddLocalAlias(lnwire.ShortChannelID, lnwire.ShortChannelID, bool,
bool) error
// GetAliases returns the set of aliases given the main SCID of a // GetAliases returns the set of aliases given the main SCID of a
// channel. This SCID will be an alias for zero-conf channels and will // channel. This SCID will be an alias for zero-conf channels and will

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/btcsuite/btcd/blockchain" "github.com/btcsuite/btcd/blockchain"
@ -23,6 +24,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/graph"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
@ -97,7 +99,6 @@ const (
// you and limitless channel size (apart from 21 million cap). // you and limitless channel size (apart from 21 million cap).
MaxBtcFundingAmountWumbo = btcutil.Amount(1000000000) MaxBtcFundingAmountWumbo = btcutil.Amount(1000000000)
// TODO(roasbeef): tune.
msgBufferSize = 50 msgBufferSize = 50
// MaxWaitNumBlocksFundingConf is the maximum number of blocks to wait // MaxWaitNumBlocksFundingConf is the maximum number of blocks to wait
@ -287,7 +288,7 @@ type InitFundingMsg struct {
// PendingChanID is not all zeroes (the default value), then this will // PendingChanID is not all zeroes (the default value), then this will
// be the pending channel ID used for the funding flow within the wire // be the pending channel ID used for the funding flow within the wire
// protocol. // protocol.
PendingChanID [32]byte PendingChanID PendingChanID
// ChannelType allows the caller to use an explicit channel type for the // ChannelType allows the caller to use an explicit channel type for the
// funding negotiation. This type will only be observed if BOTH sides // funding negotiation. This type will only be observed if BOTH sides
@ -317,7 +318,7 @@ type fundingMsg struct {
// pendingChannels is a map instantiated per-peer which tracks all active // pendingChannels is a map instantiated per-peer which tracks all active
// pending single funded channels indexed by their pending channel identifier, // pending single funded channels indexed by their pending channel identifier,
// which is a set of 32-bytes generated via a CSPRNG. // which is a set of 32-bytes generated via a CSPRNG.
type pendingChannels map[[32]byte]*reservationWithCtx type pendingChannels map[PendingChanID]*reservationWithCtx
// serializedPubKey is used within the FundingManager's activeReservations list // serializedPubKey is used within the FundingManager's activeReservations list
// to identify the nodes with which the FundingManager is actively working to // to identify the nodes with which the FundingManager is actively working to
@ -543,6 +544,24 @@ type Config struct {
// backed funding flow to not use utxos still being swept by the sweeper // backed funding flow to not use utxos still being swept by the sweeper
// subsystem. // subsystem.
IsSweeperOutpoint func(wire.OutPoint) bool IsSweeperOutpoint func(wire.OutPoint) bool
// AuxLeafStore is an optional store that can be used to store auxiliary
// leaves for certain custom channel types.
AuxLeafStore fn.Option[lnwallet.AuxLeafStore]
// AuxFundingController is an optional controller that can be used to
// modify the way we handle certain custom channel types. It's also
// able to automatically handle new custom protocol messages related to
// the funding process.
AuxFundingController fn.Option[AuxFundingController]
// AuxSigner is an optional signer that can be used to sign auxiliary
// leaves for certain custom channel types.
AuxSigner fn.Option[lnwallet.AuxSigner]
// AuxResolver is an optional interface that can be used to modify the
// way contracts are resolved.
AuxResolver fn.Option[lnwallet.AuxContractResolver]
} }
// Manager acts as an orchestrator/bridge between the wallet's // Manager acts as an orchestrator/bridge between the wallet's
@ -568,8 +587,10 @@ type Manager struct {
// chanIDNonce is a nonce that's incremented for each new funding // chanIDNonce is a nonce that's incremented for each new funding
// reservation created. // reservation created.
nonceMtx sync.RWMutex chanIDNonce atomic.Uint64
chanIDNonce uint64
// nonceMtx is a mutex that guards the pendingMusigNonces.
nonceMtx sync.RWMutex
// pendingMusigNonces is used to store the musig2 nonce we generate to // pendingMusigNonces is used to store the musig2 nonce we generate to
// send funding locked until we receive a funding locked message from // send funding locked until we receive a funding locked message from
@ -591,7 +612,7 @@ type Manager struct {
// required as mid funding flow, we switch to referencing the channel // required as mid funding flow, we switch to referencing the channel
// by its full channel ID once the commitment transactions have been // by its full channel ID once the commitment transactions have been
// signed by both parties. // signed by both parties.
signedReservations map[lnwire.ChannelID][32]byte signedReservations map[lnwire.ChannelID]PendingChanID
// resMtx guards both of the maps above to ensure that all access is // resMtx guards both of the maps above to ensure that all access is
// goroutine safe. // goroutine safe.
@ -798,24 +819,28 @@ func (f *Manager) rebroadcastFundingTx(c *channeldb.OpenChannel) {
} }
} }
// PendingChanID is a type that represents a pending channel ID. This might be
// selected by the caller, but if not, will be automatically selected.
type PendingChanID = [32]byte
// nextPendingChanID returns the next free pending channel ID to be used to // nextPendingChanID returns the next free pending channel ID to be used to
// identify a particular future channel funding workflow. // identify a particular future channel funding workflow.
func (f *Manager) nextPendingChanID() [32]byte { func (f *Manager) nextPendingChanID() PendingChanID {
// Obtain a fresh nonce. We do this by encoding the current nonce // Obtain a fresh nonce. We do this by encoding the incremented nonce.
// counter, then incrementing it by one. nextNonce := f.chanIDNonce.Add(1)
f.nonceMtx.Lock()
var nonce [8]byte var nonceBytes [8]byte
binary.LittleEndian.PutUint64(nonce[:], f.chanIDNonce) binary.LittleEndian.PutUint64(nonceBytes[:], nextNonce)
f.chanIDNonce++
f.nonceMtx.Unlock()
// We'll generate the next pending channelID by "encrypting" 32-bytes // We'll generate the next pending channelID by "encrypting" 32-bytes
// of zeroes which'll extract 32 random bytes from our stream cipher. // of zeroes which'll extract 32 random bytes from our stream cipher.
var ( var (
nextChanID [32]byte nextChanID PendingChanID
zeroes [32]byte zeroes [32]byte
) )
salsa20.XORKeyStream(nextChanID[:], zeroes[:], nonce[:], &f.chanIDKey) salsa20.XORKeyStream(
nextChanID[:], zeroes[:], nonceBytes[:], &f.chanIDKey,
)
return nextChanID return nextChanID
} }
@ -1045,7 +1070,8 @@ func (f *Manager) reservationCoordinator() {
// //
// NOTE: This MUST be run as a goroutine. // NOTE: This MUST be run as a goroutine.
func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel,
pendingChanID [32]byte, updateChan chan<- *lnrpc.OpenStatusUpdate) { pendingChanID PendingChanID,
updateChan chan<- *lnrpc.OpenStatusUpdate) {
defer f.wg.Done() defer f.wg.Done()
@ -1061,9 +1087,20 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel,
} }
} }
var chanOpts []lnwallet.ChannelOpt
f.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) {
chanOpts = append(chanOpts, lnwallet.WithLeafStore(s))
})
f.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) {
chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s))
})
f.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) {
chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s))
})
// We create the state-machine object which wraps the database state. // We create the state-machine object which wraps the database state.
lnChannel, err := lnwallet.NewLightningChannel( lnChannel, err := lnwallet.NewLightningChannel(
nil, channel, nil, nil, channel, nil, chanOpts...,
) )
if err != nil { if err != nil {
log.Errorf("Unable to create LightningChannel(%v): %v", log.Errorf("Unable to create LightningChannel(%v): %v",
@ -1115,7 +1152,7 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel,
// updateChan can be set non-nil to get OpenStatusUpdates. // updateChan can be set non-nil to get OpenStatusUpdates.
func (f *Manager) stateStep(channel *channeldb.OpenChannel, func (f *Manager) stateStep(channel *channeldb.OpenChannel,
lnChannel *lnwallet.LightningChannel, lnChannel *lnwallet.LightningChannel,
shortChanID *lnwire.ShortChannelID, pendingChanID [32]byte, shortChanID *lnwire.ShortChannelID, pendingChanID PendingChanID,
channelState channelOpeningState, channelState channelOpeningState,
updateChan chan<- *lnrpc.OpenStatusUpdate) error { updateChan chan<- *lnrpc.OpenStatusUpdate) error {
@ -1238,14 +1275,14 @@ func (f *Manager) stateStep(channel *channeldb.OpenChannel,
// advancePendingChannelState waits for a pending channel's funding tx to // advancePendingChannelState waits for a pending channel's funding tx to
// confirm, and marks it open in the database when that happens. // confirm, and marks it open in the database when that happens.
func (f *Manager) advancePendingChannelState( func (f *Manager) advancePendingChannelState(channel *channeldb.OpenChannel,
channel *channeldb.OpenChannel, pendingChanID [32]byte) error { pendingChanID PendingChanID) error {
if channel.IsZeroConf() { if channel.IsZeroConf() {
// Persist the alias to the alias database. // Persist the alias to the alias database.
baseScid := channel.ShortChannelID baseScid := channel.ShortChannelID
err := f.cfg.AliasManager.AddLocalAlias( err := f.cfg.AliasManager.AddLocalAlias(
baseScid, baseScid, true, baseScid, baseScid, true, false,
) )
if err != nil { if err != nil {
return fmt.Errorf("error adding local alias to "+ return fmt.Errorf("error adding local alias to "+
@ -1608,6 +1645,23 @@ func (f *Manager) fundeeProcessOpenChannel(peer lnpeer.Peer,
return return
} }
// At this point, if we have an AuxFundingController active, we'll
// check to see if we have a special tapscript root to use in our
// MuSig funding output.
tapscriptRoot, err := fn.MapOptionZ(
f.cfg.AuxFundingController,
func(c AuxFundingController) AuxTapscriptResult {
return c.DeriveTapscriptRoot(msg.PendingChannelID)
},
).Unpack()
if err != nil {
err = fmt.Errorf("error deriving tapscript root: %w", err)
log.Error(err)
f.failFundingFlow(peer, cid, err)
return
}
req := &lnwallet.InitFundingReserveMsg{ req := &lnwallet.InitFundingReserveMsg{
ChainHash: &msg.ChainHash, ChainHash: &msg.ChainHash,
PendingChanID: msg.PendingChannelID, PendingChanID: msg.PendingChannelID,
@ -1624,6 +1678,7 @@ func (f *Manager) fundeeProcessOpenChannel(peer lnpeer.Peer,
ZeroConf: zeroConf, ZeroConf: zeroConf,
OptionScidAlias: scid, OptionScidAlias: scid,
ScidAliasFeature: scidFeatureVal, ScidAliasFeature: scidFeatureVal,
TapscriptRoot: tapscriptRoot,
} }
reservation, err := f.cfg.Wallet.InitChannelReservation(req) reservation, err := f.cfg.Wallet.InitChannelReservation(req)
@ -1880,6 +1935,8 @@ func (f *Manager) fundeeProcessOpenChannel(peer lnpeer.Peer,
log.Debugf("Remote party accepted commitment rendering params: %v", log.Debugf("Remote party accepted commitment rendering params: %v",
lnutils.SpewLogClosure(params)) lnutils.SpewLogClosure(params))
reservation.SetState(lnwallet.SentAcceptChannel)
// With the initiator's contribution recorded, respond with our // With the initiator's contribution recorded, respond with our
// contribution in the next message of the workflow. // contribution in the next message of the workflow.
fundingAccept := lnwire.AcceptChannel{ fundingAccept := lnwire.AcceptChannel{
@ -1940,6 +1997,10 @@ func (f *Manager) funderProcessAcceptChannel(peer lnpeer.Peer,
// Update the timestamp once the fundingAcceptMsg has been handled. // Update the timestamp once the fundingAcceptMsg has been handled.
defer resCtx.updateTimestamp() defer resCtx.updateTimestamp()
if resCtx.reservation.State() != lnwallet.SentOpenChannel {
return
}
log.Infof("Recv'd fundingResponse for pending_id(%x)", log.Infof("Recv'd fundingResponse for pending_id(%x)",
pendingChanID[:]) pendingChanID[:])
@ -2243,10 +2304,34 @@ func (f *Manager) waitForPsbt(intent *chanfunding.PsbtIntent,
return return
} }
// At this point, we'll see if there's an AuxFundingDesc we
// need to deliver so the funding process can continue
// properly.
auxFundingDesc, err := fn.MapOptionZ(
f.cfg.AuxFundingController,
func(c AuxFundingController) AuxFundingDescResult {
return c.DescFromPendingChanID(
cid.tempChanID,
lnwallet.NewAuxChanState(
resCtx.reservation.ChanState(),
),
resCtx.reservation.CommitmentKeyRings(),
true,
)
},
).Unpack()
if err != nil {
failFlow("error continuing PSBT flow", err)
return
}
// A non-nil error means we can continue the funding flow. // A non-nil error means we can continue the funding flow.
// Notify the wallet so it can prepare everything we need to // Notify the wallet so it can prepare everything we need to
// continue. // continue.
err = resCtx.reservation.ProcessPsbt() //
// We'll also pass along the aux funding controller as well,
// which may be used to help process the finalized PSBT.
err = resCtx.reservation.ProcessPsbt(auxFundingDesc)
if err != nil { if err != nil {
failFlow("error continuing PSBT flow", err) failFlow("error continuing PSBT flow", err)
return return
@ -2341,6 +2426,8 @@ func (f *Manager) continueFundingAccept(resCtx *reservationWithCtx,
} }
} }
resCtx.reservation.SetState(lnwallet.SentFundingCreated)
if err := resCtx.peer.SendMessage(true, fundingCreated); err != nil { if err := resCtx.peer.SendMessage(true, fundingCreated); err != nil {
log.Errorf("Unable to send funding complete message: %v", err) log.Errorf("Unable to send funding complete message: %v", err)
f.failFundingFlow(resCtx.peer, cid, err) f.failFundingFlow(resCtx.peer, cid, err)
@ -2372,11 +2459,14 @@ func (f *Manager) fundeeProcessFundingCreated(peer lnpeer.Peer,
// final funding transaction, as well as a signature for our version of // final funding transaction, as well as a signature for our version of
// the commitment transaction. So at this point, we can validate the // the commitment transaction. So at this point, we can validate the
// initiator's commitment transaction, then send our own if it's valid. // initiator's commitment transaction, then send our own if it's valid.
// TODO(roasbeef): make case (p vs P) consistent throughout
fundingOut := msg.FundingPoint fundingOut := msg.FundingPoint
log.Infof("completing pending_id(%x) with ChannelPoint(%v)", log.Infof("completing pending_id(%x) with ChannelPoint(%v)",
pendingChanID[:], fundingOut) pendingChanID[:], fundingOut)
if resCtx.reservation.State() != lnwallet.SentAcceptChannel {
return
}
// Create the channel identifier without setting the active channel ID. // Create the channel identifier without setting the active channel ID.
cid := newChanIdentifier(pendingChanID) cid := newChanIdentifier(pendingChanID)
@ -2404,16 +2494,38 @@ func (f *Manager) fundeeProcessFundingCreated(peer lnpeer.Peer,
} }
} }
// At this point, we'll see if there's an AuxFundingDesc we need to
// deliver so the funding process can continue properly.
auxFundingDesc, err := fn.MapOptionZ(
f.cfg.AuxFundingController,
func(c AuxFundingController) AuxFundingDescResult {
return c.DescFromPendingChanID(
cid.tempChanID, lnwallet.NewAuxChanState(
resCtx.reservation.ChanState(),
), resCtx.reservation.CommitmentKeyRings(),
true,
)
},
).Unpack()
if err != nil {
log.Errorf("error continuing PSBT flow: %v", err)
f.failFundingFlow(peer, cid, err)
return
}
// With all the necessary data available, attempt to advance the // With all the necessary data available, attempt to advance the
// funding workflow to the next stage. If this succeeds then the // funding workflow to the next stage. If this succeeds then the
// funding transaction will broadcast after our next message. // funding transaction will broadcast after our next message.
// CompleteReservationSingle will also mark the channel as 'IsPending' // CompleteReservationSingle will also mark the channel as 'IsPending'
// in the database. // in the database.
//
// We'll also directly pass in the AuxFunding controller as well,
// which may be used by the reservation system to finalize funding our
// side.
completeChan, err := resCtx.reservation.CompleteReservationSingle( completeChan, err := resCtx.reservation.CompleteReservationSingle(
&fundingOut, commitSig, &fundingOut, commitSig, auxFundingDesc,
) )
if err != nil { if err != nil {
// TODO(roasbeef): better error logging: peerID, channelID, etc.
log.Errorf("unable to complete single reservation: %v", err) log.Errorf("unable to complete single reservation: %v", err)
f.failFundingFlow(peer, cid, err) f.failFundingFlow(peer, cid, err)
return return
@ -2614,6 +2726,14 @@ func (f *Manager) funderProcessFundingSigned(peer lnpeer.Peer,
return return
} }
if resCtx.reservation.State() != lnwallet.SentFundingCreated {
err := fmt.Errorf("unable to find reservation for chan_id=%x",
msg.ChanID)
f.failFundingFlow(peer, cid, err)
return
}
// Create an entry in the local discovery map so we can ensure that we // Create an entry in the local discovery map so we can ensure that we
// process the channel confirmation fully before we receive a // process the channel confirmation fully before we receive a
// channel_ready message. // channel_ready message.
@ -2709,6 +2829,21 @@ func (f *Manager) funderProcessFundingSigned(peer lnpeer.Peer,
} }
} }
// Before we proceed, if we have a funding hook that wants a
// notification that it's safe to broadcast the funding transaction,
// then we'll send that now.
err = fn.MapOptionZ(
f.cfg.AuxFundingController,
func(controller AuxFundingController) error {
return controller.ChannelFinalized(cid.tempChanID)
},
)
if err != nil {
log.Errorf("Failed to inform aux funding controller about "+
"ChannelPoint(%v) being finalized: %v", fundingPoint,
err)
}
// Now that we have a finalized reservation for this funding flow, // Now that we have a finalized reservation for this funding flow,
// we'll send the to be active channel to the ChainArbitrator so it can // we'll send the to be active channel to the ChainArbitrator so it can
// watch for any on-chain actions before the channel has fully // watch for any on-chain actions before the channel has fully
@ -2724,9 +2859,6 @@ func (f *Manager) funderProcessFundingSigned(peer lnpeer.Peer,
// Send an update to the upstream client that the negotiation process // Send an update to the upstream client that the negotiation process
// is over. // is over.
//
// TODO(roasbeef): add abstraction over updates to accommodate
// long-polling, or SSE, etc.
upd := &lnrpc.OpenStatusUpdate{ upd := &lnrpc.OpenStatusUpdate{
Update: &lnrpc.OpenStatusUpdate_ChanPending{ Update: &lnrpc.OpenStatusUpdate_ChanPending{
ChanPending: &lnrpc.PendingUpdate{ ChanPending: &lnrpc.PendingUpdate{
@ -2770,7 +2902,7 @@ type confirmedChannel struct {
// channel as closed. The error is only returned for the responder of the // channel as closed. The error is only returned for the responder of the
// channel flow. // channel flow.
func (f *Manager) fundingTimeout(c *channeldb.OpenChannel, func (f *Manager) fundingTimeout(c *channeldb.OpenChannel,
pendingID [32]byte) error { pendingID PendingChanID) error {
// We'll get a timeout if the number of blocks mined since the channel // We'll get a timeout if the number of blocks mined since the channel
// was initiated reaches MaxWaitNumBlocksFundingConf and we are not the // was initiated reaches MaxWaitNumBlocksFundingConf and we are not the
@ -2891,6 +3023,7 @@ func makeFundingScript(channel *channeldb.OpenChannel) ([]byte, error) {
if channel.ChanType.IsTaproot() { if channel.ChanType.IsTaproot() {
pkScript, _, err := input.GenTaprootFundingScript( pkScript, _, err := input.GenTaprootFundingScript(
localKey, remoteKey, int64(channel.Capacity), localKey, remoteKey, int64(channel.Capacity),
channel.TapscriptRoot,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -3130,7 +3263,7 @@ func (f *Manager) handleFundingConfirmation(
} }
err = f.cfg.AliasManager.AddLocalAlias( err = f.cfg.AliasManager.AddLocalAlias(
aliasScid, confChannel.shortChanID, true, aliasScid, confChannel.shortChanID, true, false,
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to request alias: %w", err) return fmt.Errorf("unable to request alias: %w", err)
@ -3296,7 +3429,7 @@ func (f *Manager) sendChannelReady(completeChan *channeldb.OpenChannel,
err = f.cfg.AliasManager.AddLocalAlias( err = f.cfg.AliasManager.AddLocalAlias(
alias, completeChan.ShortChannelID, alias, completeChan.ShortChannelID,
false, false, false,
) )
if err != nil { if err != nil {
return err return err
@ -3431,6 +3564,7 @@ func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel,
errChan := f.cfg.SendAnnouncement( errChan := f.cfg.SendAnnouncement(
ann.chanAnn, discovery.ChannelCapacity(completeChan.Capacity), ann.chanAnn, discovery.ChannelCapacity(completeChan.Capacity),
discovery.ChannelPoint(completeChan.FundingOutpoint), discovery.ChannelPoint(completeChan.FundingOutpoint),
discovery.TapscriptRoot(completeChan.TapscriptRoot),
) )
select { select {
case err := <-errChan: case err := <-errChan:
@ -3627,7 +3761,7 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel,
// waitForZeroConfChannel is called when the state is addedToGraph with // waitForZeroConfChannel is called when the state is addedToGraph with
// a zero-conf channel. This will wait for the real confirmation, add the // a zero-conf channel. This will wait for the real confirmation, add the
// confirmed SCID to the graph, and then announce after six confs. // confirmed SCID to the router graph, and then announce after six confs.
func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel) error { func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel) error {
// First we'll check whether the channel is confirmed on-chain. If it // First we'll check whether the channel is confirmed on-chain. If it
// is already confirmed, the chainntnfs subsystem will return with the // is already confirmed, the chainntnfs subsystem will return with the
@ -3873,7 +4007,7 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen
} }
err = f.cfg.AliasManager.AddLocalAlias( err = f.cfg.AliasManager.AddLocalAlias(
alias, channel.ShortChannelID, false, alias, channel.ShortChannelID, false, false,
) )
if err != nil { if err != nil {
log.Errorf("unable to add local alias: %v", log.Errorf("unable to add local alias: %v",
@ -3958,6 +4092,26 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen
PubNonce: remoteNonce, PubNonce: remoteNonce,
}), }),
) )
// Inform the aux funding controller that the liquidity in the
// custom channel is now ready to be advertised. We potentially
// haven't sent our own channel ready message yet, but other
// than that the channel is ready to count toward available
// liquidity.
err = fn.MapOptionZ(
f.cfg.AuxFundingController,
func(controller AuxFundingController) error {
return controller.ChannelReady(
lnwallet.NewAuxChanState(channel),
)
},
)
if err != nil {
cid := newChanIdentifier(msg.ChanID)
f.sendWarning(peer, cid, err)
return
}
} }
// The channel_ready message contains the next commitment point we'll // The channel_ready message contains the next commitment point we'll
@ -3995,7 +4149,7 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen
// channel is now active, thus we change its state to `addedToGraph` to // channel is now active, thus we change its state to `addedToGraph` to
// let the channel start handling routing. // let the channel start handling routing.
func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel, func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel,
scid *lnwire.ShortChannelID, pendingChanID [32]byte, scid *lnwire.ShortChannelID, pendingChanID PendingChanID,
updateChan chan<- *lnrpc.OpenStatusUpdate) error { updateChan chan<- *lnrpc.OpenStatusUpdate) error {
chanID := lnwire.NewChanIDFromOutPoint(channel.FundingOutpoint) chanID := lnwire.NewChanIDFromOutPoint(channel.FundingOutpoint)
@ -4044,6 +4198,19 @@ func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel,
log.Debugf("Channel(%v) with ShortChanID %v: successfully "+ log.Debugf("Channel(%v) with ShortChanID %v: successfully "+
"added to graph", chanID, scid) "added to graph", chanID, scid)
err = fn.MapOptionZ(
f.cfg.AuxFundingController,
func(controller AuxFundingController) error {
return controller.ChannelReady(
lnwallet.NewAuxChanState(channel),
)
},
)
if err != nil {
return fmt.Errorf("failed notifying aux funding controller "+
"about channel ready: %w", err)
}
// Give the caller a final update notifying them that the channel is // Give the caller a final update notifying them that the channel is
fundingPoint := channel.FundingOutpoint fundingPoint := channel.FundingOutpoint
cp := &lnrpc.ChannelPoint{ cp := &lnrpc.ChannelPoint{
@ -4357,9 +4524,9 @@ func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey,
// //
// We can pass in zeroes for the min and max htlc policy, because we // We can pass in zeroes for the min and max htlc policy, because we
// only use the channel announcement message from the returned struct. // only use the channel announcement message from the returned struct.
ann, err := f.newChanAnnouncement(localIDKey, remoteIDKey, ann, err := f.newChanAnnouncement(
localFundingKey, remoteFundingKey, shortChanID, chanID, localIDKey, remoteIDKey, localFundingKey, remoteFundingKey,
0, 0, nil, chanType, shortChanID, chanID, 0, 0, nil, chanType,
) )
if err != nil { if err != nil {
log.Errorf("can't generate channel announcement: %v", err) log.Errorf("can't generate channel announcement: %v", err)
@ -4425,7 +4592,6 @@ func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey,
// InitFundingWorkflow sends a message to the funding manager instructing it // InitFundingWorkflow sends a message to the funding manager instructing it
// to initiate a single funder workflow with the source peer. // to initiate a single funder workflow with the source peer.
// TODO(roasbeef): re-visit blocking nature..
func (f *Manager) InitFundingWorkflow(msg *InitFundingMsg) { func (f *Manager) InitFundingWorkflow(msg *InitFundingMsg) {
f.fundingRequests <- msg f.fundingRequests <- msg
} }
@ -4519,7 +4685,7 @@ func (f *Manager) handleInitFundingMsg(msg *InitFundingMsg) {
// If the caller specified their own channel ID, then we'll use that. // If the caller specified their own channel ID, then we'll use that.
// Otherwise we'll generate a fresh one as normal. This will be used // Otherwise we'll generate a fresh one as normal. This will be used
// to track this reservation throughout its lifetime. // to track this reservation throughout its lifetime.
var chanID [32]byte var chanID PendingChanID
if msg.PendingChanID == zeroID { if msg.PendingChanID == zeroID {
chanID = f.nextPendingChanID() chanID = f.nextPendingChanID()
} else { } else {
@ -4615,6 +4781,23 @@ func (f *Manager) handleInitFundingMsg(msg *InitFundingMsg) {
scidFeatureVal = true scidFeatureVal = true
} }
// At this point, if we have an AuxFundingController active, we'll check
// to see if we have a special tapscript root to use in our MuSig2
// funding output.
tapscriptRoot, err := fn.MapOptionZ(
f.cfg.AuxFundingController,
func(c AuxFundingController) AuxTapscriptResult {
return c.DeriveTapscriptRoot(chanID)
},
).Unpack()
if err != nil {
err = fmt.Errorf("error deriving tapscript root: %w", err)
log.Error(err)
msg.Err <- err
return
}
req := &lnwallet.InitFundingReserveMsg{ req := &lnwallet.InitFundingReserveMsg{
ChainHash: &msg.ChainHash, ChainHash: &msg.ChainHash,
PendingChanID: chanID, PendingChanID: chanID,
@ -4654,6 +4837,7 @@ func (f *Manager) handleInitFundingMsg(msg *InitFundingMsg) {
OptionScidAlias: scid, OptionScidAlias: scid,
ScidAliasFeature: scidFeatureVal, ScidAliasFeature: scidFeatureVal,
Memo: msg.Memo, Memo: msg.Memo,
TapscriptRoot: tapscriptRoot,
} }
reservation, err := f.cfg.Wallet.InitChannelReservation(req) reservation, err := f.cfg.Wallet.InitChannelReservation(req)
@ -4805,6 +4989,8 @@ func (f *Manager) handleInitFundingMsg(msg *InitFundingMsg) {
log.Infof("Starting funding workflow with %v for pending_id(%x), "+ log.Infof("Starting funding workflow with %v for pending_id(%x), "+
"committype=%v", msg.Peer.Address(), chanID, commitType) "committype=%v", msg.Peer.Address(), chanID, commitType)
reservation.SetState(lnwallet.SentOpenChannel)
fundingOpen := lnwire.OpenChannel{ fundingOpen := lnwire.OpenChannel{
ChainHash: *f.cfg.Wallet.Cfg.NetParams.GenesisHash, ChainHash: *f.cfg.Wallet.Cfg.NetParams.GenesisHash,
PendingChannelID: chanID, PendingChannelID: chanID,
@ -4942,7 +5128,8 @@ func (f *Manager) pruneZombieReservations() {
// cancelReservationCtx does all needed work in order to securely cancel the // cancelReservationCtx does all needed work in order to securely cancel the
// reservation. // reservation.
func (f *Manager) cancelReservationCtx(peerKey *btcec.PublicKey, func (f *Manager) cancelReservationCtx(peerKey *btcec.PublicKey,
pendingChanID [32]byte, byRemote bool) (*reservationWithCtx, error) { pendingChanID PendingChanID,
byRemote bool) (*reservationWithCtx, error) {
log.Infof("Cancelling funding reservation for node_key=%x, "+ log.Infof("Cancelling funding reservation for node_key=%x, "+
"chan_id=%x", peerKey.SerializeCompressed(), pendingChanID[:]) "chan_id=%x", peerKey.SerializeCompressed(), pendingChanID[:])
@ -4990,7 +5177,7 @@ func (f *Manager) cancelReservationCtx(peerKey *btcec.PublicKey,
// deleteReservationCtx deletes the reservation uniquely identified by the // deleteReservationCtx deletes the reservation uniquely identified by the
// target public key of the peer, and the specified pending channel ID. // target public key of the peer, and the specified pending channel ID.
func (f *Manager) deleteReservationCtx(peerKey *btcec.PublicKey, func (f *Manager) deleteReservationCtx(peerKey *btcec.PublicKey,
pendingChanID [32]byte) { pendingChanID PendingChanID) {
peerIDKey := newSerializedKey(peerKey) peerIDKey := newSerializedKey(peerKey)
f.resMtx.Lock() f.resMtx.Lock()
@ -5013,7 +5200,7 @@ func (f *Manager) deleteReservationCtx(peerKey *btcec.PublicKey,
// getReservationCtx returns the reservation context for a particular pending // getReservationCtx returns the reservation context for a particular pending
// channel ID for a target peer. // channel ID for a target peer.
func (f *Manager) getReservationCtx(peerKey *btcec.PublicKey, func (f *Manager) getReservationCtx(peerKey *btcec.PublicKey,
pendingChanID [32]byte) (*reservationWithCtx, error) { pendingChanID PendingChanID) (*reservationWithCtx, error) {
peerIDKey := newSerializedKey(peerKey) peerIDKey := newSerializedKey(peerKey)
f.resMtx.RLock() f.resMtx.RLock()
@ -5033,7 +5220,7 @@ func (f *Manager) getReservationCtx(peerKey *btcec.PublicKey,
// of being funded. After the funding transaction has been confirmed, the // of being funded. After the funding transaction has been confirmed, the
// channel will receive a new, permanent channel ID, and will no longer be // channel will receive a new, permanent channel ID, and will no longer be
// considered pending. // considered pending.
func (f *Manager) IsPendingChannel(pendingChanID [32]byte, func (f *Manager) IsPendingChannel(pendingChanID PendingChanID,
peer lnpeer.Peer) bool { peer lnpeer.Peer) bool {
peerIDKey := newSerializedKey(peer.IdentityKey()) peerIDKey := newSerializedKey(peer.IdentityKey())

View file

@ -28,6 +28,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/channelnotifier"
"github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lncfg"
@ -161,7 +162,7 @@ func (m *mockAliasMgr) GetPeerAlias(lnwire.ChannelID) (lnwire.ShortChannelID,
} }
func (m *mockAliasMgr) AddLocalAlias(lnwire.ShortChannelID, func (m *mockAliasMgr) AddLocalAlias(lnwire.ShortChannelID,
lnwire.ShortChannelID, bool) error { lnwire.ShortChannelID, bool, bool) error {
return nil return nil
} }
@ -563,6 +564,12 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey,
IsSweeperOutpoint: func(wire.OutPoint) bool { IsSweeperOutpoint: func(wire.OutPoint) bool {
return false return false
}, },
AuxLeafStore: fn.Some[lnwallet.AuxLeafStore](
&lnwallet.MockAuxLeafStore{},
),
AuxSigner: fn.Some[lnwallet.AuxSigner](
lnwallet.NewAuxSignerMock(lnwallet.EmptyMockJobHandler),
),
} }
for _, op := range options { for _, op := range options {
@ -672,6 +679,8 @@ func recreateAliceFundingManager(t *testing.T, alice *testNode) {
OpenChannelPredicate: chainedAcceptor, OpenChannelPredicate: chainedAcceptor,
DeleteAliasEdge: oldCfg.DeleteAliasEdge, DeleteAliasEdge: oldCfg.DeleteAliasEdge,
AliasManager: oldCfg.AliasManager, AliasManager: oldCfg.AliasManager,
AuxLeafStore: oldCfg.AuxLeafStore,
AuxSigner: oldCfg.AuxSigner,
}) })
require.NoError(t, err, "failed recreating aliceFundingManager") require.NoError(t, err, "failed recreating aliceFundingManager")
@ -4644,8 +4653,8 @@ func testZeroConf(t *testing.T, chanType *lnwire.ChannelType) {
// opening behavior with a specified fundmax flag. To give a hypothetical // opening behavior with a specified fundmax flag. To give a hypothetical
// example, if ANCHOR types had been introduced after the fundmax flag had been // example, if ANCHOR types had been introduced after the fundmax flag had been
// activated, the developer would have had to code for the anchor reserve in the // activated, the developer would have had to code for the anchor reserve in the
// funding manager in the context of public and private channels. Otherwise // funding manager in the context of public and private channels. Otherwise,
// inconsistent bahvior would have resulted when specifying fundmax for // inconsistent behavior would have resulted when specifying fundmax for
// different types of channel openings. // different types of channel openings.
// To ensure consistency this test compares a map of locally defined channel // To ensure consistency this test compares a map of locally defined channel
// commitment types to the list of channel types that are defined in the proto // commitment types to the list of channel types that are defined in the proto
@ -4661,6 +4670,7 @@ func TestCommitmentTypeFundmaxSanityCheck(t *testing.T) {
"ANCHORS": 3, "ANCHORS": 3,
"SCRIPT_ENFORCED_LEASE": 4, "SCRIPT_ENFORCED_LEASE": 4,
"SIMPLE_TAPROOT": 5, "SIMPLE_TAPROOT": 5,
"SIMPLE_TAPROOT_OVERLAY": 6,
} }
for commitmentType := range lnrpc.CommitmentType_value { for commitmentType := range lnrpc.CommitmentType_value {

Some files were not shown because too many files have changed in this diff Show more