aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2020-10-28 13:40:25 -0400
committerAdnan Maolood <[email protected]>2020-10-28 13:41:24 -0400
commitfbd97a62dec02ad22b7cf520cfc6ab519ea0e990 (patch)
tree8a19117713cddce2d3ed2d31c24bec59fe616a48
parentAdd ErrInputRequired and ErrCertificateRequired (diff)
downloadgo-gemini-fbd97a62dec02ad22b7cf520cfc6ab519ea0e990.tar.xz
go-gemini-fbd97a62dec02ad22b7cf520cfc6ab519ea0e990.zip
Refactor client certificates
-rw-r--r--cert.go75
-rw-r--r--client.go153
-rw-r--r--examples/client.go25
-rw-r--r--examples/server.go5
-rw-r--r--gemini.go15
-rw-r--r--status.go24
6 files changed, 141 insertions, 156 deletions
diff --git a/cert.go b/cert.go
index da77478..6ecf119 100644
--- a/cert.go
+++ b/cert.go
@@ -20,9 +20,9 @@ type CertificateStore struct {
store map[string]tls.Certificate
}
-// Add adds a certificate for the given hostname to the store.
+// Add adds a certificate for the given scope to the store.
// It tries to parse the certificate if it is not already parsed.
-func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
+func (c *CertificateStore) Add(scope string, cert tls.Certificate) {
if c.store == nil {
c.store = map[string]tls.Certificate{}
}
@@ -33,7 +33,7 @@ func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
cert.Leaf = parsed
}
}
- c.store[hostname] = cert
+ c.store[scope] = cert
}
// Lookup returns the certificate for the given hostname.
@@ -49,6 +49,22 @@ func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) {
return &cert, nil
}
+// lookup returns the certificate for the given hostname + path.
+func (c *CertificateStore) lookup(scope string) (*tls.Certificate, error) {
+ for {
+ cert, err := c.Lookup(scope)
+ switch err {
+ case ErrCertificateExpired, nil:
+ return cert, err
+ }
+ scope = path.Dir(scope)
+ if scope == "." {
+ break
+ }
+ }
+ return nil, ErrCertificateUnknown
+}
+
// Load loads certificates from the given path.
// The path should lead to a directory containing certificates and private keys
// in the form hostname.crt and hostname.key.
@@ -71,36 +87,16 @@ func (c *CertificateStore) Load(path string) error {
return nil
}
-type ClientCertificateStore struct {
- CertificateStore
+// CertificateOptions configures how a certificate is created.
+type CertificateOptions struct {
+ IPAddresses []net.IP
+ DNSNames []string
+ Duration time.Duration
}
-func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) {
- urlPath = path.Clean(urlPath)
- if urlPath == "." {
- urlPath = "/"
- }
- if urlPath[0] != '/' {
- urlPath = "/" + urlPath
- }
- for {
- cert, err := c.CertificateStore.Lookup(hostname + urlPath)
- switch err {
- case ErrCertificateExpired, nil:
- return cert, err
- }
- slash := urlPath == "/"
- urlPath = path.Dir(urlPath)
- if slash && urlPath == "/" {
- break
- }
- }
- return nil, ErrCertificateUnknown
-}
-
-// NewCertificate creates and returns a new parsed certificate.
-func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
- crt, priv, err := newX509KeyPair(host, duration)
+// CreateCertificate creates a new TLS certificate.
+func CreateCertificate(options CertificateOptions) (tls.Certificate, error) {
+ crt, priv, err := newX509KeyPair(options)
if err != nil {
return tls.Certificate{}, err
}
@@ -112,7 +108,7 @@ func NewCertificate(host string, duration time.Duration) (tls.Certificate, error
}
// newX509KeyPair creates and returns a new certificate and private key.
-func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, crypto.PrivateKey, error) {
+func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.PrivateKey, error) {
// Generate an ED25519 private key
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
@@ -131,7 +127,7 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
}
notBefore := time.Now()
- notAfter := notBefore.Add(duration)
+ notAfter := notBefore.Add(options.Duration)
template := x509.Certificate{
SerialNumber: serialNumber,
@@ -140,17 +136,8 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
- }
-
- if host != "" {
- hosts := strings.Split(host, ",")
- for _, h := range hosts {
- if ip := net.ParseIP(h); ip != nil {
- template.IPAddresses = append(template.IPAddresses, ip)
- } else {
- template.DNSNames = append(template.DNSNames, h)
- }
- }
+ IPAddresses: options.IPAddresses,
+ DNSNames: options.DNSNames,
}
crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv)
diff --git a/client.go b/client.go
index c4ffb59..5ab3ded 100644
--- a/client.go
+++ b/client.go
@@ -6,37 +6,38 @@ import (
"crypto/x509"
"net"
"net/url"
+ "strings"
)
-// Client represents a Gemini client.
+// Client is a Gemini client.
type Client struct {
- // KnownHosts is a list of known hosts that the client trusts.
+ // KnownHosts is a list of known hosts.
KnownHosts KnownHosts
- // CertificateStore maps hostnames to certificates.
- // It is used to determine which certificate to use when the server requests
- // a certificate.
- CertificateStore ClientCertificateStore
+ // Certificates stores client-side certificates.
+ Certificates CertificateStore
- // CheckRedirect, if not nil, will be called to determine whether
- // to follow a redirect.
+ // GetInput is called to retrieve input when the server requests it.
+ // If GetInput is nil or returns false, no input will be sent and
+ // the response will be returned.
+ GetInput func(prompt string, sensitive bool) (input string, ok bool)
+
+ // CheckRedirect determines whether to follow a redirect.
// If CheckRedirect is nil, a default policy of no more than 5 consecutive
// redirects will be enforced.
CheckRedirect func(req *Request, via []*Request) error
- // GetInput, if not nil, will be called to retrieve input when the server
- // requests it.
- GetInput func(prompt string, sensitive bool) (string, bool)
-
- // GetCertificate, if not nil, will be called when a server requests a certificate.
- // The returned certificate will be used when sending the request again.
- // If the certificate is nil, the request will not be sent again and
- // the response will be returned.
- GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
-
- // TrustCertificate, if not nil, will be called to determine whether the
- // client should trust the given certificate.
- // If error is not nil, the connection will be aborted.
+ // CreateCertificate is called to generate a certificate upon
+ // the request of a server.
+ // If CreateCertificate is nil or the returned error is not nil,
+ // the request will not be sent again and the response will be returned.
+ CreateCertificate func(hostname, path string) (tls.Certificate, error)
+
+ // TrustCertificate determines whether the client should trust
+ // the provided certificate.
+ // If the returned error is not nil, the connection will be aborted.
+ // If TrustCertificate is nil, the client will check KnownHosts
+ // for the certificate.
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
}
@@ -59,38 +60,18 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
config := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
- GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- // Request certificates take precedence over client certificates
- if req.Certificate != nil {
- return req.Certificate, nil
- }
- // If we have already stored the certificate, return it
- if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
- return cert, nil
- }
- return &tls.Certificate{}, nil
+ GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
+ return c.getClientCertificate(req)
},
VerifyConnection: func(cs tls.ConnectionState) error {
- cert := cs.PeerCertificates[0]
- // Verify the hostname
- if err := verifyHostname(cert, hostname(req.Host)); err != nil {
- return err
- }
- // Check that the client trusts the certificate
- if c.TrustCertificate == nil {
- if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil {
- return err
- }
- } else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil {
- return err
- }
- return nil
+ return c.verifyConnection(req, cs)
},
}
conn, err := tls.Dial("tcp", req.Host, config)
if err != nil {
return nil, err
}
+ // TODO: Set connection deadline
// Write the request
w := bufio.NewWriter(conn)
@@ -104,24 +85,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
if err := resp.read(conn); err != nil {
return nil, err
}
- // Store connection information
+ // Store connection state
resp.TLS = conn.ConnectionState()
- // Resend the request with a certificate if the server responded
- // with CertificateRequired
- if resp.Status == StatusCertificateRequired {
+ switch {
+ case resp.Status == StatusCertificateRequired:
// Check to see if a certificate was already provided to prevent an infinite loop
if req.Certificate != nil {
return resp, nil
}
- if c.GetCertificate != nil {
- if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
- req.Certificate = cert
- return c.Do(req)
+
+ hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
+ if c.CreateCertificate != nil {
+ cert, err := c.CreateCertificate(hostname, path)
+ if err != nil {
+ return resp, err
}
+ c.Certificates.Add(hostname+path, cert)
+ return c.do(req, via)
}
return resp, ErrCertificateRequired
- } else if resp.Status.Class() == StatusClassRedirect {
+
+ case resp.Status.Class() == StatusClassInput:
+ if c.GetInput != nil {
+ input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
+ if ok {
+ req.URL.ForceQuery = true
+ req.URL.RawQuery = url.QueryEscape(input)
+ return c.do(req, via)
+ }
+ }
+ return resp, ErrInputRequired
+
+ case resp.Status.Class() == StatusClassRedirect:
if via == nil {
via = []*Request{}
}
@@ -146,27 +142,44 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, ErrTooManyRedirects
}
return c.do(redirect, via)
- } else if resp.Status.Class() == StatusClassInput {
- if c.GetInput != nil {
- input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
- if ok {
- req.URL.ForceQuery = true
- req.URL.RawQuery = url.QueryEscape(input)
- return c.do(req, via)
- }
- }
- return resp, ErrInputRequired
}
resp.Request = req
return resp, nil
}
-// hostname returns the host without the port.
-func hostname(host string) string {
- hostname, _, err := net.SplitHostPort(host)
- if err != nil {
- return host
+func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
+ // Request certificates have the highest precedence
+ if req.Certificate != nil {
+ return req.Certificate, nil
+ }
+ hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
+ if cert, err := c.Certificates.lookup(hostname + path); err == nil {
+ // Remember the certificate used
+ req.Certificate = cert
+ return cert, nil
+ }
+ return &tls.Certificate{}, nil
+}
+
+func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
+ // Verify the hostname
+ var hostname string
+ if host, _, err := net.SplitHostPort(req.Host); err == nil {
+ hostname = host
+ } else {
+ hostname = req.Host
+ }
+ cert := cs.PeerCertificates[0]
+ if err := verifyHostname(cert, hostname); err != nil {
+ return err
+ }
+ // Check that the client trusts the certificate
+ var err error
+ if c.TrustCertificate != nil {
+ return c.TrustCertificate(hostname, cert, &c.KnownHosts)
+ } else {
+ err = c.KnownHosts.Lookup(hostname, cert)
}
- return hostname
+ return err
}
diff --git a/examples/client.go b/examples/client.go
index 169f726..71e7915 100644
--- a/examples/client.go
+++ b/examples/client.go
@@ -46,21 +46,11 @@ func init() {
}
return err
}
- client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate {
- // If the certificate is in the store, return it
- if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
- return cert
- }
- // Otherwise, generate a certificate
- fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path)
- duration := time.Hour
- cert, err := gemini.NewCertificate("", duration)
- if err != nil {
- return nil
- }
- // Store and return the certificate
- store.Add(req.URL.Hostname()+req.URL.Path, cert)
- return &cert
+ client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
+ fmt.Println("Generating client certificate for", hostname, path)
+ return gemini.CreateCertificate(gemini.CertificateOptions{
+ Duration: time.Hour,
+ })
}
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
fmt.Printf("%s: ", prompt)
@@ -69,8 +59,7 @@ func init() {
}
}
-// sendRequest sends a request to the given URL.
-func sendRequest(req *gemini.Request) error {
+func doRequest(req *gemini.Request) error {
resp, err := client.Do(req)
if err != nil {
return err
@@ -131,7 +120,7 @@ func main() {
os.Exit(1)
}
- if err := sendRequest(req); err != nil {
+ if err := doRequest(req); err != nil {
fmt.Println(err)
os.Exit(1)
}
diff --git a/examples/server.go b/examples/server.go
index 9ef4b8d..11ea9ed 100644
--- a/examples/server.go
+++ b/examples/server.go
@@ -29,7 +29,10 @@ func main() {
fallthrough
case gmi.ErrCertificateUnknown:
// Generate a certificate if one does not exist.
- cert, err := gmi.NewCertificate(hostname, time.Minute)
+ cert, err := gmi.CreateCertificate(gmi.CertificateOptions{
+ DNSNames: []string{hostname},
+ Duration: time.Hour,
+ })
if err != nil {
// Failed to generate new certificate, abort
return nil
diff --git a/gemini.go b/gemini.go
index a381712..8a74422 100644
--- a/gemini.go
+++ b/gemini.go
@@ -51,16 +51,9 @@ func init() {
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
return knownHosts.Lookup(hostname, cert)
}
- DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate {
- if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
- return cert
- }
- duration := time.Hour
- cert, err := NewCertificate("", duration)
- if err != nil {
- return nil
- }
- store.Add(req.URL.Hostname()+req.URL.Path, cert)
- return &cert
+ DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
+ return CreateCertificate(CertificateOptions{
+ Duration: time.Hour,
+ })
}
}
diff --git a/status.go b/status.go
index 2a4d0d3..4c50f89 100644
--- a/status.go
+++ b/status.go
@@ -24,6 +24,18 @@ const (
StatusCertificateNotValid Status = 62
)
+// Status code categories.
+type StatusClass int
+
+const (
+ StatusClassInput StatusClass = 1
+ StatusClassSuccess StatusClass = 2
+ StatusClassRedirect StatusClass = 3
+ StatusClassTemporaryFailure StatusClass = 4
+ StatusClassPermanentFailure StatusClass = 5
+ StatusClassCertificateRequired StatusClass = 6
+)
+
// Class returns the status class for this status code.
func (s Status) Class() StatusClass {
return StatusClass(s / 10)
@@ -71,15 +83,3 @@ func (s Status) Message() string {
}
return ""
}
-
-// Status code categories.
-type StatusClass int
-
-const (
- StatusClassInput StatusClass = 1
- StatusClassSuccess StatusClass = 2
- StatusClassRedirect StatusClass = 3
- StatusClassTemporaryFailure StatusClass = 4
- StatusClassPermanentFailure StatusClass = 5
- StatusClassCertificateRequired StatusClass = 6
-)