Merge pull request #6573 from voltagecloud/tls-reloader

cert: add TLS reloader
This commit is contained in:
Oliver Gugger 2022-11-02 09:10:15 +01:00 committed by GitHub
commit 926fdf486c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 242 additions and 40 deletions

View file

@ -2,4 +2,7 @@ module github.com/lightningnetwork/lnd/cert
go 1.16 go 1.16
require github.com/stretchr/testify v1.5.1 require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/stretchr/testify v1.5.1
)

View file

@ -1,5 +1,6 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View file

@ -27,7 +27,9 @@ var (
// ipAddresses returns the parserd IP addresses to use when creating the TLS // ipAddresses returns the parserd IP addresses to use when creating the TLS
// certificate. If tlsDisableAutofill is true, we don't include interface // certificate. If tlsDisableAutofill is true, we don't include interface
// addresses to protect users privacy. // addresses to protect users privacy.
func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP, error) { func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP,
error) {
// Collect the host's IP addresses, including loopback, in a slice. // Collect the host's IP addresses, including loopback, in a slice.
ipAddresses := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")} ipAddresses := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
@ -71,7 +73,9 @@ func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP, error
// dnsNames returns the host and DNS names to use when creating the TLS // dnsNames returns the host and DNS names to use when creating the TLS
// ceftificate. // ceftificate.
func dnsNames(tlsExtraDomains []string, tlsDisableAutofill bool) (string, []string) { func dnsNames(tlsExtraDomains []string, tlsDisableAutofill bool) (string,
[]string) {
// Collect the host's names into a slice. // Collect the host's names into a slice.
host, err := os.Hostname() host, err := os.Hostname()
@ -187,9 +191,10 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs,
return false, nil return false, nil
} }
// GenCertPair generates a key/cert pair to the paths provided. The // GenCertPair generates a key/cert pair and returns the pair in byte form.
// auto-generated certificates should *not* be used in production for public //
// access as they're self-signed and don't necessarily contain all of the // The auto-generated certificates should *not* be used in production for
// public access as they're self-signed and don't necessarily contain all of the
// desired hostnames for the service. For production/public use, consider a // desired hostnames for the service. For production/public use, consider a
// real PKI. // real PKI.
// //
@ -197,7 +202,7 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs,
// https://github.com/btcsuite/btcd/btcutil // https://github.com/btcsuite/btcd/btcutil
func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
tlsExtraDomains []string, tlsDisableAutofill bool, tlsExtraDomains []string, tlsDisableAutofill bool,
certValidity time.Duration) error { certValidity time.Duration) ([]byte, []byte, error) {
now := time.Now() now := time.Now()
validUntil := now.Add(certValidity) validUntil := now.Add(certValidity)
@ -210,7 +215,8 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
// Generate a serial number that's below the serialNumberLimit. // Generate a serial number that's below the serialNumberLimit.
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil { if err != nil {
return fmt.Errorf("failed to generate serial number: %s", err) return nil, nil, fmt.Errorf("failed to generate serial "+
"number: %s", err)
} }
// Get all DNS names and IP addresses to use when creating the // Get all DNS names and IP addresses to use when creating the
@ -218,13 +224,13 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
host, dnsNames := dnsNames(tlsExtraDomains, tlsDisableAutofill) host, dnsNames := dnsNames(tlsExtraDomains, tlsDisableAutofill)
ipAddresses, err := ipAddresses(tlsExtraIPs, tlsDisableAutofill) ipAddresses, err := ipAddresses(tlsExtraIPs, tlsDisableAutofill)
if err != nil { if err != nil {
return err return nil, nil, err
} }
// Generate a private key for the certificate. // Generate a private key for the certificate.
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
return err return nil, nil, err
} }
// Construct the certificate template. // Construct the certificate template.
@ -238,8 +244,11 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
NotAfter: validUntil, NotAfter: validUntil,
KeyUsage: x509.KeyUsageKeyEncipherment | KeyUsage: x509.KeyUsageKeyEncipherment |
x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, x509.KeyUsageDigitalSignature |
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
IsCA: true, // so can sign self. IsCA: true, // so can sign self.
BasicConstraintsValid: true, BasicConstraintsValid: true,
@ -247,37 +256,57 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
IPAddresses: ipAddresses, IPAddresses: ipAddresses,
} }
derBytes, err := x509.CreateCertificate(rand.Reader, &template, derBytes, err := x509.CreateCertificate(
&template, &priv.PublicKey, priv) rand.Reader, &template,
&template, &priv.PublicKey, priv,
)
if err != nil { if err != nil {
return fmt.Errorf("failed to create certificate: %v", err) return nil, nil, fmt.Errorf("failed to create certificate: %v",
err)
} }
certBuf := &bytes.Buffer{} certBuf := &bytes.Buffer{}
err = pem.Encode(certBuf, &pem.Block{Type: "CERTIFICATE", err = pem.Encode(
Bytes: derBytes}) certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes},
)
if err != nil { if err != nil {
return fmt.Errorf("failed to encode certificate: %v", err) return nil, nil, fmt.Errorf("failed to encode certificate: %v",
err)
} }
keybytes, err := x509.MarshalECPrivateKey(priv) keybytes, err := x509.MarshalECPrivateKey(priv)
if err != nil { if err != nil {
return fmt.Errorf("unable to encode privkey: %v", err) return nil, nil, fmt.Errorf("unable to encode privkey: %v",
err)
} }
keyBuf := &bytes.Buffer{} keyBuf := &bytes.Buffer{}
err = pem.Encode(keyBuf, &pem.Block{Type: "EC PRIVATE KEY", err = pem.Encode(
Bytes: keybytes}) keyBuf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keybytes},
)
if err != nil { if err != nil {
return fmt.Errorf("failed to encode private key: %v", err) return nil, nil, fmt.Errorf("failed to encode private key: %v",
err)
} }
return certBuf.Bytes(), keyBuf.Bytes(), nil
}
// WriteCertPair writes certificate and key data to disk if a path is provided.
func WriteCertPair(certFile, keyFile string, certBytes, keyBytes []byte) error {
// Write cert and key files. // Write cert and key files.
if err = ioutil.WriteFile(certFile, certBuf.Bytes(), 0644); err != nil { if certFile != "" {
return err err := ioutil.WriteFile(certFile, certBytes, 0644)
if err != nil {
return err
}
} }
if err = ioutil.WriteFile(keyFile, keyBuf.Bytes(), 0600); err != nil {
os.Remove(certFile) if keyFile != "" {
return err err := ioutil.WriteFile(keyFile, keyBytes, 0600)
if err != nil {
os.Remove(certFile)
return err
}
} }
return nil return nil

View file

@ -1,6 +1,8 @@
package cert_test package cert_test
import ( import (
"io/ioutil"
"path/filepath"
"testing" "testing"
"time" "time"
@ -26,20 +28,28 @@ func TestIsOutdatedCert(t *testing.T) {
keyPath := tempDir + "/tls.key" keyPath := tempDir + "/tls.key"
// Generate TLS files with two extra IPs and domains. // Generate TLS files with two extra IPs and domains.
err := cert.GenCertPair( certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2], "lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
extraDomains[:2], false, testTLSCertDuration, extraDomains[:2], false, testTLSCertDuration,
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
// We'll attempt to check up-to-date status for all variants of 1-3 // We'll attempt to check up-to-date status for all variants of 1-3
// number of IPs and domains. // number of IPs and domains.
for numIPs := 1; numIPs <= len(extraIPs); numIPs++ { for numIPs := 1; numIPs <= len(extraIPs); numIPs++ {
for numDomains := 1; numDomains <= len(extraDomains); numDomains++ { for numDomains := 1; numDomains <= len(extraDomains); numDomains++ {
_, parsedCert, err := cert.LoadCert( certBytes, err := ioutil.ReadFile(certPath)
certPath, keyPath, require.NoError(t, err)
keyBytes, err := ioutil.ReadFile(keyPath)
require.NoError(t, err)
_, parsedCert, err := cert.LoadCertFromBytes(
certBytes, keyBytes,
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -78,17 +88,24 @@ func TestIsOutdatedPermutation(t *testing.T) {
keyPath := tempDir + "/tls.key" keyPath := tempDir + "/tls.key"
// Generate TLS files from the IPs and domains. // Generate TLS files from the IPs and domains.
err := cert.GenCertPair( certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:], "lnd autogenerated cert", certPath, keyPath, extraIPs[:],
extraDomains[:], false, testTLSCertDuration, extraDomains[:], false, testTLSCertDuration,
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, parsedCert, err := cert.LoadCert(certPath, keyPath) err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
if err != nil { require.NoError(t, err)
t.Fatal(err)
} certBytes, err = ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err = ioutil.ReadFile(keyPath)
require.NoError(t, err)
_, parsedCert, err := cert.LoadCertFromBytes(certBytes, keyBytes)
require.NoError(t, err)
// If we have duplicate IPs or DNS names listed, that shouldn't matter. // If we have duplicate IPs or DNS names listed, that shouldn't matter.
dupIPs := make([]string, len(extraIPs)*2) dupIPs := make([]string, len(extraIPs)*2)
@ -142,7 +159,7 @@ func TestTLSDisableAutofill(t *testing.T) {
keyPath := tempDir + "/tls.key" keyPath := tempDir + "/tls.key"
// Generate TLS files with two extra IPs and domains and no interface IPs. // Generate TLS files with two extra IPs and domains and no interface IPs.
err := cert.GenCertPair( certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2], "lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
extraDomains[:2], true, testTLSCertDuration, extraDomains[:2], true, testTLSCertDuration,
) )
@ -150,9 +167,19 @@ func TestTLSDisableAutofill(t *testing.T) {
t, err, t, err,
"unable to generate tls certificate pair", "unable to generate tls certificate pair",
) )
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
_, parsedCert, err := cert.LoadCert( // Read certs from disk.
certPath, keyPath, certBytes, err = ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err = ioutil.ReadFile(keyPath)
require.NoError(t, err)
// Load the certificate.
_, parsedCert, err := cert.LoadCertFromBytes(
certBytes, keyBytes,
) )
require.NoError( require.NoError(
t, err, t, err,
@ -160,7 +187,7 @@ func TestTLSDisableAutofill(t *testing.T) {
) )
// Check if the TLS cert is outdated while still preventing // Check if the TLS cert is outdated while still preventing
// interface IPs from being used. Should not be outdated // interface IPs from being used. Should not be outdated.
shouldNotBeOutdated, err := cert.IsOutdated( shouldNotBeOutdated, err := cert.IsOutdated(
parsedCert, extraIPs[:2], parsedCert, extraIPs[:2],
extraDomains[:2], true, extraDomains[:2], true,
@ -185,3 +212,51 @@ func TestTLSDisableAutofill(t *testing.T) {
"TLS Certificate was not marked as outdated when it should be", "TLS Certificate was not marked as outdated when it should be",
) )
} }
// TestTLSConfig tests to ensure we can generate a TLS Config from
// a tls cert and tls key.
func TestTLSConfig(t *testing.T) {
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "/tls.cert")
keyPath := filepath.Join(tempDir, "/tls.key")
// Generate TLS files with an extra IP and domain.
certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath,
[]string{extraIPs[0]}, []string{extraDomains[0]}, false,
testTLSCertDuration,
)
require.NoError(t, err)
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
certBytes, err = ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err = ioutil.ReadFile(keyPath)
require.NoError(t, err)
// Load the certificate.
certData, parsedCert, err := cert.LoadCertFromBytes(
certBytes, keyBytes,
)
require.NoError(t, err)
// Check to make sure the IP and domain are in the cert.
var foundIp bool
require.Contains(t, parsedCert.DNSNames, extraDomains[0])
for _, ip := range parsedCert.IPAddresses {
if ip.String() == extraIPs[0] {
foundIp = true
break
}
}
require.Equal(t, true, foundIp, "Did not find required ip inside of "+
"TLS Certificate.")
// Create TLS Config.
tlsCfg := cert.TLSConfFromCert(certData)
require.Equal(t, 1, len(tlsCfg.Certificates))
}

View file

@ -3,6 +3,8 @@ package cert
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"io/ioutil"
"sync"
) )
var ( var (
@ -24,6 +26,22 @@ var (
} }
) )
// GetCertBytesFromPath reads the TLS certificate and key files at the given
// certPath and keyPath and returns the file bytes.
func GetCertBytesFromPath(certPath, keyPath string) (certBytes,
keyBytes []byte, err error) {
certBytes, err = ioutil.ReadFile(certPath)
if err != nil {
return nil, nil, err
}
keyBytes, err = ioutil.ReadFile(keyPath)
if err != nil {
return nil, nil, err
}
return certBytes, keyBytes, nil
}
// LoadCert loads a certificate and its corresponding private key from the PEM // LoadCert loads a certificate and its corresponding private key from the PEM
// files indicated and returns the certificate in the two formats it is most // files indicated and returns the certificate in the two formats it is most
// commonly used. // commonly used.
@ -49,6 +67,31 @@ func LoadCert(certPath, keyPath string) (tls.Certificate, *x509.Certificate,
return certData, x509Cert, nil return certData, x509Cert, nil
} }
// LoadCertFromBytes loads a certificate and its corresponding private key from
// the PEM bytes indicated and returns the certificate in the two formats it is
// most commonly used.
func LoadCertFromBytes(certBytes, keyBytes []byte) (tls.Certificate,
*x509.Certificate, error) {
// The certData returned here is just a wrapper around the PEM blocks
// loaded from the file. The PEM is not yet fully parsed but a basic
// check is performed that the certificate and private key actually
// belong together.
certData, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return tls.Certificate{}, nil, err
}
// Now parse the the PEM block of the certificate into its x509 data
// structure so it can be examined in more detail.
x509Cert, err := x509.ParseCertificate(certData.Certificate[0])
if err != nil {
return tls.Certificate{}, nil, err
}
return certData, x509Cert, nil
}
// TLSConfFromCert returns the default TLS configuration used for a server, // TLSConfFromCert returns the default TLS configuration used for a server,
// using the given certificate as identity. // using the given certificate as identity.
func TLSConfFromCert(certData tls.Certificate) *tls.Config { func TLSConfFromCert(certData tls.Certificate) *tls.Config {
@ -58,3 +101,51 @@ func TLSConfFromCert(certData tls.Certificate) *tls.Config {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
} }
// TLSReloader updates the TLS certificate without restarting the server.
type TLSReloader struct {
certMu sync.RWMutex
cert *tls.Certificate
}
// NewTLSReloader is used to create a new TLS Reloader that will be used
// to update the TLS certificate without restarting the server.
func NewTLSReloader(certBytes, keyBytes []byte) (*TLSReloader, error) {
cert, _, err := LoadCertFromBytes(certBytes, keyBytes)
if err != nil {
return nil, err
}
return &TLSReloader{
cert: &cert,
}, nil
}
// AttemptReload will make an attempt to update the TLS certificate
// and key used by the server.
func (t *TLSReloader) AttemptReload(certBytes, keyBytes []byte) error {
newCert, _, err := LoadCertFromBytes(certBytes, keyBytes)
if err != nil {
return err
}
t.certMu.Lock()
t.cert = &newCert
t.certMu.Unlock()
return nil
}
// GetCertificateFunc is used in the server's TLS configuration to
// determine the correct TLS certificate to server on a request.
func (t *TLSReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (
*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate,
error) {
t.certMu.RLock()
defer t.certMu.RUnlock()
return t.cert, nil
}
}

View file

@ -125,6 +125,9 @@ certain large transactions](https://github.com/lightningnetwork/lnd/pull/7100).
* [Stop sending a synchronizing error on the wire when out of * [Stop sending a synchronizing error on the wire when out of
sync](https://github.com/lightningnetwork/lnd/pull/7039). sync](https://github.com/lightningnetwork/lnd/pull/7039).
* [Update cert module](https://github.com/lightningnetwork/lnd/pull/6573) to
allow a way to update the tls certificate without restarting lnd.
## `lncli` ## `lncli`
* [Add an `insecure` flag to skip tls auth as well as a `metadata` string slice * [Add an `insecure` flag to skip tls auth as well as a `metadata` string slice
flag](https://github.com/lightningnetwork/lnd/pull/6818) that allows the flag](https://github.com/lightningnetwork/lnd/pull/6818) that allows the