aboutsummaryrefslogtreecommitdiff
path: root/server.go
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2021-02-22 21:13:42 -0500
committerAdnan Maolood <[email protected]>2021-02-22 21:13:44 -0500
commit35f79580835303d9dd11897b54ea1a2c92f74ed0 (patch)
tree71c8bff6ad6a62a5c36d0f5fcb19bb5c159c5129 /server.go
parentexamples/stream: Remove usage of Flusher (diff)
downloadarchived-go-gemini-35f79580835303d9dd11897b54ea1a2c92f74ed0.tar.xz
archived-go-gemini-35f79580835303d9dd11897b54ea1a2c92f74ed0.zip
server: Revert to closing contexts on Shutdown
Diffstat (limited to 'server.go')
-rw-r--r--server.go162
1 files changed, 89 insertions, 73 deletions
diff --git a/server.go b/server.go
index a9dbba5..9fc267f 100644
--- a/server.go
+++ b/server.go
@@ -45,11 +45,12 @@ type Server struct {
// If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger
- listeners map[*net.Listener]struct{}
- conns map[*net.Conn]struct{}
- closedChan chan struct{} // closed when the server is closed
- doneChan chan struct{} // closed when no more connections are open
- mu sync.Mutex
+ listeners map[*net.Listener]context.CancelFunc
+ conns map[*net.Conn]context.CancelFunc
+ closed bool // true if Closed or Shutdown called
+ shutdown bool // true if Shutdown called
+ doneChan chan struct{}
+ mu sync.Mutex
}
const (
@@ -58,18 +59,10 @@ const (
serverClosed
)
-// closed returns a channel that's closed when the server is closed.
-func (srv *Server) closed() chan struct{} {
+func (srv *Server) isClosed() bool {
srv.mu.Lock()
defer srv.mu.Unlock()
- return srv.closedLocked()
-}
-
-func (srv *Server) closedLocked() chan struct{} {
- if srv.closedChan == nil {
- srv.closedChan = make(chan struct{})
- }
- return srv.closedChan
+ return srv.closed
}
// done returns a channel that's closed when the server is closed and
@@ -87,14 +80,16 @@ func (srv *Server) doneLocked() chan struct{} {
return srv.doneChan
}
-// tryFinishShutdown closes srv.done() if the server is closed and
+// tryCloseDone closes srv.done() if the server is closed and
// there are no active listeners or connections.
-func (srv *Server) tryFinishShutdown() {
+func (srv *Server) tryCloseDone() {
srv.mu.Lock()
defer srv.mu.Unlock()
- select {
- case <-srv.closedLocked():
- default:
+ srv.tryCloseDoneLocked()
+}
+
+func (srv *Server) tryCloseDoneLocked() {
+ if !srv.closed {
return
}
if len(srv.listeners) == 0 && len(srv.conns) == 0 {
@@ -107,23 +102,27 @@ func (srv *Server) tryFinishShutdown() {
}
}
-// Close immediately closes all active net.Listeners and connections.
+// Close immediately closes all active net.Listeners and connections
+// by cancelling their contexts.
// For a graceful shutdown, use Shutdown.
func (srv *Server) Close() error {
- ch := srv.closed()
- select {
- case <-ch:
- return nil
- default:
- close(ch)
- }
+ srv.mu.Lock()
+ {
+ if srv.closed {
+ srv.mu.Unlock()
+ return nil
+ }
+ srv.closed = true
- srv.tryFinishShutdown()
+ srv.tryCloseDoneLocked()
- // Force all active connections to close.
- srv.mu.Lock()
- for conn := range srv.conns {
- (*conn).Close()
+ // Close all active connections and listeners.
+ for _, cancel := range srv.listeners {
+ cancel()
+ }
+ for _, cancel := range srv.conns {
+ cancel()
+ }
}
srv.mu.Unlock()
@@ -134,28 +133,34 @@ func (srv *Server) Close() error {
}
// Shutdown gracefully shuts down the server without interrupting any
-// active connections. Shutdown works by first closing all open
-// listeners and then waiting indefinitely for connections
-// to close and then shut down.
+// active connections. Shutdown works by first cancelling the contexts
+// of all open listeners and then waiting indefinitely for connections
+// to close.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error.
//
// When Shutdown is called, Serve and ListenAndServe immediately
-// return ErrServerClosed. Make sure the program doesn't exit and
-// waits instead for Shutdown to return.
+// return an error. Make sure the program doesn't exit and waits instead for
+// Shutdown to return.
//
// Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error {
- ch := srv.closed()
- select {
- case <-ch:
- return nil
- default:
- close(ch)
- }
+ srv.mu.Lock()
+ {
+ if srv.closed {
+ srv.mu.Unlock()
+ return nil
+ }
+ srv.closed = true
+ srv.shutdown = true
- srv.tryFinishShutdown()
+ // Close all active listeners.
+ for _, cancel := range srv.listeners {
+ cancel()
+ }
+ }
+ srv.mu.Unlock()
// Wait for active connections to finish.
select {
@@ -172,13 +177,13 @@ func (srv *Server) Shutdown(ctx context.Context) error {
//
// If srv.Addr is blank, ":1965" is used.
//
-// ListenAndServe always returns a non-nil error. After Shutdown or Close, the
-// returned error is ErrServerClosed.
+// ListenAndServe always returns a non-nil error.
func (srv *Server) ListenAndServe(ctx context.Context) error {
- select {
- case <-srv.closed():
- return ErrServerClosed
- default:
+ if srv.isClosed() {
+ // Cancel context
+ ctx, cancel := context.WithCancel(ctx)
+ cancel()
+ return ctx.Err()
}
addr := srv.Addr
@@ -206,18 +211,16 @@ func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, err
return srv.GetCertificate(h.ServerName)
}
-func (srv *Server) trackListener(l *net.Listener) bool {
+func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) bool {
srv.mu.Lock()
defer srv.mu.Unlock()
- select {
- case <-srv.closedLocked():
+ if srv.closed {
return false
- default:
}
if srv.listeners == nil {
- srv.listeners = make(map[*net.Listener]struct{})
+ srv.listeners = make(map[*net.Listener]context.CancelFunc)
}
- srv.listeners[l] = struct{}{}
+ srv.listeners[l] = cancel
return true
}
@@ -231,15 +234,19 @@ func (srv *Server) deleteListener(l *net.Listener) {
// service goroutine for each. The service goroutines read requests and
// then calls the appropriate Handler to reply to them.
//
-// Serve always returns a non-nil error and closes l. After Shutdown or Close,
-// the returned error is ErrServerClosed.
+// Serve always returns a non-nil error and closes l.
func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
defer l.Close()
- if !srv.trackListener(&l) {
- return ErrServerClosed
+ lnctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ if !srv.trackListener(&l, cancel) {
+ // Cancel context
+ cancel()
+ return lnctx.Err()
}
- defer srv.tryFinishShutdown()
+ defer srv.tryCloseDone()
defer srv.deleteListener(&l)
errch := make(chan error, 1)
@@ -248,12 +255,10 @@ func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
}()
select {
- case <-ctx.Done():
- return ctx.Err()
+ case <-lnctx.Done():
+ return lnctx.Err()
case err := <-errch:
return err
- case <-srv.closed():
- return ErrServerClosed
}
}
@@ -283,13 +288,17 @@ func (srv *Server) serve(ctx context.Context, l net.Listener) error {
}
}
-func (srv *Server) trackConn(conn *net.Conn) {
+func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) bool {
srv.mu.Lock()
defer srv.mu.Unlock()
+ if srv.closed && !srv.shutdown {
+ return false
+ }
if srv.conns == nil {
- srv.conns = make(map[*net.Conn]struct{})
+ srv.conns = make(map[*net.Conn]context.CancelFunc)
}
- srv.conns[conn] = struct{}{}
+ srv.conns[conn] = cancel
+ return true
}
func (srv *Server) deleteConn(conn *net.Conn) {
@@ -300,12 +309,19 @@ func (srv *Server) deleteConn(conn *net.Conn) {
// ServeConn serves a Gemini response over the provided connection.
// It closes the connection when the response has been completed.
-// ServeConn can be used even after Shutdown or Close have been called.
+// Note that ServeConn will succeed even if a call to Shutdown is ongoing.
func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error {
defer conn.Close()
- srv.trackConn(&conn)
- defer srv.tryFinishShutdown()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ if !srv.trackConn(&conn, cancel) {
+ // Cancel context
+ cancel()
+ return ctx.Err()
+ }
+ defer srv.tryCloseDone()
defer srv.deleteConn(&conn)
if d := srv.ReadTimeout; d != 0 {