aboutsummaryrefslogtreecommitdiff
path: root/client.go
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2021-02-23 15:34:53 -0500
committerAdnan Maolood <[email protected]>2021-02-23 15:52:47 -0500
commit9974071657fc2b5195dd00e2af074c5babc65bb7 (patch)
tree4a780b70c39e517b62d579c3ea91f22bc2c2e337 /client.go
parentexamples/stream: Simplify (diff)
downloadgo-gemini-9974071657fc2b5195dd00e2af074c5babc65bb7.tar.xz
go-gemini-9974071657fc2b5195dd00e2af074c5babc65bb7.zip
client: Cancel context on IO errors
Also close the connection when the context expires.
Diffstat (limited to 'client.go')
-rw-r--r--client.go76
1 files changed, 72 insertions, 4 deletions
diff --git a/client.go b/client.go
index 8c94f23..4372464 100644
--- a/client.go
+++ b/client.go
@@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
+ "io"
"net"
"net/url"
"time"
@@ -124,7 +125,22 @@ func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
res := make(chan result, 1)
go func() {
- resp, err := c.do(conn, req)
+ ctx, cancel := context.WithCancel(ctx)
+ done := ctx.Done()
+ cw := &contextWriter{
+ ctx: ctx,
+ done: done,
+ cancel: cancel,
+ wc: conn,
+ }
+ cr := &contextReader{
+ ctx: ctx,
+ done: done,
+ cancel: cancel,
+ rc: conn,
+ }
+
+ resp, err := c.do(cw, cr, req)
res <- result{resp, err}
}()
@@ -137,14 +153,14 @@ func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
}
}
-func (c *Client) do(conn net.Conn, req *Request) (*Response, error) {
+func (c *Client) do(w io.Writer, rc io.ReadCloser, req *Request) (*Response, error) {
// Write the request
- if err := req.Write(conn); err != nil {
+ if err := req.Write(w); err != nil {
return nil, err
}
// Read the response
- resp, err := ReadResponse(conn)
+ resp, err := ReadResponse(rc)
if err != nil {
return nil, err
}
@@ -206,3 +222,55 @@ func punycodeHostname(hostname string) (string, error) {
}
return idna.Lookup.ToASCII(hostname)
}
+
+type contextReader struct {
+ ctx context.Context
+ done <-chan struct{}
+ cancel func()
+ rc io.ReadCloser
+}
+
+func (r *contextReader) Read(p []byte) (int, error) {
+ select {
+ case <-r.done:
+ r.rc.Close()
+ return 0, r.ctx.Err()
+ default:
+ }
+ n, err := r.rc.Read(p)
+ if err != nil {
+ r.cancel()
+ }
+ return n, err
+}
+
+func (r *contextReader) Close() error {
+ r.cancel()
+ return r.rc.Close()
+}
+
+type contextWriter struct {
+ ctx context.Context
+ done <-chan struct{}
+ cancel func()
+ wc io.WriteCloser
+}
+
+func (w *contextWriter) Write(b []byte) (int, error) {
+ select {
+ case <-w.done:
+ w.wc.Close()
+ return 0, w.ctx.Err()
+ default:
+ }
+ n, err := w.wc.Write(b)
+ if err != nil {
+ w.cancel()
+ }
+ return n, err
+}
+
+func (w *contextWriter) Close() error {
+ w.cancel()
+ return w.wc.Close()
+}