aboutsummaryrefslogtreecommitdiff
path: root/tofu/tofu.go
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2021-01-25 12:02:09 -0500
committerAdnan Maolood <[email protected]>2021-01-25 12:11:59 -0500
commit62960266acb31027f15881d0818430ae293cd0db (patch)
tree95566c0ac17c8b2b9e2a082cbd257b2a268c3795 /tofu/tofu.go
parentUpdate examples (diff)
downloadgo-gemini-62960266acb31027f15881d0818430ae293cd0db.tar.xz
go-gemini-62960266acb31027f15881d0818430ae293cd0db.zip
tofu: Implement PersistentHosts
Diffstat (limited to 'tofu/tofu.go')
-rw-r--r--tofu/tofu.go95
1 files changed, 84 insertions, 11 deletions
diff --git a/tofu/tofu.go b/tofu/tofu.go
index 2ea8ac8..a928be6 100644
--- a/tofu/tofu.go
+++ b/tofu/tofu.go
@@ -27,7 +27,7 @@ type KnownHosts struct {
}
// Add adds a host to the list of known hosts.
-func (k *KnownHosts) Add(h Host) error {
+func (k *KnownHosts) Add(h Host) {
k.mu.Lock()
defer k.mu.Unlock()
if k.hosts == nil {
@@ -35,7 +35,6 @@ func (k *KnownHosts) Add(h Host) error {
}
k.hosts[h.Hostname] = h
- return nil
}
// Lookup returns the known host entry corresponding to the given hostname.
@@ -144,7 +143,7 @@ func (k *KnownHosts) Parse(r io.Reader) error {
// TOFU implements basic trust on first use.
//
// If the host is not on file, it is added to the list.
-// If the host on file is expired, it is replaced with the provided host.
+// If the host on file is expired, a new entry is added to the list.
// If the fingerprint does not match the one on file, an error is returned.
func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error {
host := NewHost(hostname, cert.Raw, cert.NotAfter)
@@ -181,9 +180,9 @@ func NewHostWriter(w io.WriteCloser) *HostWriter {
}
}
-// NewHostsFile returns a new host writer that appends to the file at the given path.
+// OpenHostsFile returns a new host writer that appends to the file at the given path.
// The file is created if it does not exist.
-func NewHostsFile(path string) (*HostWriter, error) {
+func OpenHostsFile(path string) (*HostWriter, error) {
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return nil, err
@@ -212,6 +211,83 @@ func (h *HostWriter) Close() error {
return h.cl.Close()
}
+// PersistentHosts represents a persistent set of known hosts.
+type PersistentHosts struct {
+ hosts *KnownHosts
+ writer *HostWriter
+}
+
+// NewPersistentHosts returns a new persistent set of known hosts.
+func NewPersistentHosts(hosts *KnownHosts, writer *HostWriter) *PersistentHosts {
+ return &PersistentHosts{
+ hosts,
+ writer,
+ }
+}
+
+// LoadPersistentHosts loads persistent hosts from the file at the given path.
+func LoadPersistentHosts(path string) (*PersistentHosts, error) {
+ hosts := &KnownHosts{}
+ if err := hosts.Load(path); err != nil {
+ return nil, err
+ }
+ writer, err := OpenHostsFile(path)
+ if err != nil {
+ return nil, err
+ }
+ return &PersistentHosts{
+ hosts,
+ writer,
+ }, nil
+}
+
+// Add adds a host to the list of known hosts.
+// It returns an error if the host could not be persisted.
+func (p *PersistentHosts) Add(h Host) error {
+ err := p.writer.WriteHost(h)
+ if err != nil {
+ return fmt.Errorf("failed to persist host: %w", err)
+ }
+ p.hosts.Add(h)
+ return nil
+}
+
+// Lookup returns the known host entry corresponding to the given hostname.
+func (p *PersistentHosts) Lookup(hostname string) (Host, bool) {
+ return p.hosts.Lookup(hostname)
+}
+
+// Entries returns the known host entries sorted by hostname.
+func (p *PersistentHosts) Entries() []Host {
+ return p.hosts.Entries()
+}
+
+// TOFU implements trust on first use with a persistent set of known hosts.
+//
+// If the host is not on file, it is added to the list.
+// If the host on file is expired, a new entry is added to the list.
+// If the fingerprint does not match the one on file, an error is returned.
+func (p *PersistentHosts) TOFU(hostname string, cert *x509.Certificate) error {
+ host := NewHost(hostname, cert.Raw, cert.NotAfter)
+
+ knownHost, ok := p.Lookup(hostname)
+ if !ok || time.Now().After(knownHost.Expires) {
+ return p.Add(host)
+ }
+
+ // Check fingerprint
+ if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
+ return fmt.Errorf("fingerprint for %q does not match", hostname)
+ }
+
+ return nil
+}
+
+// Close closes the underlying HostWriter.
+func (p *PersistentHosts) Close() error {
+ return p.writer.Close()
+}
+
// Host represents a host entry with a fingerprint using a certain algorithm.
type Host struct {
Hostname string // hostname
@@ -259,8 +335,7 @@ func (h *Host) UnmarshalText(text []byte) error {
parts := bytes.Split(text, []byte(" "))
if len(parts) != 4 {
- return fmt.Errorf(
- "expected the format %q", format)
+ return fmt.Errorf("expected the format %q", format)
}
if len(parts[0]) == 0 {
@@ -271,8 +346,7 @@ func (h *Host) UnmarshalText(text []byte) error {
algorithm := string(parts[1])
if algorithm != "SHA-512" {
- return fmt.Errorf(
- "unsupported algorithm %q", algorithm)
+ return fmt.Errorf("unsupported algorithm %q", algorithm)
}
h.Algorithm = algorithm
@@ -298,8 +372,7 @@ func (h *Host) UnmarshalText(text []byte) error {
unix, err := strconv.ParseInt(string(parts[3]), 10, 0)
if err != nil {
- return fmt.Errorf(
- "invalid unix timestamp: %w", err)
+ return fmt.Errorf("invalid unix timestamp: %w", err)
}
h.Expires = time.Unix(unix, 0)