aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2020-10-27 23:34:06 -0400
committerAdnan Maolood <[email protected]>2020-10-27 23:34:06 -0400
commitd1dcf070fff5e1215d2869ac8a1fafacf5a75e1e (patch)
treed1be50e8ab08be7299e04bb891097f51fea7441d
parentclient: Follow redirects (diff)
downloadgo-gemini-d1dcf070fff5e1215d2869ac8a1fafacf5a75e1e.tar.xz
go-gemini-d1dcf070fff5e1215d2869ac8a1fafacf5a75e1e.zip
Restrict client certificates to certain paths
-rw-r--r--cert.go36
-rw-r--r--client.go8
-rw-r--r--examples/client.go16
-rw-r--r--gemini.go8
4 files changed, 48 insertions, 20 deletions
diff --git a/cert.go b/cert.go
index a19a2bf..4f11fe6 100644
--- a/cert.go
+++ b/cert.go
@@ -8,6 +8,7 @@ import (
"crypto/x509"
"math/big"
"net"
+ "path"
"path/filepath"
"strings"
"time"
@@ -70,6 +71,27 @@ func (c *CertificateStore) Load(path string) error {
return nil
}
+type ClientCertificateStore struct {
+ CertificateStore
+}
+
+func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) {
+ urlPath = path.Clean(urlPath)
+ for {
+ cert, err := c.CertificateStore.Lookup(hostname + urlPath)
+ switch err {
+ case ErrCertificateExpired, nil:
+ return cert, err
+ }
+ slash := urlPath == "/"
+ urlPath = path.Dir(urlPath)
+ if slash && urlPath == "/" {
+ break
+ }
+ }
+ return nil, ErrCertificateUnknown
+}
+
// NewCertificate creates and returns a new parsed certificate.
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
crt, priv, err := newX509KeyPair(host, duration)
@@ -114,12 +136,14 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
BasicConstraintsValid: true,
}
- hosts := strings.Split(host, ",")
- for _, h := range hosts {
- if ip := net.ParseIP(h); ip != nil {
- template.IPAddresses = append(template.IPAddresses, ip)
- } else {
- template.DNSNames = append(template.DNSNames, h)
+ if host != "" {
+ hosts := strings.Split(host, ",")
+ for _, h := range hosts {
+ if ip := net.ParseIP(h); ip != nil {
+ template.IPAddresses = append(template.IPAddresses, ip)
+ } else {
+ template.DNSNames = append(template.DNSNames, h)
+ }
}
}
diff --git a/client.go b/client.go
index d8d88ae..91402b1 100644
--- a/client.go
+++ b/client.go
@@ -16,7 +16,7 @@ type Client struct {
// CertificateStore maps hostnames to certificates.
// It is used to determine which certificate to use when the server requests
// a certificate.
- CertificateStore CertificateStore
+ CertificateStore ClientCertificateStore
// CheckRedirect, if not nil, will be called to determine whether
// to follow a redirect.
@@ -28,7 +28,7 @@ type Client struct {
// The returned certificate will be used when sending the request again.
// If the certificate is nil, the request will not be sent again and
// the response will be returned.
- GetCertificate func(hostname string, store *CertificateStore) *tls.Certificate
+ GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
// TrustCertificate, if not nil, will be called to determine whether the
// client should trust the given certificate.
@@ -61,7 +61,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return req.Certificate, nil
}
// If we have already stored the certificate, return it
- if cert, err := c.CertificateStore.Lookup(hostname(req.Host)); err == nil {
+ if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
return cert, nil
}
return &tls.Certificate{}, nil
@@ -111,7 +111,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, nil
}
if c.GetCertificate != nil {
- if cert := c.GetCertificate(hostname(req.Host), &c.CertificateStore); cert != nil {
+ if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
req.Certificate = cert
return c.Do(req)
}
diff --git a/examples/client.go b/examples/client.go
index 5180f5c..e2e1669 100644
--- a/examples/client.go
+++ b/examples/client.go
@@ -8,7 +8,6 @@ import (
"crypto/x509"
"fmt"
"io/ioutil"
- "net/url"
"os"
"time"
@@ -47,22 +46,27 @@ func init() {
}
return err
}
- client.GetCertificate = func(hostname string, store *gemini.CertificateStore) *tls.Certificate {
+ client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate {
// If the certificate is in the store, return it
- if cert, err := store.Lookup(hostname); err == nil {
+ if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
return cert
}
// Otherwise, generate a certificate
- fmt.Println("Generating client certificate for", hostname)
+ fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path)
duration := time.Hour
- cert, err := gemini.NewCertificate(hostname, duration)
+ cert, err := gemini.NewCertificate("", duration)
if err != nil {
return nil
}
// Store and return the certificate
- store.Add(hostname, cert)
+ store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert
}
+ client.GetInput = func(prompt string, sensitive bool) (string, bool) {
+ fmt.Printf("%s: ", prompt)
+ scanner.Scan()
+ return scanner.Text(), true
+ }
}
// sendRequest sends a request to the given URL.
diff --git a/gemini.go b/gemini.go
index f3e0c81..635622c 100644
--- a/gemini.go
+++ b/gemini.go
@@ -49,16 +49,16 @@ func init() {
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
return knownHosts.Lookup(hostname, cert)
}
- DefaultClient.GetCertificate = func(hostname string, store *CertificateStore) *tls.Certificate {
- if cert, err := store.Lookup(hostname); err == nil {
+ DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate {
+ if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
return cert
}
duration := time.Hour
- cert, err := NewCertificate(hostname, duration)
+ cert, err := NewCertificate("", duration)
if err != nil {
return nil
}
- store.Add(hostname, cert)
+ store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert
}
}