diff options
| author | Adnan Maolood <[email protected]> | 2020-10-28 13:40:25 -0400 |
|---|---|---|
| committer | Adnan Maolood <[email protected]> | 2020-10-28 13:41:24 -0400 |
| commit | fbd97a62dec02ad22b7cf520cfc6ab519ea0e990 (patch) | |
| tree | 8a19117713cddce2d3ed2d31c24bec59fe616a48 /client.go | |
| parent | Add ErrInputRequired and ErrCertificateRequired (diff) | |
| download | go-gemini-fbd97a62dec02ad22b7cf520cfc6ab519ea0e990.tar.xz go-gemini-fbd97a62dec02ad22b7cf520cfc6ab519ea0e990.zip | |
Refactor client certificates
Diffstat (limited to 'client.go')
| -rw-r--r-- | client.go | 153 |
1 files changed, 83 insertions, 70 deletions
@@ -6,37 +6,38 @@ import ( "crypto/x509" "net" "net/url" + "strings" ) -// Client represents a Gemini client. +// Client is a Gemini client. type Client struct { - // KnownHosts is a list of known hosts that the client trusts. + // KnownHosts is a list of known hosts. KnownHosts KnownHosts - // CertificateStore maps hostnames to certificates. - // It is used to determine which certificate to use when the server requests - // a certificate. - CertificateStore ClientCertificateStore + // Certificates stores client-side certificates. + Certificates CertificateStore - // CheckRedirect, if not nil, will be called to determine whether - // to follow a redirect. + // GetInput is called to retrieve input when the server requests it. + // If GetInput is nil or returns false, no input will be sent and + // the response will be returned. + GetInput func(prompt string, sensitive bool) (input string, ok bool) + + // CheckRedirect determines whether to follow a redirect. // If CheckRedirect is nil, a default policy of no more than 5 consecutive // redirects will be enforced. CheckRedirect func(req *Request, via []*Request) error - // GetInput, if not nil, will be called to retrieve input when the server - // requests it. - GetInput func(prompt string, sensitive bool) (string, bool) - - // GetCertificate, if not nil, will be called when a server requests a certificate. - // The returned certificate will be used when sending the request again. - // If the certificate is nil, the request will not be sent again and - // the response will be returned. - GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate - - // TrustCertificate, if not nil, will be called to determine whether the - // client should trust the given certificate. - // If error is not nil, the connection will be aborted. + // CreateCertificate is called to generate a certificate upon + // the request of a server. + // If CreateCertificate is nil or the returned error is not nil, + // the request will not be sent again and the response will be returned. + CreateCertificate func(hostname, path string) (tls.Certificate, error) + + // TrustCertificate determines whether the client should trust + // the provided certificate. + // If the returned error is not nil, the connection will be aborted. + // If TrustCertificate is nil, the client will check KnownHosts + // for the certificate. TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error } @@ -59,38 +60,18 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { config := &tls.Config{ InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, - GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - // Request certificates take precedence over client certificates - if req.Certificate != nil { - return req.Certificate, nil - } - // If we have already stored the certificate, return it - if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil { - return cert, nil - } - return &tls.Certificate{}, nil + GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return c.getClientCertificate(req) }, VerifyConnection: func(cs tls.ConnectionState) error { - cert := cs.PeerCertificates[0] - // Verify the hostname - if err := verifyHostname(cert, hostname(req.Host)); err != nil { - return err - } - // Check that the client trusts the certificate - if c.TrustCertificate == nil { - if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil { - return err - } - } else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil { - return err - } - return nil + return c.verifyConnection(req, cs) }, } conn, err := tls.Dial("tcp", req.Host, config) if err != nil { return nil, err } + // TODO: Set connection deadline // Write the request w := bufio.NewWriter(conn) @@ -104,24 +85,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { if err := resp.read(conn); err != nil { return nil, err } - // Store connection information + // Store connection state resp.TLS = conn.ConnectionState() - // Resend the request with a certificate if the server responded - // with CertificateRequired - if resp.Status == StatusCertificateRequired { + switch { + case resp.Status == StatusCertificateRequired: // Check to see if a certificate was already provided to prevent an infinite loop if req.Certificate != nil { return resp, nil } - if c.GetCertificate != nil { - if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil { - req.Certificate = cert - return c.Do(req) + + hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/") + if c.CreateCertificate != nil { + cert, err := c.CreateCertificate(hostname, path) + if err != nil { + return resp, err } + c.Certificates.Add(hostname+path, cert) + return c.do(req, via) } return resp, ErrCertificateRequired - } else if resp.Status.Class() == StatusClassRedirect { + + case resp.Status.Class() == StatusClassInput: + if c.GetInput != nil { + input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput) + if ok { + req.URL.ForceQuery = true + req.URL.RawQuery = url.QueryEscape(input) + return c.do(req, via) + } + } + return resp, ErrInputRequired + + case resp.Status.Class() == StatusClassRedirect: if via == nil { via = []*Request{} } @@ -146,27 +142,44 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { return resp, ErrTooManyRedirects } return c.do(redirect, via) - } else if resp.Status.Class() == StatusClassInput { - if c.GetInput != nil { - input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput) - if ok { - req.URL.ForceQuery = true - req.URL.RawQuery = url.QueryEscape(input) - return c.do(req, via) - } - } - return resp, ErrInputRequired } resp.Request = req return resp, nil } -// hostname returns the host without the port. -func hostname(host string) string { - hostname, _, err := net.SplitHostPort(host) - if err != nil { - return host +func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) { + // Request certificates have the highest precedence + if req.Certificate != nil { + return req.Certificate, nil + } + hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/") + if cert, err := c.Certificates.lookup(hostname + path); err == nil { + // Remember the certificate used + req.Certificate = cert + return cert, nil + } + return &tls.Certificate{}, nil +} + +func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { + // Verify the hostname + var hostname string + if host, _, err := net.SplitHostPort(req.Host); err == nil { + hostname = host + } else { + hostname = req.Host + } + cert := cs.PeerCertificates[0] + if err := verifyHostname(cert, hostname); err != nil { + return err + } + // Check that the client trusts the certificate + var err error + if c.TrustCertificate != nil { + return c.TrustCertificate(hostname, cert, &c.KnownHosts) + } else { + err = c.KnownHosts.Lookup(hostname, cert) } - return hostname + return err } |