diff --git a/cert/go.mod b/cert/go.mod index 8cd6cc1a4..6ee7f3f5b 100644 --- a/cert/go.mod +++ b/cert/go.mod @@ -2,4 +2,7 @@ module github.com/lightningnetwork/lnd/cert go 1.16 -require github.com/stretchr/testify v1.5.1 +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/stretchr/testify v1.5.1 +) diff --git a/cert/go.sum b/cert/go.sum index 331fa6982..e15f5f98d 100644 --- a/cert/go.sum +++ b/cert/go.sum @@ -1,5 +1,6 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/cert/selfsigned.go b/cert/selfsigned.go index dd87fb0b7..b8e01521d 100644 --- a/cert/selfsigned.go +++ b/cert/selfsigned.go @@ -27,7 +27,9 @@ var ( // ipAddresses returns the parserd IP addresses to use when creating the TLS // certificate. If tlsDisableAutofill is true, we don't include interface // addresses to protect users privacy. -func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP, error) { +func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP, + error) { + // Collect the host's IP addresses, including loopback, in a slice. ipAddresses := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")} @@ -71,7 +73,9 @@ func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP, error // dnsNames returns the host and DNS names to use when creating the TLS // ceftificate. -func dnsNames(tlsExtraDomains []string, tlsDisableAutofill bool) (string, []string) { +func dnsNames(tlsExtraDomains []string, tlsDisableAutofill bool) (string, + []string) { + // Collect the host's names into a slice. host, err := os.Hostname() @@ -187,9 +191,10 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs, return false, nil } -// GenCertPair generates a key/cert pair to the paths provided. The -// auto-generated certificates should *not* be used in production for public -// access as they're self-signed and don't necessarily contain all of the +// GenCertPair generates a key/cert pair and returns the pair in byte form. +// +// The auto-generated certificates should *not* be used in production for +// public access as they're self-signed and don't necessarily contain all of the // desired hostnames for the service. For production/public use, consider a // real PKI. // @@ -197,7 +202,7 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs, // https://github.com/btcsuite/btcd/btcutil func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, tlsExtraDomains []string, tlsDisableAutofill bool, - certValidity time.Duration) error { + certValidity time.Duration) ([]byte, []byte, error) { now := time.Now() validUntil := now.Add(certValidity) @@ -210,7 +215,8 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, // Generate a serial number that's below the serialNumberLimit. serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { - return fmt.Errorf("failed to generate serial number: %s", err) + return nil, nil, fmt.Errorf("failed to generate serial "+ + "number: %s", err) } // Get all DNS names and IP addresses to use when creating the @@ -218,13 +224,13 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, host, dnsNames := dnsNames(tlsExtraDomains, tlsDisableAutofill) ipAddresses, err := ipAddresses(tlsExtraIPs, tlsDisableAutofill) if err != nil { - return err + return nil, nil, err } // Generate a private key for the certificate. priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { - return err + return nil, nil, err } // Construct the certificate template. @@ -238,8 +244,11 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, NotAfter: validUntil, KeyUsage: x509.KeyUsageKeyEncipherment | - x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + x509.KeyUsageDigitalSignature | + x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, IsCA: true, // so can sign self. BasicConstraintsValid: true, @@ -247,37 +256,57 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, IPAddresses: ipAddresses, } - derBytes, err := x509.CreateCertificate(rand.Reader, &template, - &template, &priv.PublicKey, priv) + derBytes, err := x509.CreateCertificate( + rand.Reader, &template, + &template, &priv.PublicKey, priv, + ) if err != nil { - return fmt.Errorf("failed to create certificate: %v", err) + return nil, nil, fmt.Errorf("failed to create certificate: %v", + err) } certBuf := &bytes.Buffer{} - err = pem.Encode(certBuf, &pem.Block{Type: "CERTIFICATE", - Bytes: derBytes}) + err = pem.Encode( + certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}, + ) if err != nil { - return fmt.Errorf("failed to encode certificate: %v", err) + return nil, nil, fmt.Errorf("failed to encode certificate: %v", + err) } keybytes, err := x509.MarshalECPrivateKey(priv) if err != nil { - return fmt.Errorf("unable to encode privkey: %v", err) + return nil, nil, fmt.Errorf("unable to encode privkey: %v", + err) } keyBuf := &bytes.Buffer{} - err = pem.Encode(keyBuf, &pem.Block{Type: "EC PRIVATE KEY", - Bytes: keybytes}) + err = pem.Encode( + keyBuf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keybytes}, + ) if err != nil { - return fmt.Errorf("failed to encode private key: %v", err) + return nil, nil, fmt.Errorf("failed to encode private key: %v", + err) } + return certBuf.Bytes(), keyBuf.Bytes(), nil +} + +// WriteCertPair writes certificate and key data to disk if a path is provided. +func WriteCertPair(certFile, keyFile string, certBytes, keyBytes []byte) error { // Write cert and key files. - if err = ioutil.WriteFile(certFile, certBuf.Bytes(), 0644); err != nil { - return err + if certFile != "" { + err := ioutil.WriteFile(certFile, certBytes, 0644) + if err != nil { + return err + } } - if err = ioutil.WriteFile(keyFile, keyBuf.Bytes(), 0600); err != nil { - os.Remove(certFile) - return err + + if keyFile != "" { + err := ioutil.WriteFile(keyFile, keyBytes, 0600) + if err != nil { + os.Remove(certFile) + return err + } } return nil diff --git a/cert/selfsigned_test.go b/cert/selfsigned_test.go index 3f5e694b2..402334d99 100644 --- a/cert/selfsigned_test.go +++ b/cert/selfsigned_test.go @@ -1,6 +1,8 @@ package cert_test import ( + "io/ioutil" + "path/filepath" "testing" "time" @@ -26,20 +28,28 @@ func TestIsOutdatedCert(t *testing.T) { keyPath := tempDir + "/tls.key" // Generate TLS files with two extra IPs and domains. - err := cert.GenCertPair( + certBytes, keyBytes, err := cert.GenCertPair( "lnd autogenerated cert", certPath, keyPath, extraIPs[:2], extraDomains[:2], false, testTLSCertDuration, ) if err != nil { t.Fatal(err) } + err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes) + require.NoError(t, err) // We'll attempt to check up-to-date status for all variants of 1-3 // number of IPs and domains. for numIPs := 1; numIPs <= len(extraIPs); numIPs++ { for numDomains := 1; numDomains <= len(extraDomains); numDomains++ { - _, parsedCert, err := cert.LoadCert( - certPath, keyPath, + certBytes, err := ioutil.ReadFile(certPath) + require.NoError(t, err) + + keyBytes, err := ioutil.ReadFile(keyPath) + require.NoError(t, err) + + _, parsedCert, err := cert.LoadCertFromBytes( + certBytes, keyBytes, ) if err != nil { t.Fatal(err) @@ -78,17 +88,24 @@ func TestIsOutdatedPermutation(t *testing.T) { keyPath := tempDir + "/tls.key" // Generate TLS files from the IPs and domains. - err := cert.GenCertPair( + certBytes, keyBytes, err := cert.GenCertPair( "lnd autogenerated cert", certPath, keyPath, extraIPs[:], extraDomains[:], false, testTLSCertDuration, ) if err != nil { t.Fatal(err) } - _, parsedCert, err := cert.LoadCert(certPath, keyPath) - if err != nil { - t.Fatal(err) - } + err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes) + require.NoError(t, err) + + certBytes, err = ioutil.ReadFile(certPath) + require.NoError(t, err) + + keyBytes, err = ioutil.ReadFile(keyPath) + require.NoError(t, err) + + _, parsedCert, err := cert.LoadCertFromBytes(certBytes, keyBytes) + require.NoError(t, err) // If we have duplicate IPs or DNS names listed, that shouldn't matter. dupIPs := make([]string, len(extraIPs)*2) @@ -142,7 +159,7 @@ func TestTLSDisableAutofill(t *testing.T) { keyPath := tempDir + "/tls.key" // Generate TLS files with two extra IPs and domains and no interface IPs. - err := cert.GenCertPair( + certBytes, keyBytes, err := cert.GenCertPair( "lnd autogenerated cert", certPath, keyPath, extraIPs[:2], extraDomains[:2], true, testTLSCertDuration, ) @@ -150,9 +167,19 @@ func TestTLSDisableAutofill(t *testing.T) { t, err, "unable to generate tls certificate pair", ) + err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes) + require.NoError(t, err) - _, parsedCert, err := cert.LoadCert( - certPath, keyPath, + // Read certs from disk. + certBytes, err = ioutil.ReadFile(certPath) + require.NoError(t, err) + + keyBytes, err = ioutil.ReadFile(keyPath) + require.NoError(t, err) + + // Load the certificate. + _, parsedCert, err := cert.LoadCertFromBytes( + certBytes, keyBytes, ) require.NoError( t, err, @@ -160,7 +187,7 @@ func TestTLSDisableAutofill(t *testing.T) { ) // Check if the TLS cert is outdated while still preventing - // interface IPs from being used. Should not be outdated + // interface IPs from being used. Should not be outdated. shouldNotBeOutdated, err := cert.IsOutdated( parsedCert, extraIPs[:2], extraDomains[:2], true, @@ -185,3 +212,51 @@ func TestTLSDisableAutofill(t *testing.T) { "TLS Certificate was not marked as outdated when it should be", ) } + +// TestTLSConfig tests to ensure we can generate a TLS Config from +// a tls cert and tls key. +func TestTLSConfig(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "/tls.cert") + keyPath := filepath.Join(tempDir, "/tls.key") + + // Generate TLS files with an extra IP and domain. + certBytes, keyBytes, err := cert.GenCertPair( + "lnd autogenerated cert", certPath, keyPath, + []string{extraIPs[0]}, []string{extraDomains[0]}, false, + testTLSCertDuration, + ) + require.NoError(t, err) + + err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes) + require.NoError(t, err) + + certBytes, err = ioutil.ReadFile(certPath) + require.NoError(t, err) + + keyBytes, err = ioutil.ReadFile(keyPath) + require.NoError(t, err) + + // Load the certificate. + certData, parsedCert, err := cert.LoadCertFromBytes( + certBytes, keyBytes, + ) + require.NoError(t, err) + + // Check to make sure the IP and domain are in the cert. + var foundIp bool + require.Contains(t, parsedCert.DNSNames, extraDomains[0]) + for _, ip := range parsedCert.IPAddresses { + if ip.String() == extraIPs[0] { + foundIp = true + break + } + } + require.Equal(t, true, foundIp, "Did not find required ip inside of "+ + "TLS Certificate.") + + // Create TLS Config. + tlsCfg := cert.TLSConfFromCert(certData) + + require.Equal(t, 1, len(tlsCfg.Certificates)) +} diff --git a/cert/tls.go b/cert/tls.go index a8783158e..755b8005f 100644 --- a/cert/tls.go +++ b/cert/tls.go @@ -3,6 +3,8 @@ package cert import ( "crypto/tls" "crypto/x509" + "io/ioutil" + "sync" ) var ( @@ -24,6 +26,22 @@ var ( } ) +// GetCertBytesFromPath reads the TLS certificate and key files at the given +// certPath and keyPath and returns the file bytes. +func GetCertBytesFromPath(certPath, keyPath string) (certBytes, + keyBytes []byte, err error) { + + certBytes, err = ioutil.ReadFile(certPath) + if err != nil { + return nil, nil, err + } + keyBytes, err = ioutil.ReadFile(keyPath) + if err != nil { + return nil, nil, err + } + return certBytes, keyBytes, nil +} + // LoadCert loads a certificate and its corresponding private key from the PEM // files indicated and returns the certificate in the two formats it is most // commonly used. @@ -49,6 +67,31 @@ func LoadCert(certPath, keyPath string) (tls.Certificate, *x509.Certificate, return certData, x509Cert, nil } +// LoadCertFromBytes loads a certificate and its corresponding private key from +// the PEM bytes indicated and returns the certificate in the two formats it is +// most commonly used. +func LoadCertFromBytes(certBytes, keyBytes []byte) (tls.Certificate, + *x509.Certificate, error) { + + // The certData returned here is just a wrapper around the PEM blocks + // loaded from the file. The PEM is not yet fully parsed but a basic + // check is performed that the certificate and private key actually + // belong together. + certData, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + return tls.Certificate{}, nil, err + } + + // Now parse the the PEM block of the certificate into its x509 data + // structure so it can be examined in more detail. + x509Cert, err := x509.ParseCertificate(certData.Certificate[0]) + if err != nil { + return tls.Certificate{}, nil, err + } + + return certData, x509Cert, nil +} + // TLSConfFromCert returns the default TLS configuration used for a server, // using the given certificate as identity. func TLSConfFromCert(certData tls.Certificate) *tls.Config { @@ -58,3 +101,51 @@ func TLSConfFromCert(certData tls.Certificate) *tls.Config { MinVersion: tls.VersionTLS12, } } + +// TLSReloader updates the TLS certificate without restarting the server. +type TLSReloader struct { + certMu sync.RWMutex + cert *tls.Certificate +} + +// NewTLSReloader is used to create a new TLS Reloader that will be used +// to update the TLS certificate without restarting the server. +func NewTLSReloader(certBytes, keyBytes []byte) (*TLSReloader, error) { + cert, _, err := LoadCertFromBytes(certBytes, keyBytes) + if err != nil { + return nil, err + } + return &TLSReloader{ + cert: &cert, + }, nil +} + +// AttemptReload will make an attempt to update the TLS certificate +// and key used by the server. +func (t *TLSReloader) AttemptReload(certBytes, keyBytes []byte) error { + newCert, _, err := LoadCertFromBytes(certBytes, keyBytes) + if err != nil { + return err + } + + t.certMu.Lock() + t.cert = &newCert + t.certMu.Unlock() + + return nil +} + +// GetCertificateFunc is used in the server's TLS configuration to +// determine the correct TLS certificate to server on a request. +func (t *TLSReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) ( + *tls.Certificate, error) { + + return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, + error) { + + t.certMu.RLock() + defer t.certMu.RUnlock() + + return t.cert, nil + } +} diff --git a/docs/release-notes/release-notes-0.16.0.md b/docs/release-notes/release-notes-0.16.0.md index e40234067..3f1363534 100644 --- a/docs/release-notes/release-notes-0.16.0.md +++ b/docs/release-notes/release-notes-0.16.0.md @@ -125,6 +125,9 @@ certain large transactions](https://github.com/lightningnetwork/lnd/pull/7100). * [Stop sending a synchronizing error on the wire when out of sync](https://github.com/lightningnetwork/lnd/pull/7039). +* [Update cert module](https://github.com/lightningnetwork/lnd/pull/6573) to + allow a way to update the tls certificate without restarting lnd. + ## `lncli` * [Add an `insecure` flag to skip tls auth as well as a `metadata` string slice flag](https://github.com/lightningnetwork/lnd/pull/6818) that allows the