mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-03 17:26:57 +01:00
cert: add TLS reloader and return bytes from GenCert
Co-authored-by: gkrizek <graham@krizek.io>
This commit is contained in:
parent
84401f6f6c
commit
2f35b9aa7f
5 changed files with 239 additions and 40 deletions
|
@ -2,4 +2,7 @@ module github.com/lightningnetwork/lnd/cert
|
|||
|
||||
go 1.16
|
||||
|
||||
require github.com/stretchr/testify v1.5.1
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/stretchr/testify v1.5.1
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
|
|
|
@ -27,7 +27,9 @@ var (
|
|||
// ipAddresses returns the parserd IP addresses to use when creating the TLS
|
||||
// certificate. If tlsDisableAutofill is true, we don't include interface
|
||||
// addresses to protect users privacy.
|
||||
func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP, error) {
|
||||
func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP,
|
||||
error) {
|
||||
|
||||
// Collect the host's IP addresses, including loopback, in a slice.
|
||||
ipAddresses := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
|
||||
|
||||
|
@ -71,7 +73,9 @@ func ipAddresses(tlsExtraIPs []string, tlsDisableAutofill bool) ([]net.IP, error
|
|||
|
||||
// dnsNames returns the host and DNS names to use when creating the TLS
|
||||
// ceftificate.
|
||||
func dnsNames(tlsExtraDomains []string, tlsDisableAutofill bool) (string, []string) {
|
||||
func dnsNames(tlsExtraDomains []string, tlsDisableAutofill bool) (string,
|
||||
[]string) {
|
||||
|
||||
// Collect the host's names into a slice.
|
||||
host, err := os.Hostname()
|
||||
|
||||
|
@ -187,9 +191,10 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs,
|
|||
return false, nil
|
||||
}
|
||||
|
||||
// GenCertPair generates a key/cert pair to the paths provided. The
|
||||
// auto-generated certificates should *not* be used in production for public
|
||||
// access as they're self-signed and don't necessarily contain all of the
|
||||
// GenCertPair generates a key/cert pair and returns the pair in byte form.
|
||||
//
|
||||
// The auto-generated certificates should *not* be used in production for
|
||||
// public access as they're self-signed and don't necessarily contain all of the
|
||||
// desired hostnames for the service. For production/public use, consider a
|
||||
// real PKI.
|
||||
//
|
||||
|
@ -197,7 +202,7 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs,
|
|||
// https://github.com/btcsuite/btcd/btcutil
|
||||
func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
|
||||
tlsExtraDomains []string, tlsDisableAutofill bool,
|
||||
certValidity time.Duration) error {
|
||||
certValidity time.Duration) ([]byte, []byte, error) {
|
||||
|
||||
now := time.Now()
|
||||
validUntil := now.Add(certValidity)
|
||||
|
@ -210,7 +215,8 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
|
|||
// Generate a serial number that's below the serialNumberLimit.
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate serial number: %s", err)
|
||||
return nil, nil, fmt.Errorf("failed to generate serial "+
|
||||
"number: %s", err)
|
||||
}
|
||||
|
||||
// Get all DNS names and IP addresses to use when creating the
|
||||
|
@ -218,13 +224,13 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
|
|||
host, dnsNames := dnsNames(tlsExtraDomains, tlsDisableAutofill)
|
||||
ipAddresses, err := ipAddresses(tlsExtraIPs, tlsDisableAutofill)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Generate a private key for the certificate.
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Construct the certificate template.
|
||||
|
@ -238,8 +244,11 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
|
|||
NotAfter: validUntil,
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment |
|
||||
x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
x509.KeyUsageDigitalSignature |
|
||||
x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
IsCA: true, // so can sign self.
|
||||
BasicConstraintsValid: true,
|
||||
|
||||
|
@ -247,37 +256,57 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
|
|||
IPAddresses: ipAddresses,
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template,
|
||||
&template, &priv.PublicKey, priv)
|
||||
derBytes, err := x509.CreateCertificate(
|
||||
rand.Reader, &template,
|
||||
&template, &priv.PublicKey, priv,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate: %v", err)
|
||||
return nil, nil, fmt.Errorf("failed to create certificate: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
certBuf := &bytes.Buffer{}
|
||||
err = pem.Encode(certBuf, &pem.Block{Type: "CERTIFICATE",
|
||||
Bytes: derBytes})
|
||||
err = pem.Encode(
|
||||
certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode certificate: %v", err)
|
||||
return nil, nil, fmt.Errorf("failed to encode certificate: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
keybytes, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to encode privkey: %v", err)
|
||||
return nil, nil, fmt.Errorf("unable to encode privkey: %v",
|
||||
err)
|
||||
}
|
||||
keyBuf := &bytes.Buffer{}
|
||||
err = pem.Encode(keyBuf, &pem.Block{Type: "EC PRIVATE KEY",
|
||||
Bytes: keybytes})
|
||||
err = pem.Encode(
|
||||
keyBuf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keybytes},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode private key: %v", err)
|
||||
return nil, nil, fmt.Errorf("failed to encode private key: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
return certBuf.Bytes(), keyBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
// WriteCertPair writes certificate and key data to disk if a path is provided.
|
||||
func WriteCertPair(certFile, keyFile string, certBytes, keyBytes []byte) error {
|
||||
// Write cert and key files.
|
||||
if err = ioutil.WriteFile(certFile, certBuf.Bytes(), 0644); err != nil {
|
||||
return err
|
||||
if certFile != "" {
|
||||
err := ioutil.WriteFile(certFile, certBytes, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err = ioutil.WriteFile(keyFile, keyBuf.Bytes(), 0600); err != nil {
|
||||
os.Remove(certFile)
|
||||
return err
|
||||
|
||||
if keyFile != "" {
|
||||
err := ioutil.WriteFile(keyFile, keyBytes, 0600)
|
||||
if err != nil {
|
||||
os.Remove(certFile)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package cert_test
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -26,20 +28,28 @@ func TestIsOutdatedCert(t *testing.T) {
|
|||
keyPath := tempDir + "/tls.key"
|
||||
|
||||
// Generate TLS files with two extra IPs and domains.
|
||||
err := cert.GenCertPair(
|
||||
certBytes, keyBytes, err := cert.GenCertPair(
|
||||
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
|
||||
extraDomains[:2], false, testTLSCertDuration,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We'll attempt to check up-to-date status for all variants of 1-3
|
||||
// number of IPs and domains.
|
||||
for numIPs := 1; numIPs <= len(extraIPs); numIPs++ {
|
||||
for numDomains := 1; numDomains <= len(extraDomains); numDomains++ {
|
||||
_, parsedCert, err := cert.LoadCert(
|
||||
certPath, keyPath,
|
||||
certBytes, err := ioutil.ReadFile(certPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
keyBytes, err := ioutil.ReadFile(keyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, parsedCert, err := cert.LoadCertFromBytes(
|
||||
certBytes, keyBytes,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -78,17 +88,24 @@ func TestIsOutdatedPermutation(t *testing.T) {
|
|||
keyPath := tempDir + "/tls.key"
|
||||
|
||||
// Generate TLS files from the IPs and domains.
|
||||
err := cert.GenCertPair(
|
||||
certBytes, keyBytes, err := cert.GenCertPair(
|
||||
"lnd autogenerated cert", certPath, keyPath, extraIPs[:],
|
||||
extraDomains[:], false, testTLSCertDuration,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, parsedCert, err := cert.LoadCert(certPath, keyPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
certBytes, err = ioutil.ReadFile(certPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
keyBytes, err = ioutil.ReadFile(keyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, parsedCert, err := cert.LoadCertFromBytes(certBytes, keyBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// If we have duplicate IPs or DNS names listed, that shouldn't matter.
|
||||
dupIPs := make([]string, len(extraIPs)*2)
|
||||
|
@ -142,7 +159,7 @@ func TestTLSDisableAutofill(t *testing.T) {
|
|||
keyPath := tempDir + "/tls.key"
|
||||
|
||||
// Generate TLS files with two extra IPs and domains and no interface IPs.
|
||||
err := cert.GenCertPair(
|
||||
certBytes, keyBytes, err := cert.GenCertPair(
|
||||
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
|
||||
extraDomains[:2], true, testTLSCertDuration,
|
||||
)
|
||||
|
@ -150,9 +167,19 @@ func TestTLSDisableAutofill(t *testing.T) {
|
|||
t, err,
|
||||
"unable to generate tls certificate pair",
|
||||
)
|
||||
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, parsedCert, err := cert.LoadCert(
|
||||
certPath, keyPath,
|
||||
// Read certs from disk.
|
||||
certBytes, err = ioutil.ReadFile(certPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
keyBytes, err = ioutil.ReadFile(keyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the certificate.
|
||||
_, parsedCert, err := cert.LoadCertFromBytes(
|
||||
certBytes, keyBytes,
|
||||
)
|
||||
require.NoError(
|
||||
t, err,
|
||||
|
@ -160,7 +187,7 @@ func TestTLSDisableAutofill(t *testing.T) {
|
|||
)
|
||||
|
||||
// Check if the TLS cert is outdated while still preventing
|
||||
// interface IPs from being used. Should not be outdated
|
||||
// interface IPs from being used. Should not be outdated.
|
||||
shouldNotBeOutdated, err := cert.IsOutdated(
|
||||
parsedCert, extraIPs[:2],
|
||||
extraDomains[:2], true,
|
||||
|
@ -185,3 +212,51 @@ func TestTLSDisableAutofill(t *testing.T) {
|
|||
"TLS Certificate was not marked as outdated when it should be",
|
||||
)
|
||||
}
|
||||
|
||||
// TestTLSConfig tests to ensure we can generate a TLS Config from
|
||||
// a tls cert and tls key.
|
||||
func TestTLSConfig(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certPath := filepath.Join(tempDir, "/tls.cert")
|
||||
keyPath := filepath.Join(tempDir, "/tls.key")
|
||||
|
||||
// Generate TLS files with an extra IP and domain.
|
||||
certBytes, keyBytes, err := cert.GenCertPair(
|
||||
"lnd autogenerated cert", certPath, keyPath,
|
||||
[]string{extraIPs[0]}, []string{extraDomains[0]}, false,
|
||||
testTLSCertDuration,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
certBytes, err = ioutil.ReadFile(certPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
keyBytes, err = ioutil.ReadFile(keyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the certificate.
|
||||
certData, parsedCert, err := cert.LoadCertFromBytes(
|
||||
certBytes, keyBytes,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check to make sure the IP and domain are in the cert.
|
||||
var foundIp bool
|
||||
require.Contains(t, parsedCert.DNSNames, extraDomains[0])
|
||||
for _, ip := range parsedCert.IPAddresses {
|
||||
if ip.String() == extraIPs[0] {
|
||||
foundIp = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Equal(t, true, foundIp, "Did not find required ip inside of "+
|
||||
"TLS Certificate.")
|
||||
|
||||
// Create TLS Config.
|
||||
tlsCfg := cert.TLSConfFromCert(certData)
|
||||
|
||||
require.Equal(t, 1, len(tlsCfg.Certificates))
|
||||
}
|
||||
|
|
91
cert/tls.go
91
cert/tls.go
|
@ -3,6 +3,8 @@ package cert
|
|||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -24,6 +26,22 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
// GetCertBytesFromPath reads the TLS certificate and key files at the given
|
||||
// certPath and keyPath and returns the file bytes.
|
||||
func GetCertBytesFromPath(certPath, keyPath string) (certBytes,
|
||||
keyBytes []byte, err error) {
|
||||
|
||||
certBytes, err = ioutil.ReadFile(certPath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
keyBytes, err = ioutil.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return certBytes, keyBytes, nil
|
||||
}
|
||||
|
||||
// LoadCert loads a certificate and its corresponding private key from the PEM
|
||||
// files indicated and returns the certificate in the two formats it is most
|
||||
// commonly used.
|
||||
|
@ -49,6 +67,31 @@ func LoadCert(certPath, keyPath string) (tls.Certificate, *x509.Certificate,
|
|||
return certData, x509Cert, nil
|
||||
}
|
||||
|
||||
// LoadCertFromBytes loads a certificate and its corresponding private key from
|
||||
// the PEM bytes indicated and returns the certificate in the two formats it is
|
||||
// most commonly used.
|
||||
func LoadCertFromBytes(certBytes, keyBytes []byte) (tls.Certificate,
|
||||
*x509.Certificate, error) {
|
||||
|
||||
// The certData returned here is just a wrapper around the PEM blocks
|
||||
// loaded from the file. The PEM is not yet fully parsed but a basic
|
||||
// check is performed that the certificate and private key actually
|
||||
// belong together.
|
||||
certData, err := tls.X509KeyPair(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, nil, err
|
||||
}
|
||||
|
||||
// Now parse the the PEM block of the certificate into its x509 data
|
||||
// structure so it can be examined in more detail.
|
||||
x509Cert, err := x509.ParseCertificate(certData.Certificate[0])
|
||||
if err != nil {
|
||||
return tls.Certificate{}, nil, err
|
||||
}
|
||||
|
||||
return certData, x509Cert, nil
|
||||
}
|
||||
|
||||
// TLSConfFromCert returns the default TLS configuration used for a server,
|
||||
// using the given certificate as identity.
|
||||
func TLSConfFromCert(certData tls.Certificate) *tls.Config {
|
||||
|
@ -58,3 +101,51 @@ func TLSConfFromCert(certData tls.Certificate) *tls.Config {
|
|||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
// TLSReloader updates the TLS certificate without restarting the server.
|
||||
type TLSReloader struct {
|
||||
certMu sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
}
|
||||
|
||||
// NewTLSReloader is used to create a new TLS Reloader that will be used
|
||||
// to update the TLS certificate without restarting the server.
|
||||
func NewTLSReloader(certBytes, keyBytes []byte) (*TLSReloader, error) {
|
||||
cert, _, err := LoadCertFromBytes(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TLSReloader{
|
||||
cert: &cert,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AttemptReload will make an attempt to update the TLS certificate
|
||||
// and key used by the server.
|
||||
func (t *TLSReloader) AttemptReload(certBytes, keyBytes []byte) error {
|
||||
newCert, _, err := LoadCertFromBytes(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.certMu.Lock()
|
||||
t.cert = &newCert
|
||||
t.certMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCertificateFunc is used in the server's TLS configuration to
|
||||
// determine the correct TLS certificate to server on a request.
|
||||
func (t *TLSReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (
|
||||
*tls.Certificate, error) {
|
||||
|
||||
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate,
|
||||
error) {
|
||||
|
||||
t.certMu.RLock()
|
||||
defer t.certMu.RUnlock()
|
||||
|
||||
return t.cert, nil
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue