aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2021-02-14 23:58:28 -0500
committerAdnan Maolood <[email protected]>2021-02-14 23:58:33 -0500
commit3f2d540579b3051d398365e34de43113053c3b70 (patch)
treecd325a09a36bcf8edf7b0004e61b55c26db56a19
parentTweak returned error for requests that are too long (diff)
downloadgo-gemini-3f2d540579b3051d398365e34de43113053c3b70.tar.xz
go-gemini-3f2d540579b3051d398365e34de43113053c3b70.zip
server: Implement Close and Shutdown methods
-rw-r--r--gemini.go9
-rw-r--r--server.go145
2 files changed, 151 insertions, 3 deletions
diff --git a/gemini.go b/gemini.go
index 86b31ae..0462dfb 100644
--- a/gemini.go
+++ b/gemini.go
@@ -11,5 +11,12 @@ var (
ErrInvalidURL = errors.New("gemini: invalid URL")
ErrInvalidRequest = errors.New("gemini: invalid request")
ErrInvalidResponse = errors.New("gemini: invalid response")
- ErrBodyNotAllowed = errors.New("gemini: response body not allowed")
+
+ // ErrBodyNotAllowed is returned by ResponseWriter.Write calls
+ // when the response status code does not permit a body.
+ ErrBodyNotAllowed = errors.New("gemini: response status code does not allow body")
+
+ // ErrServerClosed is returned by the Server's Serve and ListenAndServe
+ // methods after a call to Shutdown or Close.
+ ErrServerClosed = errors.New("gemini: server closed")
)
diff --git a/server.go b/server.go
index 806ede6..daeb097 100644
--- a/server.go
+++ b/server.go
@@ -1,11 +1,14 @@
package gemini
import (
+ "context"
"crypto/tls"
"errors"
"log"
"net"
"strings"
+ "sync"
+ "sync/atomic"
"time"
"git.sr.ht/~adnano/go-gemini/certificate"
@@ -47,6 +50,11 @@ type Server struct {
// registered handlers
handlers map[handlerKey]Handler
hosts map[string]bool
+
+ listeners map[*net.Listener]struct{}
+ conns map[*net.Conn]struct{}
+ done int32
+ mu sync.Mutex
}
type handlerKey struct {
@@ -62,6 +70,9 @@ type handlerKey struct {
// Wildcard patterns are supported (e.g. "*.example.com").
// To handle any hostname, use the wildcard pattern "*".
func (srv *Server) Handle(pattern string, handler Handler) {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+
if pattern == "" {
panic("gemini: invalid pattern")
}
@@ -101,7 +112,6 @@ func (srv *Server) HandleFunc(pattern string, handler func(ResponseWriter, *Requ
//
// If srv.Addr is blank, ":1965" is used.
//
-// TODO:
// ListenAndServe always returns a non-nil error. After Shutdown or Close, the
// returned error is ErrServerClosed.
func (srv *Server) ListenAndServe() error {
@@ -123,19 +133,45 @@ func (srv *Server) ListenAndServe() error {
}))
}
+func (srv *Server) trackListener(l *net.Listener) {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ if srv.listeners == nil {
+ srv.listeners = make(map[*net.Listener]struct{})
+ }
+ srv.listeners[l] = struct{}{}
+}
+
+func (srv *Server) deleteListener(l *net.Listener) {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ delete(srv.listeners, l)
+}
+
// Serve accepts incoming connections on the Listener l, creating a new
// service goroutine for each. The service goroutines read requests and
// then calls the appropriate Handler to reply to them.
//
-// TODO:
// Serve always returns a non-nil error and closes l. After Shutdown or Close,
// the returned error is ErrServerClosed.
func (srv *Server) Serve(l net.Listener) error {
+ defer l.Close()
+
+ if atomic.LoadInt32(&srv.done) == 1 {
+ return ErrServerClosed
+ }
+
+ srv.trackListener(&l)
+ defer srv.deleteListener(&l)
+
var tempDelay time.Duration // how long to sleep on accept failure
for {
rw, err := l.Accept()
if err != nil {
+ if atomic.LoadInt32(&srv.done) == 1 {
+ return ErrServerClosed
+ }
// If this is a temporary error, sleep
if ne, ok := err.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
@@ -160,6 +196,92 @@ func (srv *Server) Serve(l net.Listener) error {
}
}
+func (srv *Server) closeListenersLocked() error {
+ var err error
+ for ln := range srv.listeners {
+ if cerr := (*ln).Close(); cerr != nil && err == nil {
+ err = cerr
+ }
+ delete(srv.listeners, ln)
+ }
+ return err
+}
+
+// Close immediately closes all active net.Listeners and connections.
+// For a graceful shutdown, use Shutdown.
+//
+// Close returns any error returned from closing the Server's
+// underlying Listener(s).
+func (srv *Server) Close() error {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) {
+ return ErrServerClosed
+ }
+ err := srv.closeListenersLocked()
+
+ // Close active connections
+ for conn := range srv.conns {
+ (*conn).Close()
+ delete(srv.conns, conn)
+ }
+ return err
+}
+
+func (srv *Server) numConns() int {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ return len(srv.conns)
+}
+
+// shutdownPollInterval is how often we poll for quiescence
+// during Server.Shutdown. This is lower during tests, to
+// speed up tests.
+// Ideally we could find a solution that doesn't involve polling,
+// but which also doesn't have a high runtime cost (and doesn't
+// involve any contentious mutexes), but that is left as an
+// exercise for the reader.
+var shutdownPollInterval = 500 * time.Millisecond
+
+// Shutdown gracefully shuts down the server without interrupting any
+// active connections. Shutdown works by first closing all open
+// listeners and then waiting indefinitely for connections
+// to close and then shut down.
+// If the provided context expires before the shutdown is complete,
+// Shutdown returns the context's error, otherwise it returns any
+// error returned from closing the Server's underlying Listener(s).
+//
+// When Shutdown is called, Serve, ListenAndServe, and
+// ListenAndServeTLS immediately return ErrServerClosed. Make sure the
+// program doesn't exit and waits instead for Shutdown to return.
+//
+// Once Shutdown has been called on a server, it may not be reused;
+// future calls to methods such as Serve will return ErrServerClosed.
+func (srv *Server) Shutdown(ctx context.Context) error {
+ if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) {
+ return ErrServerClosed
+ }
+
+ srv.mu.Lock()
+ err := srv.closeListenersLocked()
+ srv.mu.Unlock()
+
+ // Wait for active connections to close
+ ticker := time.NewTicker(shutdownPollInterval)
+ defer ticker.Stop()
+ for {
+ if srv.numConns() == 0 {
+ return err
+ }
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-ticker.C:
+ }
+ }
+}
+
// getCertificate retrieves a certificate for the given client hello.
func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := srv.lookupCertificate(h.ServerName, h.ServerName)
@@ -207,9 +329,28 @@ func (srv *Server) lookupCertificate(pattern, hostname string) (*tls.Certificate
return &cert, nil
}
+func (srv *Server) trackConn(conn *net.Conn) {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ if srv.conns == nil {
+ srv.conns = make(map[*net.Conn]struct{})
+ }
+ srv.conns[conn] = struct{}{}
+}
+
+func (srv *Server) deleteConn(conn *net.Conn) {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ delete(srv.conns, conn)
+}
+
// respond responds to a connection.
func (srv *Server) respond(conn net.Conn) {
defer conn.Close()
+
+ srv.trackConn(&conn)
+ defer srv.deleteConn(&conn)
+
if d := srv.ReadTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d))
}