aboutsummaryrefslogtreecommitdiff
path: root/client.go
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2020-10-28 13:40:25 -0400
committerAdnan Maolood <[email protected]>2020-10-28 13:41:24 -0400
commitfbd97a62dec02ad22b7cf520cfc6ab519ea0e990 (patch)
tree8a19117713cddce2d3ed2d31c24bec59fe616a48 /client.go
parentAdd ErrInputRequired and ErrCertificateRequired (diff)
downloadgo-gemini-fbd97a62dec02ad22b7cf520cfc6ab519ea0e990.tar.xz
go-gemini-fbd97a62dec02ad22b7cf520cfc6ab519ea0e990.zip
Refactor client certificates
Diffstat (limited to 'client.go')
-rw-r--r--client.go153
1 files changed, 83 insertions, 70 deletions
diff --git a/client.go b/client.go
index c4ffb59..5ab3ded 100644
--- a/client.go
+++ b/client.go
@@ -6,37 +6,38 @@ import (
"crypto/x509"
"net"
"net/url"
+ "strings"
)
-// Client represents a Gemini client.
+// Client is a Gemini client.
type Client struct {
- // KnownHosts is a list of known hosts that the client trusts.
+ // KnownHosts is a list of known hosts.
KnownHosts KnownHosts
- // CertificateStore maps hostnames to certificates.
- // It is used to determine which certificate to use when the server requests
- // a certificate.
- CertificateStore ClientCertificateStore
+ // Certificates stores client-side certificates.
+ Certificates CertificateStore
- // CheckRedirect, if not nil, will be called to determine whether
- // to follow a redirect.
+ // GetInput is called to retrieve input when the server requests it.
+ // If GetInput is nil or returns false, no input will be sent and
+ // the response will be returned.
+ GetInput func(prompt string, sensitive bool) (input string, ok bool)
+
+ // CheckRedirect determines whether to follow a redirect.
// If CheckRedirect is nil, a default policy of no more than 5 consecutive
// redirects will be enforced.
CheckRedirect func(req *Request, via []*Request) error
- // GetInput, if not nil, will be called to retrieve input when the server
- // requests it.
- GetInput func(prompt string, sensitive bool) (string, bool)
-
- // GetCertificate, if not nil, will be called when a server requests a certificate.
- // The returned certificate will be used when sending the request again.
- // If the certificate is nil, the request will not be sent again and
- // the response will be returned.
- GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
-
- // TrustCertificate, if not nil, will be called to determine whether the
- // client should trust the given certificate.
- // If error is not nil, the connection will be aborted.
+ // CreateCertificate is called to generate a certificate upon
+ // the request of a server.
+ // If CreateCertificate is nil or the returned error is not nil,
+ // the request will not be sent again and the response will be returned.
+ CreateCertificate func(hostname, path string) (tls.Certificate, error)
+
+ // TrustCertificate determines whether the client should trust
+ // the provided certificate.
+ // If the returned error is not nil, the connection will be aborted.
+ // If TrustCertificate is nil, the client will check KnownHosts
+ // for the certificate.
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
}
@@ -59,38 +60,18 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
config := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
- GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- // Request certificates take precedence over client certificates
- if req.Certificate != nil {
- return req.Certificate, nil
- }
- // If we have already stored the certificate, return it
- if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
- return cert, nil
- }
- return &tls.Certificate{}, nil
+ GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
+ return c.getClientCertificate(req)
},
VerifyConnection: func(cs tls.ConnectionState) error {
- cert := cs.PeerCertificates[0]
- // Verify the hostname
- if err := verifyHostname(cert, hostname(req.Host)); err != nil {
- return err
- }
- // Check that the client trusts the certificate
- if c.TrustCertificate == nil {
- if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil {
- return err
- }
- } else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil {
- return err
- }
- return nil
+ return c.verifyConnection(req, cs)
},
}
conn, err := tls.Dial("tcp", req.Host, config)
if err != nil {
return nil, err
}
+ // TODO: Set connection deadline
// Write the request
w := bufio.NewWriter(conn)
@@ -104,24 +85,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
if err := resp.read(conn); err != nil {
return nil, err
}
- // Store connection information
+ // Store connection state
resp.TLS = conn.ConnectionState()
- // Resend the request with a certificate if the server responded
- // with CertificateRequired
- if resp.Status == StatusCertificateRequired {
+ switch {
+ case resp.Status == StatusCertificateRequired:
// Check to see if a certificate was already provided to prevent an infinite loop
if req.Certificate != nil {
return resp, nil
}
- if c.GetCertificate != nil {
- if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
- req.Certificate = cert
- return c.Do(req)
+
+ hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
+ if c.CreateCertificate != nil {
+ cert, err := c.CreateCertificate(hostname, path)
+ if err != nil {
+ return resp, err
}
+ c.Certificates.Add(hostname+path, cert)
+ return c.do(req, via)
}
return resp, ErrCertificateRequired
- } else if resp.Status.Class() == StatusClassRedirect {
+
+ case resp.Status.Class() == StatusClassInput:
+ if c.GetInput != nil {
+ input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
+ if ok {
+ req.URL.ForceQuery = true
+ req.URL.RawQuery = url.QueryEscape(input)
+ return c.do(req, via)
+ }
+ }
+ return resp, ErrInputRequired
+
+ case resp.Status.Class() == StatusClassRedirect:
if via == nil {
via = []*Request{}
}
@@ -146,27 +142,44 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, ErrTooManyRedirects
}
return c.do(redirect, via)
- } else if resp.Status.Class() == StatusClassInput {
- if c.GetInput != nil {
- input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
- if ok {
- req.URL.ForceQuery = true
- req.URL.RawQuery = url.QueryEscape(input)
- return c.do(req, via)
- }
- }
- return resp, ErrInputRequired
}
resp.Request = req
return resp, nil
}
-// hostname returns the host without the port.
-func hostname(host string) string {
- hostname, _, err := net.SplitHostPort(host)
- if err != nil {
- return host
+func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
+ // Request certificates have the highest precedence
+ if req.Certificate != nil {
+ return req.Certificate, nil
+ }
+ hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
+ if cert, err := c.Certificates.lookup(hostname + path); err == nil {
+ // Remember the certificate used
+ req.Certificate = cert
+ return cert, nil
+ }
+ return &tls.Certificate{}, nil
+}
+
+func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
+ // Verify the hostname
+ var hostname string
+ if host, _, err := net.SplitHostPort(req.Host); err == nil {
+ hostname = host
+ } else {
+ hostname = req.Host
+ }
+ cert := cs.PeerCertificates[0]
+ if err := verifyHostname(cert, hostname); err != nil {
+ return err
+ }
+ // Check that the client trusts the certificate
+ var err error
+ if c.TrustCertificate != nil {
+ return c.TrustCertificate(hostname, cert, &c.KnownHosts)
+ } else {
+ err = c.KnownHosts.Lookup(hostname, cert)
}
- return hostname
+ return err
}