aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2021-02-21 00:21:29 -0500
committerAdnan Maolood <[email protected]>2021-02-21 00:21:31 -0500
commitf6505ae4c4e93b9ad2e280e7d97a08a4faf206b6 (patch)
tree67ecaa45cab45f63e0e403e4fefa2b578f795896
parentclient: Inline result type (diff)
downloadgo-gemini-f6505ae4c4e93b9ad2e280e7d97a08a4faf206b6.tar.xz
go-gemini-f6505ae4c4e93b9ad2e280e7d97a08a4faf206b6.zip
server: Use explicit context arguments
Replace the Server.Context field with explicit context.Context arguments to most Server functions.
-rw-r--r--server.go305
1 files changed, 167 insertions, 138 deletions
diff --git a/server.go b/server.go
index 779c003..491f516 100644
--- a/server.go
+++ b/server.go
@@ -47,16 +47,98 @@ type Server struct {
// If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger
- // Context is the base context to use.
- // If nil, context.Background is used.
- Context context.Context
-
- listeners map[*net.Listener]struct{}
- conns map[*net.Conn]struct{}
- done int32
+ listeners map[*net.Listener]context.CancelFunc
+ conns map[*net.Conn]context.CancelFunc
+ doneChan chan struct{}
+ closed int32
mu sync.Mutex
}
+// done returns a channel that's closed when the server has finished closing.
+func (srv *Server) done() chan struct{} {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ return srv.doneLocked()
+}
+
+func (srv *Server) doneLocked() chan struct{} {
+ if srv.doneChan == nil {
+ srv.doneChan = make(chan struct{})
+ }
+ return srv.doneChan
+}
+
+// tryFinishShutdown closes srv.done() if there are no active listeners or requests.
+func (srv *Server) tryFinishShutdown() {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ if len(srv.listeners) == 0 && len(srv.conns) == 0 {
+ done := srv.doneLocked()
+ select {
+ case <-done:
+ default:
+ close(done)
+ }
+ }
+}
+
+// Close immediately closes all active net.Listeners and connections.
+// For a graceful shutdown, use Shutdown.
+func (srv *Server) Close() error {
+ if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) {
+ return ErrServerClosed
+ }
+
+ // Close active listeners and connections.
+ srv.mu.Lock()
+ for _, cancel := range srv.listeners {
+ cancel()
+ }
+ for _, cancel := range srv.conns {
+ cancel()
+ }
+ srv.mu.Unlock()
+
+ select {
+ case <-srv.done():
+ return nil
+ }
+}
+
+// Shutdown gracefully shuts down the server without interrupting any
+// active connections. Shutdown works by first closing all open
+// listeners and then waiting indefinitely for connections
+// to close and then shut down.
+// If the provided context expires before the shutdown is complete,
+// Shutdown returns the context's error.
+//
+// When Shutdown is called, Serve and ListenAndServer immediately
+// return ErrServerClosed. Make sure the program doesn't exit and
+// waits instead for Shutdown to return.
+//
+// Once Shutdown has been called on a server, it may not be reused;
+// future calls to methods such as Serve will return ErrServerClosed.
+func (srv *Server) Shutdown(ctx context.Context) error {
+ if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) {
+ return ErrServerClosed
+ }
+
+ // Close active listeners.
+ srv.mu.Lock()
+ for _, cancel := range srv.listeners {
+ cancel()
+ }
+ srv.mu.Unlock()
+
+ // Wait for active connections to finish.
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-srv.done():
+ return nil
+ }
+}
+
// ListenAndServe listens for requests at the server's configured address.
// ListenAndServe listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming connections.
@@ -65,8 +147,8 @@ type Server struct {
//
// ListenAndServe always returns a non-nil error. After Shutdown or Close, the
// returned error is ErrServerClosed.
-func (srv *Server) ListenAndServe() error {
- if atomic.LoadInt32(&srv.done) == 1 {
+func (srv *Server) ListenAndServe(ctx context.Context) error {
+ if atomic.LoadInt32(&srv.closed) == 1 {
return ErrServerClosed
}
@@ -75,26 +157,33 @@ func (srv *Server) ListenAndServe() error {
addr = ":1965"
}
- ln, err := net.Listen("tcp", addr)
+ l, err := net.Listen("tcp", addr)
if err != nil {
return err
}
- defer ln.Close()
- return srv.Serve(tls.NewListener(ln, &tls.Config{
+ l = tls.NewListener(l, &tls.Config{
ClientAuth: tls.RequestClientCert,
MinVersion: tls.VersionTLS12,
GetCertificate: srv.getCertificate,
- }))
+ })
+ return srv.Serve(ctx, l)
+}
+
+func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ if srv.GetCertificate == nil {
+ return nil, errors.New("gemini: GetCertificate is nil")
+ }
+ return srv.GetCertificate(h.ServerName)
}
-func (srv *Server) trackListener(l *net.Listener) {
+func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.listeners == nil {
- srv.listeners = make(map[*net.Listener]struct{})
+ srv.listeners = make(map[*net.Listener]context.CancelFunc)
}
- srv.listeners[l] = struct{}{}
+ srv.listeners[l] = cancel
}
func (srv *Server) deleteListener(l *net.Listener) {
@@ -109,24 +198,46 @@ func (srv *Server) deleteListener(l *net.Listener) {
//
// Serve always returns a non-nil error and closes l. After Shutdown or Close,
// the returned error is ErrServerClosed.
-func (srv *Server) Serve(l net.Listener) error {
+func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
defer l.Close()
- srv.trackListener(&l)
+ if atomic.LoadInt32(&srv.closed) == 1 {
+ return ErrServerClosed
+ }
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ srv.trackListener(&l, cancel)
+ defer srv.tryFinishShutdown()
defer srv.deleteListener(&l)
- if atomic.LoadInt32(&srv.done) == 1 {
- return ErrServerClosed
+ errch := make(chan error, 1)
+ go func() {
+ errch <- srv.serve(ctx, l)
+ }()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case err := <-errch:
+ return err
}
+}
- var tempDelay time.Duration // how long to sleep on accept failure
+func (srv *Server) serve(ctx context.Context, l net.Listener) error {
+ // how long to sleep on accept failure
+ var tempDelay time.Duration
for {
rw, err := l.Accept()
if err != nil {
- if atomic.LoadInt32(&srv.done) == 1 {
- return ErrServerClosed
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
}
+
// If this is a temporary error, sleep
if ne, ok := err.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
@@ -142,115 +253,21 @@ func (srv *Server) Serve(l net.Listener) error {
continue
}
- // Otherwise, return the error
return err
}
tempDelay = 0
- go srv.respond(rw)
+ go srv.serveConn(ctx, rw)
}
}
-func (srv *Server) closeListenersLocked() error {
- var err error
- for ln := range srv.listeners {
- if cerr := (*ln).Close(); cerr != nil && err == nil {
- err = cerr
- }
- delete(srv.listeners, ln)
- }
- return err
-}
-
-// Close immediately closes all active net.Listeners and connections.
-// For a graceful shutdown, use Shutdown.
-//
-// Close returns any error returned from closing the Server's
-// underlying Listener(s).
-func (srv *Server) Close() error {
- srv.mu.Lock()
- defer srv.mu.Unlock()
- if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) {
- return ErrServerClosed
- }
- err := srv.closeListenersLocked()
-
- // Close active connections
- for conn := range srv.conns {
- (*conn).Close()
- delete(srv.conns, conn)
- }
- return err
-}
-
-func (srv *Server) numConns() int {
- srv.mu.Lock()
- defer srv.mu.Unlock()
- return len(srv.conns)
-}
-
-// shutdownPollInterval is how often we poll for quiescence
-// during Server.Shutdown. This is lower during tests, to
-// speed up tests.
-// Ideally we could find a solution that doesn't involve polling,
-// but which also doesn't have a high runtime cost (and doesn't
-// involve any contentious mutexes), but that is left as an
-// exercise for the reader.
-var shutdownPollInterval = 500 * time.Millisecond
-
-// Shutdown gracefully shuts down the server without interrupting any
-// active connections. Shutdown works by first closing all open
-// listeners and then waiting indefinitely for connections
-// to close and then shut down.
-// If the provided context expires before the shutdown is complete,
-// Shutdown returns the context's error, otherwise it returns any
-// error returned from closing the Server's underlying Listener(s).
-//
-// When Shutdown is called, Serve, ListenAndServe, and
-// ListenAndServeTLS immediately return ErrServerClosed. Make sure the
-// program doesn't exit and waits instead for Shutdown to return.
-//
-// Once Shutdown has been called on a server, it may not be reused;
-// future calls to methods such as Serve will return ErrServerClosed.
-func (srv *Server) Shutdown(ctx context.Context) error {
- if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) {
- return ErrServerClosed
- }
-
- srv.mu.Lock()
- err := srv.closeListenersLocked()
- srv.mu.Unlock()
-
- // Wait for active connections to close
- ticker := time.NewTicker(shutdownPollInterval)
- defer ticker.Stop()
- for {
- if srv.numConns() == 0 {
- return err
- }
-
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-ticker.C:
- }
- }
-}
-
-func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
- if srv.GetCertificate == nil {
- return nil, errors.New("gemini: GetCertificate is nil")
- }
- return srv.GetCertificate(h.ServerName)
-}
-
-func (srv *Server) trackConn(conn *net.Conn) {
+func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.conns == nil {
- srv.conns = make(map[*net.Conn]struct{})
+ srv.conns = make(map[*net.Conn]context.CancelFunc)
}
- srv.conns[conn] = struct{}{}
+ srv.conns[conn] = cancel
}
func (srv *Server) deleteConn(conn *net.Conn) {
@@ -259,10 +276,22 @@ func (srv *Server) deleteConn(conn *net.Conn) {
delete(srv.conns, conn)
}
-// respond responds to a connection.
-func (srv *Server) respond(conn net.Conn) {
+// serveConn serves a Gemini response over the provided connection.
+// It closes the connection when the response has been completed.
+func (srv *Server) serveConn(ctx context.Context, conn net.Conn) {
defer conn.Close()
+ if atomic.LoadInt32(&srv.closed) == 1 {
+ return
+ }
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ srv.trackConn(&conn, cancel)
+ defer srv.tryFinishShutdown()
+ defer srv.deleteConn(&conn)
+
defer func() {
if err := recover(); err != nil && err != ErrAbortHandler {
const size = 64 << 10
@@ -272,9 +301,6 @@ func (srv *Server) respond(conn net.Conn) {
}
}()
- srv.trackConn(&conn)
- defer srv.deleteConn(&conn)
-
if d := srv.ReadTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d))
}
@@ -282,16 +308,29 @@ func (srv *Server) respond(conn net.Conn) {
conn.SetWriteDeadline(time.Now().Add(d))
}
+ done := make(chan struct{})
+ go func() {
+ srv.respond(ctx, conn)
+ close(done)
+ }()
+
+ select {
+ case <-ctx.Done():
+ case <-done:
+ }
+}
+
+func (srv *Server) respond(ctx context.Context, conn net.Conn) {
w := newResponseWriter(conn)
+ defer w.Flush()
req, err := ReadRequest(conn)
if err != nil {
w.WriteHeader(StatusBadRequest, "Bad request")
- w.Flush()
return
}
- // Store information about the TLS connection
+ // Store the TLS connection state
if tlsConn, ok := conn.(*tls.Conn); ok {
state := tlsConn.ConnectionState()
req.TLS = &state
@@ -304,20 +343,10 @@ func (srv *Server) respond(conn net.Conn) {
h := srv.Handler
if h == nil {
w.WriteHeader(StatusNotFound, "Not found")
- w.Flush()
return
}
- ctx := srv.context()
h.ServeGemini(ctx, w, req)
- w.Flush()
-}
-
-func (srv *Server) context() context.Context {
- if srv.Context != nil {
- return srv.Context
- }
- return context.Background()
}
func (srv *Server) logf(format string, args ...interface{}) {