aboutsummaryrefslogtreecommitdiff
path: root/client.go
diff options
context:
space:
mode:
authoradnano <[email protected]>2020-09-25 23:06:54 -0400
committeradnano <[email protected]>2020-09-25 23:18:14 -0400
commit927dfd29c598f2ec79fec711877bc582ffd18749 (patch)
treea6cd29fa36b890c067fc6e84562ff5ac05056c11 /client.go
parentImplement basic TOFU (diff)
downloadgo-gemini-927dfd29c598f2ec79fec711877bc582ffd18749.tar.xz
go-gemini-927dfd29c598f2ec79fec711877bc582ffd18749.zip
Refactor TOFU
Diffstat (limited to 'client.go')
-rw-r--r--client.go44
1 files changed, 36 insertions, 8 deletions
diff --git a/client.go b/client.go
index bd4cbd4..06f9a67 100644
--- a/client.go
+++ b/client.go
@@ -10,12 +10,15 @@ import (
"net"
"net/url"
"strconv"
+ "strings"
)
// Errors.
var (
- ErrProtocol = errors.New("gemini: protocol error")
- ErrInvalidURL = errors.New("gemini: requested URL is invalid")
+ ErrProtocol = errors.New("gemini: protocol error")
+ ErrInvalidURL = errors.New("gemini: requested URL is invalid")
+ ErrCertificateNotValid = errors.New("gemini: certificate is invalid")
+ ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted")
)
// Request represents a Gemini request.
@@ -163,24 +166,40 @@ func (resp *Response) read(r *bufio.Reader) error {
}
// Client represents a Gemini client.
-type Client interface {
- // VerifyCertificate will be called to verify the server certificate.
- // If error is not nil, the connection will be aborted.
- VerifyCertificate(cert *x509.Certificate, req *Request) error
+type Client struct {
+ // KnownHosts is a list of known hosts that the client trusts.
+ KnownHosts *KnownHosts
+
+ // TrustCertificate, if not nil, will be called to determine whether the
+ // client should trust the given certificate.
+ TrustCertificate func(cert *x509.Certificate, knownHosts *KnownHosts) bool
}
// Send sends a Gemini request and returns a Gemini response.
-func Send(c Client, req *Request) (*Response, error) {
+func (c *Client) Send(req *Request) (*Response, error) {
// Connect to the host
config := &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{req.Certificate},
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
+ // Parse the certificate
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return err
}
- return c.VerifyCertificate(cert, req)
+ // Check that the certificate is valid for the hostname
+ if cert.Subject.CommonName != hostname(req.Host) {
+ return ErrCertificateNotValid
+ }
+ // Check that the client trusts the certificate
+ if c.TrustCertificate == nil {
+ if c.KnownHosts == nil || !c.KnownHosts.Has(cert) {
+ return ErrCertificateNotTrusted
+ }
+ } else if !c.TrustCertificate(cert, c.KnownHosts) {
+ return ErrCertificateNotTrusted
+ }
+ return nil
},
}
conn, err := tls.Dial("tcp", req.Host, config)
@@ -206,3 +225,12 @@ func Send(c Client, req *Request) (*Response, error) {
}
return resp, nil
}
+
+// hostname extracts the host name from a valid host or host:port
+func hostname(host string) string {
+ i := strings.LastIndexByte(host, ':')
+ if i != -1 {
+ return host[:i]
+ }
+ return host
+}