aboutsummaryrefslogtreecommitdiff
path: root/cert.go
blob: 6fcf77d6543a8a5bd70323d13329170b7b7db2fe (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package gmi

import (
	"bytes"
	"crypto/ed25519"
	"crypto/rand"
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"math/big"
	"net"
	"path/filepath"
	"strings"
	"time"
)

// CertificateStore maps hostnames to certificates.
// The zero value of CertificateStore is an empty store ready to use.
type CertificateStore struct {
	store map[string]tls.Certificate
}

// Add adds a certificate for the given hostname to the store.
// It tries to parse the certificate if it is not already parsed.
func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
	if c.store == nil {
		c.store = map[string]tls.Certificate{}
	}
	// Parse certificate if not already parsed
	if cert.Leaf == nil {
		parsed, err := x509.ParseCertificate(cert.Certificate[0])
		if err == nil {
			cert.Leaf = parsed
		}
	}
	c.store[hostname] = cert
}

// Lookup returns the certificate for the given hostname.
func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) {
	cert, ok := c.store[hostname]
	if !ok {
		return nil, ErrCertificateUnknown
	}
	// Ensure that the certificate is not expired
	if cert.Leaf != nil && cert.Leaf.NotAfter.Before(time.Now()) {
		return &cert, ErrCertificateExpired
	}
	return &cert, nil
}

// Load loads certificates from the given path.
// The path should lead to a directory containing certificates and private keys
// in the form hostname.crt and hostname.key.
// For example, the hostname "localhost" would have the corresponding files
// localhost.crt (certificate) and localhost.key (private key).
func (c *CertificateStore) Load(path string) error {
	matches, err := filepath.Glob(filepath.Join(path, "*.crt"))
	if err != nil {
		return err
	}
	for _, crtPath := range matches {
		keyPath := strings.TrimSuffix(crtPath, ".crt") + ".key"
		cert, err := tls.LoadX509KeyPair(crtPath, keyPath)
		if err != nil {
			continue
		}
		hostname := strings.TrimSuffix(filepath.Base(crtPath), ".crt")
		c.Add(hostname, cert)
	}
	return nil
}

// NewCertificate creates and returns a new parsed certificate.
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
	crt, key, err := NewRawCertificate(host, duration)
	if err != nil {
		return tls.Certificate{}, err
	}
	return tls.X509KeyPair(crt, key)
}

// NewRawCertificate creates and returns a raw certificate for the given host.
// It generates a self-signed TLS certificate and a ED25519 private key.
func NewRawCertificate(host string, duration time.Duration) (crt, key []byte, err error) {
	// Generate a ED25519 private key
	_, priv, err := ed25519.GenerateKey(rand.Reader)
	if err != nil {
		return nil, nil, err
	}
	public := priv.Public().(ed25519.PublicKey)

	// ED25519 keys should have the DigitalSignature KeyUsage bits set
	// in the x509.Certificate template
	keyUsage := x509.KeyUsageDigitalSignature

	notBefore := time.Now()
	notAfter := notBefore.Add(duration)

	// Generate the serial number
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	if err != nil {
		return nil, nil, err
	}

	template := x509.Certificate{
		SerialNumber:          serialNumber,
		NotBefore:             notBefore,
		NotAfter:              notAfter,
		KeyUsage:              keyUsage,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
	}

	hosts := strings.Split(host, ",")
	for _, h := range hosts {
		if ip := net.ParseIP(h); ip != nil {
			template.IPAddresses = append(template.IPAddresses, ip)
		} else {
			template.DNSNames = append(template.DNSNames, h)
		}
	}

	// Create the certificate
	cert, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv)
	if err != nil {
		return nil, nil, err
	}

	// Encode the certificate
	var b bytes.Buffer
	if err := pem.Encode(&b, &pem.Block{Type: "CERTIFICATE", Bytes: cert}); err != nil {
		return nil, nil, err
	}
	crt = b.Bytes()

	// Encode the key
	b = bytes.Buffer{}
	if err != nil {
		return nil, nil, err
	}
	privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
	if err != nil {
		return nil, nil, err
	}
	if err := pem.Encode(&b, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
		return nil, nil, err
	}
	key = b.Bytes()

	return
}