diff options
| -rw-r--r-- | client.go | 34 | ||||
| -rw-r--r-- | server.go | 2 | ||||
| -rw-r--r-- | tofu.go | 13 |
3 files changed, 29 insertions, 20 deletions
@@ -213,27 +213,27 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { if c.InsecureSkipTrust { return nil } + // Check the known hosts - // No need to check if it is expired as tls already does that knownHost, ok := c.KnownHosts.Lookup(hostname) - if ok { - fingerprint := NewFingerprint(cert) - if knownHost.Hex != fingerprint.Hex { - return errors.New("gemini: fingerprint does not match") + if !ok || time.Now().Unix() >= knownHost.Expires { + // See if the client trusts the certificate + if c.TrustCertificate != nil { + switch c.TrustCertificate(hostname, cert) { + case TrustOnce: + c.KnownHosts.AddTemporary(hostname, cert) + return nil + case TrustAlways: + c.KnownHosts.Add(hostname, cert) + return nil + } } - return nil + return errors.New("gemini: certificate not trusted") } - // See if the client trusts the certificate - if c.TrustCertificate != nil { - switch c.TrustCertificate(hostname, cert) { - case TrustOnce: - c.KnownHosts.AddTemporary(hostname, cert) - return nil - case TrustAlways: - c.KnownHosts.Add(hostname, cert) - return nil - } + fingerprint := NewFingerprint(cert) + if knownHost.Hex == fingerprint.Hex { + return nil } - return errors.New("gemini: certificate not trusted") + return errors.New("gemini: fingerprint does not match") } @@ -156,7 +156,7 @@ func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) { // Generate a new certificate if it is missing or expired cert, ok := s.Certificates.Lookup(hostname) - if !ok || cert.Leaf != nil && !time.Now().After(cert.Leaf.NotAfter) { + if !ok || cert.Leaf != nil && cert.Leaf.NotAfter.Before(time.Now()) { if s.CreateCertificate != nil { cert, err := s.CreateCertificate(hostname) if err == nil { @@ -8,6 +8,7 @@ import ( "io" "os" "path/filepath" + "strconv" "strings" ) @@ -105,7 +106,7 @@ func (k *KnownHosts) Parse(r io.Reader) { for scanner.Scan() { text := scanner.Text() parts := strings.Split(text, " ") - if len(parts) < 3 { + if len(parts) < 4 { continue } @@ -116,9 +117,15 @@ func (k *KnownHosts) Parse(r io.Reader) { } fingerprint := parts[2] + expires, err := strconv.ParseInt(parts[3], 10, 0) + if err != nil { + continue + } + k.hosts[hostname] = Fingerprint{ Algorithm: algorithm, Hex: fingerprint, + Expires: expires, } } } @@ -131,13 +138,14 @@ func (k *KnownHosts) Write(w io.Writer) { } func appendKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) { - return fmt.Fprintf(w, "%s %s %s\n", hostname, f.Algorithm, f.Hex) + return fmt.Fprintf(w, "%s %s %s %d\n", hostname, f.Algorithm, f.Hex, f.Expires) } // Fingerprint represents a fingerprint using a certain algorithm. type Fingerprint struct { Algorithm string // fingerprint algorithm e.g. SHA-512 Hex string // fingerprint in hexadecimal, with ':' between each octet + Expires int64 // unix time of the fingerprint expiration date } // NewFingerprint returns the SHA-512 fingerprint of the provided certificate. @@ -153,6 +161,7 @@ func NewFingerprint(cert *x509.Certificate) Fingerprint { return Fingerprint{ Algorithm: "SHA-512", Hex: b.String(), + Expires: cert.NotAfter.Unix(), } } |