aboutsummaryrefslogtreecommitdiff
path: root/client.go
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2021-02-20 13:37:08 -0500
committerAdnan Maolood <[email protected]>2021-02-20 15:34:21 -0500
commit3f4fd10b6d92ec45dad3788770a90127da57ca17 (patch)
tree7d989682ab2c33032349e49b211c036e1c3b2f4f /client.go
parentserver: Make Request.RemoteAddr a string (diff)
downloadgo-gemini-3f4fd10b6d92ec45dad3788770a90127da57ca17.tar.xz
go-gemini-3f4fd10b6d92ec45dad3788770a90127da57ca17.zip
client: Make Get and Do accept a Context
This removes the need for Request.Context.
Diffstat (limited to 'client.go')
-rw-r--r--client.go167
1 files changed, 90 insertions, 77 deletions
diff --git a/client.go b/client.go
index 1969f1e..5a94031 100644
--- a/client.go
+++ b/client.go
@@ -5,8 +5,8 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
- "fmt"
"net"
+ "net/url"
"time"
)
@@ -28,6 +28,10 @@ type Client struct {
//
// A Timeout of zero means no timeout.
Timeout time.Duration
+
+ // DialContext specifies the dial function for creating TCP connections.
+ // If DialContext is nil, the client dials using package net.
+ DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
}
// Get sends a Gemini request for the given URL.
@@ -39,12 +43,12 @@ type Client struct {
// which the user is expected to close.
//
// For more control over requests, use NewRequest and Client.Do.
-func (c *Client) Get(url string) (*Response, error) {
+func (c *Client) Get(ctx context.Context, url string) (*Response, error) {
req, err := NewRequest(url)
if err != nil {
return nil, err
}
- return c.Do(req)
+ return c.Do(ctx, req)
}
// Do sends a Gemini request and returns a Gemini response, following
@@ -57,48 +61,56 @@ func (c *Client) Get(url string) (*Response, error) {
// which the user is expected to close.
//
// Generally Get will be used instead of Do.
-func (c *Client) Do(req *Request) (*Response, error) {
- // Punycode request URL host
- hostname, port, err := net.SplitHostPort(req.URL.Host)
- if err != nil {
- // Likely no port
- hostname = req.URL.Host
- port = "1965"
+func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
+ if ctx == nil {
+ panic("nil context")
}
- punycode, err := punycodeHostname(hostname)
+
+ // Punycode request URL host
+ host, port := splitHostPort(req.URL.Host)
+ punycode, err := punycodeHostname(host)
if err != nil {
return nil, err
}
- if hostname != punycode {
- hostname = punycode
+ if host != punycode {
+ host = punycode
// Make a copy of the request
- _req := *req
- req = &_req
- _url := *req.URL
- req.URL = &_url
+ r2 := new(Request)
+ *r2 = *req
+ r2.URL = new(url.URL)
+ *r2.URL = *req.URL
+ req = r2
// Set the host
- req.URL.Host = net.JoinHostPort(hostname, port)
+ req.URL.Host = net.JoinHostPort(host, port)
}
// Use request host if provided
if req.Host != "" {
- hostname, port, err = net.SplitHostPort(req.Host)
- if err != nil {
- // Likely no port
- hostname = req.Host
- port = "1965"
- }
- // Punycode hostname
- hostname, err = punycodeHostname(hostname)
+ host, port = splitHostPort(req.Host)
+ host, err = punycodeHostname(host)
if err != nil {
return nil, err
}
}
+ addr := net.JoinHostPort(host, port)
+
// Connect to the host
- config := &tls.Config{
+ start := time.Now()
+ conn, err := c.dialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set the connection deadline
+ if c.Timeout != 0 {
+ conn.SetDeadline(start.Add(c.Timeout))
+ }
+
+ // Setup TLS
+ conn = tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
@@ -108,83 +120,84 @@ func (c *Client) Do(req *Request) (*Response, error) {
return &tls.Certificate{}, nil
},
VerifyConnection: func(cs tls.ConnectionState) error {
- return c.verifyConnection(hostname, punycode, cs)
+ return c.verifyConnection(cs, host)
},
- ServerName: hostname,
- }
-
- ctx := req.Context
- if ctx == nil {
- ctx = context.Background()
- }
-
- start := time.Now()
- dialer := net.Dialer{
- Timeout: c.Timeout,
- }
-
- address := net.JoinHostPort(hostname, port)
- netConn, err := dialer.DialContext(ctx, "tcp", address)
- if err != nil {
- return nil, err
- }
-
- conn := tls.Client(netConn, config)
-
- // Set connection deadline
- if c.Timeout != 0 {
- err := conn.SetDeadline(start.Add(c.Timeout))
- if err != nil {
- return nil, fmt.Errorf("failed to set connection deadline: %w", err)
- }
- }
+ ServerName: host,
+ })
- resp, err := c.do(conn, req)
- if err != nil {
- // If we fail to perform the request/response we have
- // to take responsibility for closing the connection.
- _ = conn.Close()
+ res := make(chan result, 1)
+ go func() {
+ res <- c.do(conn, req)
+ }()
- return nil, err
+ select {
+ case <-ctx.Done():
+ conn.Close()
+ return nil, ctx.Err()
+ case r := <-res:
+ return r.resp, r.err
}
+}
- // Store connection state
- state := conn.ConnectionState()
- resp.TLS = &state
-
- return resp, nil
+type result struct {
+ resp *Response
+ err error
}
-func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) {
+func (c *Client) do(conn net.Conn, req *Request) result {
// Write the request
- err := req.Write(conn)
- if err != nil {
- return nil, fmt.Errorf("failed to write request: %w", err)
+ if err := req.Write(conn); err != nil {
+ return result{nil, err}
}
// Read the response
resp, err := ReadResponse(conn)
if err != nil {
- return nil, err
+ return result{nil, err}
+ }
+
+ // Store TLS connection state
+ if tlsConn, ok := conn.(*tls.Conn); ok {
+ state := tlsConn.ConnectionState()
+ resp.TLS = &state
}
- return resp, nil
+ return result{resp, nil}
+}
+
+func (c *Client) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+ if c.DialContext != nil {
+ return c.DialContext(ctx, network, addr)
+ }
+ return (&net.Dialer{
+ Timeout: c.Timeout,
+ }).DialContext(ctx, network, addr)
}
-func (c *Client) verifyConnection(hostname, punycode string, cs tls.ConnectionState) error {
+func (c *Client) verifyConnection(cs tls.ConnectionState, hostname string) error {
cert := cs.PeerCertificates[0]
- // Verify punycoded hostname
- if err := verifyHostname(cert, punycode); err != nil {
+ // Verify hostname
+ if err := verifyHostname(cert, hostname); err != nil {
return err
}
// Check expiration date
if !time.Now().Before(cert.NotAfter) {
return errors.New("gemini: certificate expired")
}
-
// See if the client trusts the certificate
if c.TrustCertificate != nil {
return c.TrustCertificate(hostname, cert)
}
return nil
}
+
+func splitHostPort(hostport string) (host, port string) {
+ var err error
+ host, port, err = net.SplitHostPort(hostport)
+ if err != nil {
+ // Likely no port
+ host = hostport
+ port = "1965"
+ }
+ return
+}