diff options
| author | Adnan Maolood <[email protected]> | 2020-12-18 00:12:32 -0500 |
|---|---|---|
| committer | Adnan Maolood <[email protected]> | 2020-12-18 00:12:32 -0500 |
| commit | e2c907a7f65a98a1737ffe50f55d6c7dfa9abb9e (patch) | |
| tree | 2bf87afb600fde3843ea09536965896955ecedb0 /examples | |
| parent | Update switch statement (diff) | |
| download | go-gemini-e2c907a7f65a98a1737ffe50f55d6c7dfa9abb9e.tar.xz go-gemini-e2c907a7f65a98a1737ffe50f55d6c7dfa9abb9e.zip | |
client: Remove GetInput and CheckRedirect callbacks
Diffstat (limited to 'examples')
| -rw-r--r-- | examples/client.go | 113 |
1 files changed, 80 insertions, 33 deletions
diff --git a/examples/client.go b/examples/client.go index 1c98bf5..4975158 100644 --- a/examples/client.go +++ b/examples/client.go @@ -9,6 +9,7 @@ import ( "fmt" "io/ioutil" "log" + "net/url" "os" "path/filepath" "time" @@ -17,6 +18,22 @@ import ( "git.sr.ht/~adnano/go-xdg" ) +var ( + hosts gemini.KnownHostsFile + scanner *bufio.Scanner +) + +func init() { + // Load known hosts file + path := filepath.Join(xdg.DataHome(), "gemini", "known_hosts") + err := hosts.Load(path) + if err != nil { + log.Println(err) + } + + scanner = bufio.NewScanner(os.Stdin) +} + const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is: %s @@ -26,47 +43,77 @@ Otherwise, this should be safe to trust. [t]rust always; trust [o]nce; [a]bort => ` -func main() { - if len(os.Args) < 2 { - fmt.Printf("usage: %s <url> [host]", os.Args[0]) - os.Exit(1) +func trustCertificate(hostname string, cert *x509.Certificate) error { + knownHost, ok := hosts.Lookup(hostname) + if ok && time.Now().Before(knownHost.Expires) { + // Certificate is in known hosts file and is not expired + return nil } - // Load known hosts file - var knownHosts gemini.KnownHostsFile - if err := knownHosts.Load(filepath.Join(xdg.DataHome(), "gemini", "known_hosts")); err != nil { - log.Println(err) + fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter) + fmt.Printf(trustPrompt, hostname, fingerprint.Hex) + scanner.Scan() + switch scanner.Text() { + case "t": + hosts.Add(hostname, fingerprint) + hosts.Write(hostname, fingerprint) + return nil + case "o": + hosts.Add(hostname, fingerprint) + return nil + default: + return errors.New("certificate not trusted") } +} - scanner := bufio.NewScanner(os.Stdin) +func getInput(prompt string, sensitive bool) (input string, ok bool) { + fmt.Printf("%s ", prompt) + scanner.Scan() + return scanner.Text(), true +} - var client gemini.Client - client.TrustCertificate = func(hostname string, cert *x509.Certificate) error { - knownHost, ok := knownHosts.Lookup(hostname) - if ok && time.Now().Before(knownHost.Expires) { - // Certificate is in known hosts file and is not expired - return nil +func do(req *gemini.Request, via []*gemini.Request) (*gemini.Response, error) { + client := gemini.Client{ + TrustCertificate: trustCertificate, + } + resp, err := client.Do(req) + if err != nil { + return resp, err + } + + switch resp.Status.Class() { + case gemini.StatusClassInput: + input, ok := getInput(resp.Meta, resp.Status == gemini.StatusSensitiveInput) + if !ok { + break + } + req.URL.ForceQuery = true + req.URL.RawQuery = gemini.QueryEscape(input) + return do(req, via) + + case gemini.StatusClassRedirect: + via = append(via, req) + if len(via) > 5 { + return resp, errors.New("too many redirects") } - fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter) - fmt.Printf(trustPrompt, hostname, fingerprint.Hex) - scanner.Scan() - switch scanner.Text() { - case "t": - knownHosts.Add(hostname, fingerprint) - knownHosts.Write(hostname, fingerprint) - return nil - case "o": - knownHosts.Add(hostname, fingerprint) - return nil - default: - return errors.New("certificate not trusted") + target, err := url.Parse(resp.Meta) + if err != nil { + return resp, err } + target = req.URL.ResolveReference(target) + redirect := *req + redirect.URL = target + return do(&redirect, via) } - client.GetInput = func(prompt string, sensitive bool) (string, bool) { - fmt.Printf("%s ", prompt) - scanner.Scan() - return scanner.Text(), true + + return resp, err +} + +func main() { + if len(os.Args) < 2 { + fmt.Printf("usage: %s <url> [host]\n", os.Args[0]) + os.Exit(1) } // Do the request @@ -79,7 +126,7 @@ func main() { if len(os.Args) == 3 { req.Host = os.Args[2] } - resp, err := client.Do(req) + resp, err := do(req, nil) if err != nil { fmt.Println(err) os.Exit(1) |