cert: add TLS reloader and return bytes from GenCert

Co-authored-by: gkrizek <graham@krizek.io>
This commit is contained in:
Orbital 2022-05-24 13:26:32 -05:00
parent 84401f6f6c
commit 2f35b9aa7f
No known key found for this signature in database
GPG key ID: E557F37C985848F7
5 changed files with 239 additions and 40 deletions

View file

@ -2,4 +2,7 @@ module github.com/lightningnetwork/lnd/cert
go 1.16
require github.com/stretchr/testify v1.5.1
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/stretchr/testify v1.5.1
)

View file

@ -1,5 +1,6 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View file

@ -27,7 +27,9 @@ var (
// ipAddresses returns the parserd IP addresses to use when creating the TLS
// certificate. If tlsDisableAutofill is true, we don't include interface
// addresses to protect users privacy.
func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP, error) {
func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP,
error) {
// Collect the host's IP addresses, including loopback, in a slice.
ipAddresses := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
@ -71,7 +73,9 @@ func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP, error
// dnsNames returns the host and DNS names to use when creating the TLS
// ceftificate.
func dnsNames(tlsExtraDomains []string, tlsDisableAutofill bool) (string, []string) {
func dnsNames(tlsExtraDomains []string, tlsDisableAutofill bool) (string,
[]string) {
// Collect the host's names into a slice.
host, err := os.Hostname()
@ -187,9 +191,10 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs,
return false, nil
}
// GenCertPair generates a key/cert pair to the paths provided. The
// auto-generated certificates should *not* be used in production for public
// access as they're self-signed and don't necessarily contain all of the
// GenCertPair generates a key/cert pair and returns the pair in byte form.
//
// The auto-generated certificates should *not* be used in production for
// public access as they're self-signed and don't necessarily contain all of the
// desired hostnames for the service. For production/public use, consider a
// real PKI.
//
@ -197,7 +202,7 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs,
// https://github.com/btcsuite/btcd/btcutil
func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
tlsExtraDomains []string, tlsDisableAutofill bool,
certValidity time.Duration) error {
certValidity time.Duration) ([]byte, []byte, error) {
now := time.Now()
validUntil := now.Add(certValidity)
@ -210,7 +215,8 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
// Generate a serial number that's below the serialNumberLimit.
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return fmt.Errorf("failed to generate serial number: %s", err)
return nil, nil, fmt.Errorf("failed to generate serial "+
"number: %s", err)
}
// Get all DNS names and IP addresses to use when creating the
@ -218,13 +224,13 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
host, dnsNames := dnsNames(tlsExtraDomains, tlsDisableAutofill)
ipAddresses, err := ipAddresses(tlsExtraIPs, tlsDisableAutofill)
if err != nil {
return err
return nil, nil, err
}
// Generate a private key for the certificate.
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
return nil, nil, err
}
// Construct the certificate template.
@ -238,8 +244,11 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
NotAfter: validUntil,
KeyUsage: x509.KeyUsageKeyEncipherment |
x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
x509.KeyUsageDigitalSignature |
x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
IsCA: true, // so can sign self.
BasicConstraintsValid: true,
@ -247,37 +256,57 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
IPAddresses: ipAddresses,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template,
&template, &priv.PublicKey, priv)
derBytes, err := x509.CreateCertificate(
rand.Reader, &template,
&template, &priv.PublicKey, priv,
)
if err != nil {
return fmt.Errorf("failed to create certificate: %v", err)
return nil, nil, fmt.Errorf("failed to create certificate: %v",
err)
}
certBuf := &bytes.Buffer{}
err = pem.Encode(certBuf, &pem.Block{Type: "CERTIFICATE",
Bytes: derBytes})
err = pem.Encode(
certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes},
)
if err != nil {
return fmt.Errorf("failed to encode certificate: %v", err)
return nil, nil, fmt.Errorf("failed to encode certificate: %v",
err)
}
keybytes, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return fmt.Errorf("unable to encode privkey: %v", err)
return nil, nil, fmt.Errorf("unable to encode privkey: %v",
err)
}
keyBuf := &bytes.Buffer{}
err = pem.Encode(keyBuf, &pem.Block{Type: "EC PRIVATE KEY",
Bytes: keybytes})
err = pem.Encode(
keyBuf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keybytes},
)
if err != nil {
return fmt.Errorf("failed to encode private key: %v", err)
return nil, nil, fmt.Errorf("failed to encode private key: %v",
err)
}
return certBuf.Bytes(), keyBuf.Bytes(), nil
}
// WriteCertPair writes certificate and key data to disk if a path is provided.
func WriteCertPair(certFile, keyFile string, certBytes, keyBytes []byte) error {
// Write cert and key files.
if err = ioutil.WriteFile(certFile, certBuf.Bytes(), 0644); err != nil {
return err
if certFile != "" {
err := ioutil.WriteFile(certFile, certBytes, 0644)
if err != nil {
return err
}
}
if err = ioutil.WriteFile(keyFile, keyBuf.Bytes(), 0600); err != nil {
os.Remove(certFile)
return err
if keyFile != "" {
err := ioutil.WriteFile(keyFile, keyBytes, 0600)
if err != nil {
os.Remove(certFile)
return err
}
}
return nil

View file

@ -1,6 +1,8 @@
package cert_test
import (
"io/ioutil"
"path/filepath"
"testing"
"time"
@ -26,20 +28,28 @@ func TestIsOutdatedCert(t *testing.T) {
keyPath := tempDir + "/tls.key"
// Generate TLS files with two extra IPs and domains.
err := cert.GenCertPair(
certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
extraDomains[:2], false, testTLSCertDuration,
)
if err != nil {
t.Fatal(err)
}
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
// We'll attempt to check up-to-date status for all variants of 1-3
// number of IPs and domains.
for numIPs := 1; numIPs <= len(extraIPs); numIPs++ {
for numDomains := 1; numDomains <= len(extraDomains); numDomains++ {
_, parsedCert, err := cert.LoadCert(
certPath, keyPath,
certBytes, err := ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err := ioutil.ReadFile(keyPath)
require.NoError(t, err)
_, parsedCert, err := cert.LoadCertFromBytes(
certBytes, keyBytes,
)
if err != nil {
t.Fatal(err)
@ -78,17 +88,24 @@ func TestIsOutdatedPermutation(t *testing.T) {
keyPath := tempDir + "/tls.key"
// Generate TLS files from the IPs and domains.
err := cert.GenCertPair(
certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:],
extraDomains[:], false, testTLSCertDuration,
)
if err != nil {
t.Fatal(err)
}
_, parsedCert, err := cert.LoadCert(certPath, keyPath)
if err != nil {
t.Fatal(err)
}
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
certBytes, err = ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err = ioutil.ReadFile(keyPath)
require.NoError(t, err)
_, parsedCert, err := cert.LoadCertFromBytes(certBytes, keyBytes)
require.NoError(t, err)
// If we have duplicate IPs or DNS names listed, that shouldn't matter.
dupIPs := make([]string, len(extraIPs)*2)
@ -142,7 +159,7 @@ func TestTLSDisableAutofill(t *testing.T) {
keyPath := tempDir + "/tls.key"
// Generate TLS files with two extra IPs and domains and no interface IPs.
err := cert.GenCertPair(
certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
extraDomains[:2], true, testTLSCertDuration,
)
@ -150,9 +167,19 @@ func TestTLSDisableAutofill(t *testing.T) {
t, err,
"unable to generate tls certificate pair",
)
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
_, parsedCert, err := cert.LoadCert(
certPath, keyPath,
// Read certs from disk.
certBytes, err = ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err = ioutil.ReadFile(keyPath)
require.NoError(t, err)
// Load the certificate.
_, parsedCert, err := cert.LoadCertFromBytes(
certBytes, keyBytes,
)
require.NoError(
t, err,
@ -160,7 +187,7 @@ func TestTLSDisableAutofill(t *testing.T) {
)
// Check if the TLS cert is outdated while still preventing
// interface IPs from being used. Should not be outdated
// interface IPs from being used. Should not be outdated.
shouldNotBeOutdated, err := cert.IsOutdated(
parsedCert, extraIPs[:2],
extraDomains[:2], true,
@ -185,3 +212,51 @@ func TestTLSDisableAutofill(t *testing.T) {
"TLS Certificate was not marked as outdated when it should be",
)
}
// TestTLSConfig tests to ensure we can generate a TLS Config from
// a tls cert and tls key.
func TestTLSConfig(t *testing.T) {
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "/tls.cert")
keyPath := filepath.Join(tempDir, "/tls.key")
// Generate TLS files with an extra IP and domain.
certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath,
[]string{extraIPs[0]}, []string{extraDomains[0]}, false,
testTLSCertDuration,
)
require.NoError(t, err)
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
certBytes, err = ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err = ioutil.ReadFile(keyPath)
require.NoError(t, err)
// Load the certificate.
certData, parsedCert, err := cert.LoadCertFromBytes(
certBytes, keyBytes,
)
require.NoError(t, err)
// Check to make sure the IP and domain are in the cert.
var foundIp bool
require.Contains(t, parsedCert.DNSNames, extraDomains[0])
for _, ip := range parsedCert.IPAddresses {
if ip.String() == extraIPs[0] {
foundIp = true
break
}
}
require.Equal(t, true, foundIp, "Did not find required ip inside of "+
"TLS Certificate.")
// Create TLS Config.
tlsCfg := cert.TLSConfFromCert(certData)
require.Equal(t, 1, len(tlsCfg.Certificates))
}

View file

@ -3,6 +3,8 @@ package cert
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"sync"
)
var (
@ -24,6 +26,22 @@ var (
}
)
// GetCertBytesFromPath reads the TLS certificate and key files at the given
// certPath and keyPath and returns the file bytes.
func GetCertBytesFromPath(certPath, keyPath string) (certBytes,
keyBytes []byte, err error) {
certBytes, err = ioutil.ReadFile(certPath)
if err != nil {
return nil, nil, err
}
keyBytes, err = ioutil.ReadFile(keyPath)
if err != nil {
return nil, nil, err
}
return certBytes, keyBytes, nil
}
// LoadCert loads a certificate and its corresponding private key from the PEM
// files indicated and returns the certificate in the two formats it is most
// commonly used.
@ -49,6 +67,31 @@ func LoadCert(certPath, keyPath string) (tls.Certificate, *x509.Certificate,
return certData, x509Cert, nil
}
// LoadCertFromBytes loads a certificate and its corresponding private key from
// the PEM bytes indicated and returns the certificate in the two formats it is
// most commonly used.
func LoadCertFromBytes(certBytes, keyBytes []byte) (tls.Certificate,
*x509.Certificate, error) {
// The certData returned here is just a wrapper around the PEM blocks
// loaded from the file. The PEM is not yet fully parsed but a basic
// check is performed that the certificate and private key actually
// belong together.
certData, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return tls.Certificate{}, nil, err
}
// Now parse the the PEM block of the certificate into its x509 data
// structure so it can be examined in more detail.
x509Cert, err := x509.ParseCertificate(certData.Certificate[0])
if err != nil {
return tls.Certificate{}, nil, err
}
return certData, x509Cert, nil
}
// TLSConfFromCert returns the default TLS configuration used for a server,
// using the given certificate as identity.
func TLSConfFromCert(certData tls.Certificate) *tls.Config {
@ -58,3 +101,51 @@ func TLSConfFromCert(certData tls.Certificate) *tls.Config {
MinVersion: tls.VersionTLS12,
}
}
// TLSReloader updates the TLS certificate without restarting the server.
type TLSReloader struct {
certMu sync.RWMutex
cert *tls.Certificate
}
// NewTLSReloader is used to create a new TLS Reloader that will be used
// to update the TLS certificate without restarting the server.
func NewTLSReloader(certBytes, keyBytes []byte) (*TLSReloader, error) {
cert, _, err := LoadCertFromBytes(certBytes, keyBytes)
if err != nil {
return nil, err
}
return &TLSReloader{
cert: &cert,
}, nil
}
// AttemptReload will make an attempt to update the TLS certificate
// and key used by the server.
func (t *TLSReloader) AttemptReload(certBytes, keyBytes []byte) error {
newCert, _, err := LoadCertFromBytes(certBytes, keyBytes)
if err != nil {
return err
}
t.certMu.Lock()
t.cert = &newCert
t.certMu.Unlock()
return nil
}
// GetCertificateFunc is used in the server's TLS configuration to
// determine the correct TLS certificate to server on a request.
func (t *TLSReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (
*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate,
error) {
t.certMu.RLock()
defer t.certMu.RUnlock()
return t.cert, nil
}
}