diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 668214444..03ee9d5fe 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2,17 +2,15 @@ package htlcswitch import ( "bytes" + "crypto/sha256" "fmt" prand "math/rand" "sync" "sync/atomic" "time" - "crypto/sha256" - "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hodl" @@ -213,13 +211,6 @@ type ChannelLinkConfig struct { // transaction to ensure timely confirmation. FeeEstimator lnwallet.FeeEstimator - // BlockEpochs is an active block epoch event stream backed by an - // active ChainNotifier instance. The ChannelLink will use new block - // notifications sent over this channel to decide when a _new_ HTLC is - // too close to expiry, and also when any active HTLC's have expired - // (or are close to expiry). - BlockEpochs *chainntnfs.BlockEpochEvent - // DebugHTLC should be turned on if you want all HTLCs sent to a node // with the debug htlc R-Hash are immediately settled in the next // available state transition. @@ -290,10 +281,6 @@ type channelLink struct { // method in state machine. batchCounter uint32 - // bestHeight is the best known height of the main chain. The link will - // use this information to govern decisions based on HTLC timeouts. - bestHeight uint32 - // keystoneBatch represents a volatile list of keystones that must be // written before attempting to sign the next commitment txn. These // represent all the HTLC's forwarded to the link from the switch. Once @@ -371,8 +358,8 @@ type channelLink struct { // NewChannelLink creates a new instance of a ChannelLink given a configuration // and active channel that will be used to verify/apply updates to. -func NewChannelLink(cfg ChannelLinkConfig, channel *lnwallet.LightningChannel, - currentHeight uint32) ChannelLink { +func NewChannelLink(cfg ChannelLinkConfig, + channel *lnwallet.LightningChannel) ChannelLink { return &channelLink{ cfg: cfg, @@ -381,7 +368,6 @@ func NewChannelLink(cfg ChannelLinkConfig, channel *lnwallet.LightningChannel, // TODO(roasbeef): just do reserve here? logCommitTimer: time.NewTimer(300 * time.Millisecond), overflowQueue: newPacketQueue(lnwallet.MaxHTLCNumber / 2), - bestHeight: currentHeight, htlcUpdates: make(chan []channeldb.HTLC), quit: make(chan struct{}), } @@ -804,7 +790,6 @@ func (l *channelLink) fwdPkgGarbager() { func (l *channelLink) htlcManager() { defer func() { l.wg.Done() - l.cfg.BlockEpochs.Cancel() log.Infof("ChannelLink(%v) has exited", l) }() @@ -2095,7 +2080,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, continue } - heightNow := l.bestHeight + heightNow := l.cfg.Switch.BestHeight() fwdInfo := chanIterator.ForwardingInstructions() switch fwdInfo.NextHop { diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index f145e33b4..d76090125 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "fmt" "io" + "math" "reflect" "runtime" "strings" @@ -13,12 +14,9 @@ import ( "testing" "time" - "math" - "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hodl" @@ -1057,7 +1055,7 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight, n.firstBobChannelLink, n.carolChannelLink) - daveServer, err := newMockServer(t, "dave", nil) + daveServer, err := newMockServer(t, "dave", testStartingHeight, nil) if err != nil { t.Fatalf("unable to init dave's server: %v", err) } @@ -1443,11 +1441,6 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( } var ( - globalEpoch = &chainntnfs.BlockEpochEvent{ - Epochs: make(chan *chainntnfs.BlockEpoch), - Cancel: func() { - }, - } invoiceRegistry = newMockRegistry() decoder = newMockIteratorDecoder() obfuscator = NewMockObfuscator() @@ -1468,7 +1461,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( } aliceDb := aliceChannel.State().Db - aliceSwitch, err := initSwitchWithDB(aliceDb) + aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb) if err != nil { return nil, nil, nil, nil, nil, nil, err } @@ -1495,7 +1488,6 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( }, Registry: invoiceRegistry, ChainEvents: &contractcourt.ChainEventSubscription{}, - BlockEpochs: globalEpoch, BatchTicker: ticker, FwdPkgGCTicker: NewBatchTicker(time.NewTicker(5 * time.Second)), // Make the BatchSize and Min/MaxFeeUpdateTimeout large enough @@ -1506,7 +1498,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( } const startingHeight = 100 - aliceLink := NewChannelLink(aliceCfg, aliceChannel, startingHeight) + aliceLink := NewChannelLink(aliceCfg, aliceChannel) start := func() error { return aliceSwitch.AddLink(aliceLink) } @@ -3825,11 +3817,6 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch, hodlFlags []hodl.Flag) (ChannelLink, chan time.Time, func(), error) { var ( - globalEpoch = &chainntnfs.BlockEpochEvent{ - Epochs: make(chan *chainntnfs.BlockEpoch), - Cancel: func() { - }, - } invoiceRegistry = newMockRegistry() decoder = newMockIteratorDecoder() obfuscator = NewMockObfuscator() @@ -3854,7 +3841,7 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch, if aliceSwitch == nil { var err error - aliceSwitch, err = initSwitchWithDB(aliceDb) + aliceSwitch, err = initSwitchWithDB(testStartingHeight, aliceDb) if err != nil { return nil, nil, nil, err } @@ -3880,7 +3867,6 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch, }, Registry: invoiceRegistry, ChainEvents: &contractcourt.ChainEventSubscription{}, - BlockEpochs: globalEpoch, BatchTicker: ticker, FwdPkgGCTicker: NewBatchTicker(time.NewTicker(5 * time.Second)), // Make the BatchSize and Min/MaxFeeUpdateTimeout large enough @@ -3894,7 +3880,7 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch, } const startingHeight = 100 - aliceLink := NewChannelLink(aliceCfg, aliceChannel, startingHeight) + aliceLink := NewChannelLink(aliceCfg, aliceChannel) if err := aliceSwitch.AddLink(aliceLink); err != nil { return nil, nil, nil, err } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index db1dcc655..dc47730f5 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -1,19 +1,17 @@ package htlcswitch import ( + "bytes" "crypto/sha256" "encoding/binary" "fmt" + "io" "io/ioutil" "sync" + "sync/atomic" "testing" "time" - "io" - "sync/atomic" - - "bytes" - "github.com/btcsuite/fastsha256" "github.com/go-errors/errors" "github.com/lightningnetwork/lightning-onion" @@ -122,7 +120,7 @@ type mockServer struct { var _ lnpeer.Peer = (*mockServer)(nil) -func initSwitchWithDB(db *channeldb.DB) (*Switch, error) { +func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { if db == nil { tempPath, err := ioutil.TempDir("", "switchdb") if err != nil { @@ -135,7 +133,7 @@ func initSwitchWithDB(db *channeldb.DB) (*Switch, error) { } } - return New(Config{ + cfg := Config{ DB: db, SwitchPackager: channeldb.NewSwitchPackager(), FwdingLog: &mockForwardingLog{ @@ -144,15 +142,20 @@ func initSwitchWithDB(db *channeldb.DB) (*Switch, error) { FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { return nil, nil }, - }) + Notifier: &mockNotifier{}, + } + + return New(cfg, startingHeight) } -func newMockServer(t testing.TB, name string, db *channeldb.DB) (*mockServer, error) { +func newMockServer(t testing.TB, name string, startingHeight uint32, + db *channeldb.DB) (*mockServer, error) { + var id [33]byte h := sha256.Sum256([]byte(name)) copy(id[:], h[:]) - htlcSwitch, err := initSwitchWithDB(db) + htlcSwitch, err := initSwitchWithDB(startingHeight, db) if err != nil { return nil, err } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 91d6b0e11..f7d3b5c88 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2,23 +2,22 @@ package htlcswitch import ( "bytes" + "crypto/sha256" "fmt" "sync" "sync/atomic" "time" - "crypto/sha256" - "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" - "github.com/roasbeef/btcd/btcec" - "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/roasbeef/btcd/btcec" "github.com/roasbeef/btcd/wire" "github.com/roasbeef/btcutil" ) @@ -142,6 +141,10 @@ type Config struct { // provide payment senders our latest policy when sending encrypted // error messages. FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) + + // Notifier is an instance of a chain notifier that we'll use to signal + // the switch when a new block has arrived. + Notifier chainntnfs.ChainNotifier } // Switch is the central messaging bus for all incoming/outgoing HTLCs. @@ -155,8 +158,14 @@ type Config struct { type Switch struct { started int32 // To be used atomically. shutdown int32 // To be used atomically. - wg sync.WaitGroup - quit chan struct{} + + // bestHeight is the best known height of the main chain. The links will + // be used this information to govern decisions based on HTLC timeouts. + // This will be retrieved by the registered links atomically. + bestHeight uint32 + + wg sync.WaitGroup + quit chan struct{} // cfg is a copy of the configuration struct that the htlc switch // service was initialized with. @@ -229,10 +238,15 @@ type Switch struct { // to the forwarding log. fwdEventMtx sync.Mutex pendingFwdingEvents []channeldb.ForwardingEvent + + // blockEpochStream is an active block epoch event stream backed by an + // active ChainNotifier instance. This will be used to retrieve the + // lastest height of the chain. + blockEpochStream *chainntnfs.BlockEpochEvent } // New creates the new instance of htlc switch. -func New(cfg Config) (*Switch, error) { +func New(cfg Config, currentHeight uint32) (*Switch, error) { circuitMap, err := NewCircuitMap(&CircuitMapConfig{ DB: cfg.DB, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, @@ -247,6 +261,7 @@ func New(cfg Config) (*Switch, error) { } return &Switch{ + bestHeight: currentHeight, cfg: &cfg, circuits: circuitMap, paymentSequencer: sequencer, @@ -1339,8 +1354,10 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, closeType ChannelCloseType, func (s *Switch) htlcForwarder() { defer s.wg.Done() - // Remove all links once we've been signalled for shutdown. defer func() { + s.blockEpochStream.Cancel() + + // Remove all links once we've been signalled for shutdown. s.indexMtx.Lock() for _, link := range s.linkIndex { if err := s.removeLink(link.ChanID()); err != nil { @@ -1378,8 +1395,15 @@ func (s *Switch) htlcForwarder() { fwdEventTicker := time.NewTicker(15 * time.Second) defer fwdEventTicker.Stop() +out: for { select { + case blockEpoch, ok := <-s.blockEpochStream.Epochs: + if !ok { + break out + } + + atomic.StoreUint32(&s.bestHeight, uint32(blockEpoch.Height)) // A local close request has arrived, we'll forward this to the // relevant link (if it exists) so the channel can be // cooperatively closed (if possible). @@ -1549,6 +1573,12 @@ func (s *Switch) Start() error { log.Infof("Starting HTLC Switch") + blockEpochStream, err := s.cfg.Notifier.RegisterBlockEpochNtfn() + if err != nil { + return err + } + s.blockEpochStream = blockEpochStream + s.wg.Add(1) go s.htlcForwarder() @@ -2033,3 +2063,8 @@ func (s *Switch) FlushForwardingEvents() error { // forwarding log. return s.cfg.FwdingLog.AddForwardingEvents(events) } + +// BestHeight returns the best height known to the switch. +func (s *Switch) BestHeight() uint32 { + return atomic.LoadUint32(&s.bestHeight) +} diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 82932d3c0..ae0a2586f 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -30,12 +30,12 @@ func genPreimage() ([32]byte, error) { func TestSwitchSendPending(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - s, err := initSwitchWithDB(nil) + s, err := initSwitchWithDB(testStartingHeight, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -125,16 +125,16 @@ func TestSwitchSendPending(t *testing.T) { func TestSwitchForward(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create bob server: %v", err) } - s, err := initSwitchWithDB(nil) + s, err := initSwitchWithDB(testStartingHeight, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -230,11 +230,11 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -249,7 +249,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -344,7 +344,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -421,11 +421,11 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -440,7 +440,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -535,7 +535,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -615,11 +615,11 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -634,7 +634,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -721,7 +721,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -778,11 +778,11 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -797,7 +797,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -879,7 +879,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -936,11 +936,11 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -955,7 +955,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1036,7 +1036,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -1129,7 +1129,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s3, err := initSwitchWithDB(cdb3) + s3, err := initSwitchWithDB(testStartingHeight, cdb3) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -1167,16 +1167,16 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { var packet *htlcPacket - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create bob server: %v", err) } - s, err := initSwitchWithDB(nil) + s, err := initSwitchWithDB(testStartingHeight, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1237,12 +1237,12 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { // We'll create a single link for this test, marking it as being unable // to forward form the get go. - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - s, err := initSwitchWithDB(nil) + s, err := initSwitchWithDB(testStartingHeight, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1289,16 +1289,16 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { func TestSwitchCancel(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create bob server: %v", err) } - s, err := initSwitchWithDB(nil) + s, err := initSwitchWithDB(testStartingHeight, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1402,16 +1402,16 @@ func TestSwitchAddSamePayment(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create bob server: %v", err) } - s, err := initSwitchWithDB(nil) + s, err := initSwitchWithDB(testStartingHeight, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1561,12 +1561,12 @@ func TestSwitchAddSamePayment(t *testing.T) { func TestSwitchSendPayment(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - s, err := initSwitchWithDB(nil) + s, err := initSwitchWithDB(testStartingHeight, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1805,8 +1805,6 @@ func TestMultiHopPaymentForwardingEvents(t *testing.T) { } } - time.Sleep(time.Millisecond * 200) - // With all 10 payments sent. We'll now manually stop each of the // switches so we can examine their end state. n.stop() diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 06760bd13..505b87c7e 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -17,7 +17,6 @@ import ( "github.com/btcsuite/fastsha256" "github.com/coreos/bbolt" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/keychain" @@ -569,22 +568,13 @@ func generateRoute(hops ...ForwardingInfo) ([lnwire.OnionPacketSize]byte, error) type threeHopNetwork struct { aliceServer *mockServer aliceChannelLink *channelLink - aliceBlockEpoch chan *chainntnfs.BlockEpoch - aliceTicker *time.Ticker - - firstBobChannelLink *channelLink - bobFirstBlockEpoch chan *chainntnfs.BlockEpoch - firstBobTicker *time.Ticker bobServer *mockServer + firstBobChannelLink *channelLink secondBobChannelLink *channelLink - bobSecondBlockEpoch chan *chainntnfs.BlockEpoch - secondBobTicker *time.Ticker - carolChannelLink *channelLink carolServer *mockServer - carolBlockEpoch chan *chainntnfs.BlockEpoch - carolTicker *time.Ticker + carolChannelLink *channelLink feeEstimator *mockFeeEstimator @@ -762,11 +752,6 @@ func (n *threeHopNetwork) stop() { done <- struct{}{} }() - n.aliceTicker.Stop() - n.firstBobTicker.Stop() - n.secondBobTicker.Stop() - n.carolTicker.Stop() - for i := 0; i < 3; i++ { <-done } @@ -858,15 +843,15 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, carolDb := carolChannel.State().Db // Create three peers/servers. - aliceServer, err := newMockServer(t, "alice", aliceDb) + aliceServer, err := newMockServer(t, "alice", startingHeight, aliceDb) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobServer, err := newMockServer(t, "bob", bobDb) + bobServer, err := newMockServer(t, "bob", startingHeight, bobDb) if err != nil { t.Fatalf("unable to create bob server: %v", err) } - carolServer, err := newMockServer(t, "carol", carolDb) + carolServer, err := newMockServer(t, "carol", startingHeight, carolDb) if err != nil { t.Fatalf("unable to create carol server: %v", err) } @@ -900,13 +885,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, } obfuscator := NewMockObfuscator() - aliceEpochChan := make(chan *chainntnfs.BlockEpoch) - aliceEpoch := &chainntnfs.BlockEpochEvent{ - Epochs: aliceEpochChan, - Cancel: func() { - }, - } - aliceTicker := time.NewTicker(50 * time.Millisecond) aliceChannelLink := NewChannelLink( ChannelLinkConfig{ Switch: aliceServer.htlcSwitch, @@ -921,7 +899,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, }, FetchLastChannelUpdate: mockGetChanUpdateMessage, Registry: aliceServer.registry, - BlockEpochs: aliceEpoch, FeeEstimator: feeEstimator, PreimageCache: pCache, UpdateContractSignals: func(*contractcourt.ContractSignals) error { @@ -937,7 +914,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, }, aliceChannel, - startingHeight, ) if err := aliceServer.htlcSwitch.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice channel link: %v", err) @@ -952,13 +928,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, } }() - bobFirstEpochChan := make(chan *chainntnfs.BlockEpoch) - bobFirstEpoch := &chainntnfs.BlockEpochEvent{ - Epochs: bobFirstEpochChan, - Cancel: func() { - }, - } - firstBobTicker := time.NewTicker(50 * time.Millisecond) firstBobChannelLink := NewChannelLink( ChannelLinkConfig{ Switch: bobServer.htlcSwitch, @@ -973,7 +942,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, }, FetchLastChannelUpdate: mockGetChanUpdateMessage, Registry: bobServer.registry, - BlockEpochs: bobFirstEpoch, FeeEstimator: feeEstimator, PreimageCache: pCache, UpdateContractSignals: func(*contractcourt.ContractSignals) error { @@ -989,7 +957,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, }, firstBobChannel, - startingHeight, ) if err := bobServer.htlcSwitch.AddLink(firstBobChannelLink); err != nil { t.Fatalf("unable to add first bob channel link: %v", err) @@ -1004,13 +971,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, } }() - bobSecondEpochChan := make(chan *chainntnfs.BlockEpoch) - bobSecondEpoch := &chainntnfs.BlockEpochEvent{ - Epochs: bobSecondEpochChan, - Cancel: func() { - }, - } - secondBobTicker := time.NewTicker(50 * time.Millisecond) secondBobChannelLink := NewChannelLink( ChannelLinkConfig{ Switch: bobServer.htlcSwitch, @@ -1025,7 +985,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, }, FetchLastChannelUpdate: mockGetChanUpdateMessage, Registry: bobServer.registry, - BlockEpochs: bobSecondEpoch, FeeEstimator: feeEstimator, PreimageCache: pCache, UpdateContractSignals: func(*contractcourt.ContractSignals) error { @@ -1041,7 +1000,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, }, secondBobChannel, - startingHeight, ) if err := bobServer.htlcSwitch.AddLink(secondBobChannelLink); err != nil { t.Fatalf("unable to add second bob channel link: %v", err) @@ -1056,13 +1014,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, } }() - carolBlockEpoch := make(chan *chainntnfs.BlockEpoch) - carolEpoch := &chainntnfs.BlockEpochEvent{ - Epochs: bobSecondEpochChan, - Cancel: func() { - }, - } - carolTicker := time.NewTicker(50 * time.Millisecond) carolChannelLink := NewChannelLink( ChannelLinkConfig{ Switch: carolServer.htlcSwitch, @@ -1077,7 +1028,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, }, FetchLastChannelUpdate: mockGetChanUpdateMessage, Registry: carolServer.registry, - BlockEpochs: carolEpoch, FeeEstimator: feeEstimator, PreimageCache: pCache, UpdateContractSignals: func(*contractcourt.ContractSignals) error { @@ -1093,7 +1043,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, }, carolChannel, - startingHeight, ) if err := carolServer.htlcSwitch.AddLink(carolChannelLink); err != nil { t.Fatalf("unable to add carol channel link: %v", err) @@ -1111,22 +1060,13 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, return &threeHopNetwork{ aliceServer: aliceServer, aliceChannelLink: aliceChannelLink.(*channelLink), - aliceBlockEpoch: aliceEpochChan, - aliceTicker: aliceTicker, - - firstBobChannelLink: firstBobChannelLink.(*channelLink), - bobFirstBlockEpoch: bobFirstEpochChan, - firstBobTicker: firstBobTicker, bobServer: bobServer, + firstBobChannelLink: firstBobChannelLink.(*channelLink), secondBobChannelLink: secondBobChannelLink.(*channelLink), - bobSecondBlockEpoch: bobSecondEpochChan, - secondBobTicker: secondBobTicker, - carolChannelLink: carolChannelLink.(*channelLink), carolServer: carolServer, - carolBlockEpoch: carolBlockEpoch, - carolTicker: carolTicker, + carolChannelLink: carolChannelLink.(*channelLink), feeEstimator: feeEstimator, globalPolicy: globalPolicy, diff --git a/peer.go b/peer.go index 26537cd29..91058d7db 100644 --- a/peer.go +++ b/peer.go @@ -535,7 +535,6 @@ func (p *peer) addLink(chanPoint *wire.OutPoint, ForwardPackets: p.server.htlcSwitch.ForwardPackets, FwrdingPolicy: *forwardingPolicy, FeeEstimator: p.server.cc.feeEstimator, - BlockEpochs: blockEpoch, PreimageCache: p.server.witnessBeacon, ChainEvents: chainEvents, UpdateContractSignals: func(signals *contractcourt.ContractSignals) error { @@ -555,9 +554,7 @@ func (p *peer) addLink(chanPoint *wire.OutPoint, MaxFeeUpdateTimeout: htlcswitch.DefaultMaxLinkFeeUpdateTimeout, } - link := htlcswitch.NewChannelLink( - linkCfg, lnChan, uint32(currentHeight), - ) + link := htlcswitch.NewChannelLink(linkCfg, lnChan) // With the channel link created, we'll now notify the htlc switch so // this channel can be used to dispatch local payments and also diff --git a/server.go b/server.go index 45ad1dbb2..743affb14 100644 --- a/server.go +++ b/server.go @@ -284,6 +284,11 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, debugPre[:], debugHash[:]) } + _, currentHeight, err := s.cc.chainIO.GetBestBlock() + if err != nil { + return nil, err + } + s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{ DB: chanDB, SelfKey: s.identityPriv.PubKey(), @@ -313,7 +318,8 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, SwitchPackager: channeldb.NewSwitchPackager(), ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter, FetchLastChannelUpdate: fetchLastChanUpdate(s, serializedPubKey), - }) + Notifier: s.cc.chainNotifier, + }, uint32(currentHeight)) if err != nil { return nil, err } diff --git a/test_utils.go b/test_utils.go index 0baf779d7..98531120e 100644 --- a/test_utils.go +++ b/test_utils.go @@ -341,10 +341,17 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, breachArbiter: breachArbiter, chainArb: chainArb, } + + _, currentHeight, err := s.cc.chainIO.GetBestBlock() + if err != nil { + return nil, nil, nil, nil, err + } + htlcSwitch, err := htlcswitch.New(htlcswitch.Config{ DB: dbAlice, SwitchPackager: channeldb.NewSwitchPackager(), - }) + Notifier: notifier, + }, uint32(currentHeight)) if err != nil { return nil, nil, nil, nil, err }