diff --git a/brontide/listener.go b/brontide/listener.go index 1cc171f66..603a6b1a8 100644 --- a/brontide/listener.go +++ b/brontide/listener.go @@ -1,6 +1,7 @@ package brontide import ( + "errors" "io" "net" "time" @@ -8,6 +9,10 @@ import ( "github.com/roasbeef/btcd/btcec" ) +// defaultHandshakes is the maximum number of handshakes that can be done in +// parallel. +const defaultHandshakes = 1000 + // Listener is an implementation of a net.Conn which executes an authenticated // key exchange and message encryption protocol dubbed "Machine" after // initial connection acceptance. See the Machine struct for additional @@ -17,6 +22,10 @@ type Listener struct { localStatic *btcec.PrivateKey tcp *net.TCPListener + + handshakeSema chan struct{} + conns chan maybeConn + quit chan struct{} } // A compile-time assertion to ensure that Conn meets the net.Listener interface. @@ -36,23 +45,57 @@ func NewListener(localStatic *btcec.PrivateKey, listenAddr string) (*Listener, return nil, err } - return &Listener{ - localStatic: localStatic, - tcp: l, - }, nil + brontideListener := &Listener{ + localStatic: localStatic, + tcp: l, + handshakeSema: make(chan struct{}, defaultHandshakes), + conns: make(chan maybeConn), + quit: make(chan struct{}), + } + + for i := 0; i < defaultHandshakes; i++ { + brontideListener.handshakeSema <- struct{}{} + } + + go brontideListener.listen() + + return brontideListener, nil } -// Accept waits for and returns the next connection to the listener. All -// incoming connections are authenticated via the three act Brontide -// key-exchange scheme. This function will fail with a non-nil error in the -// case that either the handshake breaks down, or the remote peer doesn't know -// our static public key. +// listen accepts connection from the underlying tcp conn, then performs +// the brontinde handshake procedure asynchronously. A maximum of +// defaultHandshakes will be active at any given time. // -// Part of the net.Listener interface. -func (l *Listener) Accept() (net.Conn, error) { - conn, err := l.tcp.Accept() - if err != nil { - return nil, err +// NOTE: This method must be run as a goroutine. +func (l *Listener) listen() { + for { + select { + case <-l.handshakeSema: + case <-l.quit: + return + } + + conn, err := l.tcp.Accept() + if err != nil { + l.rejectConn(err) + l.handshakeSema <- struct{}{} + continue + } + + go l.doHandshake(conn) + } +} + +// doHandshake asynchronously performs the brontide handshake, so that it does +// not block the main accept loop. This prevents peers that delay writing to the +// connection from block other connection attempts. +func (l *Listener) doHandshake(conn net.Conn) { + defer func() { l.handshakeSema <- struct{}{} }() + + select { + case <-l.quit: + return + default: } brontideConn := &Conn{ @@ -71,11 +114,13 @@ func (l *Listener) Accept() (net.Conn, error) { var actOne [ActOneSize]byte if _, err := io.ReadFull(conn, actOne[:]); err != nil { brontideConn.conn.Close() - return nil, err + l.rejectConn(err) + return } if err := brontideConn.noise.RecvActOne(actOne); err != nil { brontideConn.conn.Close() - return nil, err + l.rejectConn(err) + return } // Next, progress the handshake processes by sending over our ephemeral @@ -83,11 +128,19 @@ func (l *Listener) Accept() (net.Conn, error) { actTwo, err := brontideConn.noise.GenActTwo() if err != nil { brontideConn.conn.Close() - return nil, err + l.rejectConn(err) + return } if _, err := conn.Write(actTwo[:]); err != nil { brontideConn.conn.Close() - return nil, err + l.rejectConn(err) + return + } + + select { + case <-l.quit: + return + default: } // We'll ensure that we get ActTwo from the remote peer in a timely @@ -101,18 +154,59 @@ func (l *Listener) Accept() (net.Conn, error) { var actThree [ActThreeSize]byte if _, err := io.ReadFull(conn, actThree[:]); err != nil { brontideConn.conn.Close() - return nil, err + l.rejectConn(err) + return } if err := brontideConn.noise.RecvActThree(actThree); err != nil { brontideConn.conn.Close() - return nil, err + l.rejectConn(err) + return } // We'll reset the deadline as it's no longer critical beyond the // initial handshake. conn.SetReadDeadline(time.Time{}) - return brontideConn, nil + l.acceptConn(brontideConn) +} + +// maybeConn holds either a brontide connection or an error returned from the +// handshake. +type maybeConn struct { + conn *Conn + err error +} + +// acceptConn returns a connection that successfully performed a handshake. +func (l *Listener) acceptConn(conn *Conn) { + select { + case l.conns <- maybeConn{conn: conn}: + case <-l.quit: + } +} + +// rejectConn returns any errors encountered during connection or handshake. +func (l *Listener) rejectConn(err error) { + select { + case l.conns <- maybeConn{err: err}: + case <-l.quit: + } +} + +// Accept waits for and returns the next connection to the listener. All +// incoming connections are authenticated via the three act Brontide +// key-exchange scheme. This function will fail with a non-nil error in the +// case that either the handshake breaks down, or the remote peer doesn't know +// our static public key. +// +// Part of the net.Listener interface. +func (l *Listener) Accept() (net.Conn, error) { + select { + case result := <-l.conns: + return result.conn, result.err + case <-l.quit: + return nil, errors.New("brontide connection closed") + } } // Close closes the listener. Any blocked Accept operations will be unblocked @@ -120,6 +214,12 @@ func (l *Listener) Accept() (net.Conn, error) { // // Part of the net.Listener interface. func (l *Listener) Close() error { + select { + case <-l.quit: + default: + close(l.quit) + } + return l.tcp.Close() } diff --git a/brontide/noise_test.go b/brontide/noise_test.go index 5b0562c21..415659cad 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -13,16 +13,16 @@ import ( "github.com/roasbeef/btcd/btcec" ) -func establishTestConnection() (net.Conn, net.Conn, func(), error) { - // First, generate the long-term private keys both ends of the - // connection within our test. +type maybeNetConn struct { + conn net.Conn + err error +} + +func makeListener() (*Listener, *lnwire.NetAddress, error) { + // First, generate the long-term private keys for the brontide listener. localPriv, err := btcec.NewPrivateKey(btcec.S256()) if err != nil { - return nil, nil, nil, err - } - remotePriv, err := btcec.NewPrivateKey(btcec.S256()) - if err != nil { - return nil, nil, nil, err + return nil, nil, err } // Having a port of ":0" means a random port, and interface will be @@ -32,56 +32,62 @@ func establishTestConnection() (net.Conn, net.Conn, func(), error) { // Our listener will be local, and the connection remote. listener, err := NewListener(localPriv, addr) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - defer listener.Close() netAddr := &lnwire.NetAddress{ IdentityKey: localPriv.PubKey(), Address: listener.Addr().(*net.TCPAddr), } + return listener, netAddr, nil +} + +func establishTestConnection() (net.Conn, net.Conn, func(), error) { + listener, netAddr, err := makeListener() + if err != nil { + return nil, nil, nil, err + } + defer listener.Close() + + // Nos, generate the long-term private keys remote end of the connection + // within our test. + remotePriv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return nil, nil, nil, err + } + // Initiate a connection with a separate goroutine, and listen with our // main one. If both errors are nil, then encryption+auth was // successful. - conErrChan := make(chan error, 1) - connChan := make(chan net.Conn, 1) + remoteConnChan := make(chan maybeNetConn, 1) go func() { - conn, err := Dial(remotePriv, netAddr, net.Dial) - - conErrChan <- err - connChan <- conn + remoteConn, err := Dial(remotePriv, netAddr, net.Dial) + remoteConnChan <- maybeNetConn{remoteConn, err} }() - lisErrChan := make(chan error, 1) - lisChan := make(chan net.Conn, 1) + localConnChan := make(chan maybeNetConn, 1) go func() { - localConn, listenErr := listener.Accept() - - lisErrChan <- listenErr - lisChan <- localConn + localConn, err := listener.Accept() + localConnChan <- maybeNetConn{localConn, err} }() - select { - case err := <-conErrChan: - if err != nil { - return nil, nil, nil, err - } - case err := <-lisErrChan: - if err != nil { - return nil, nil, nil, err - } + remote := <-remoteConnChan + if remote.err != nil { + return nil, nil, nil, err } - localConn := <-lisChan - remoteConn := <-connChan + local := <-localConnChan + if local.err != nil { + return nil, nil, nil, err + } cleanUp := func() { - localConn.Close() - remoteConn.Close() + local.conn.Close() + remote.conn.Close() } - return localConn, remoteConn, cleanUp, nil + return local.conn, remote.conn, cleanUp, nil } func TestConnectionCorrectness(t *testing.T) { @@ -134,14 +140,84 @@ func TestConnectionCorrectness(t *testing.T) { } } +// TestConecurrentHandshakes verifies the listener's ability to not be blocked +// by other pending handshakes. This is tested by opening multiple tcp +// connections with the listener, without completing any of the brontide acts. +// The test passes if real brontide dialer connects while the others are +// stalled. +func TestConcurrentHandshakes(t *testing.T) { + listener, netAddr, err := makeListener() + if err != nil { + t.Fatalf("unable to create listener connection: %v", err) + } + defer listener.Close() + + const nblocking = 5 + + // Open a handful of tcp connections, that do not complete any steps of + // the brontide handshake. + connChan := make(chan maybeNetConn) + for i := 0; i < nblocking; i++ { + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + connChan <- maybeNetConn{conn, err} + }() + } + + // Receive all connections/errors from our blocking tcp dials. We make a + // pass to gather all connections and errors to make sure we defer the + // calls to Close() on all successful connections. + tcpErrs := make([]error, 0, nblocking) + for i := 0; i < nblocking; i++ { + result := <-connChan + if result.conn != nil { + defer result.conn.Close() + } + if result.err != nil { + tcpErrs = append(tcpErrs, result.err) + } + } + for _, tcpErr := range tcpErrs { + if tcpErr != nil { + t.Fatalf("unable to tcp dial listener: %v", tcpErr) + } + } + + // Now, construct a new private key and use the brontide dialer to + // connect to the listener. + remotePriv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate private key: %v", err) + } + + go func() { + remoteConn, err := Dial(remotePriv, netAddr, net.Dial) + connChan <- maybeNetConn{remoteConn, err} + }() + + // This connection should be accepted without error, as the brontide + // connection should bypass stalled tcp connections. + conn, err := listener.Accept() + if err != nil { + t.Fatalf("unable to accept dial: %v", err) + } + defer conn.Close() + + result := <-connChan + if result.err != nil { + t.Fatalf("unable to dial %v: %v", netAddr, result.err) + } + result.conn.Close() +} + func TestMaxPayloadLength(t *testing.T) { t.Parallel() b := Machine{} b.split() - // Create a payload that's only *slightly* above the maximum allotted payload - // length. + // Create a payload that's only *slightly* above the maximum allotted + // payload length. payloadToReject := make([]byte, math.MaxUint16+1) var buf bytes.Buffer