aboutsummaryrefslogtreecommitdiff
path: root/server.go
diff options
context:
space:
mode:
authorAdnan Maolood <[email protected]>2021-02-21 11:53:15 -0500
committerAdnan Maolood <[email protected]>2021-02-21 11:53:15 -0500
commit49dac34afff8cea8d5a4feaa9a01d763b0e5d05b (patch)
tree2b787f6beed109cbc78ae51b43e4f7095fc6c8c9 /server.go
parentserver: Don't recover from panics (diff)
downloadgo-gemini-49dac34afff8cea8d5a4feaa9a01d763b0e5d05b.tar.xz
go-gemini-49dac34afff8cea8d5a4feaa9a01d763b0e5d05b.zip
server: Export ServeConn method
Diffstat (limited to 'server.go')
-rw-r--r--server.go93
1 files changed, 50 insertions, 43 deletions
diff --git a/server.go b/server.go
index 8484eec..7574782 100644
--- a/server.go
+++ b/server.go
@@ -7,7 +7,6 @@ import (
"log"
"net"
"sync"
- "sync/atomic"
"time"
)
@@ -49,7 +48,7 @@ type Server struct {
listeners map[*net.Listener]context.CancelFunc
conns map[*net.Conn]context.CancelFunc
doneChan chan struct{}
- closed int32
+ closed bool
mu sync.Mutex
}
@@ -67,6 +66,22 @@ func (srv *Server) doneLocked() chan struct{} {
return srv.doneChan
}
+func (srv *Server) isClosed() bool {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ return srv.closed
+}
+
+func (srv *Server) tryClose() bool {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ if srv.closed {
+ return false
+ }
+ srv.closed = true
+ return true
+}
+
// tryFinishShutdown closes srv.done() if there are no active listeners or requests.
func (srv *Server) tryFinishShutdown() {
srv.mu.Lock()
@@ -84,7 +99,7 @@ func (srv *Server) tryFinishShutdown() {
// Close immediately closes all active net.Listeners and connections.
// For a graceful shutdown, use Shutdown.
func (srv *Server) Close() error {
- if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) {
+ if !srv.tryClose() {
return ErrServerClosed
}
@@ -118,7 +133,7 @@ func (srv *Server) Close() error {
// Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error {
- if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) {
+ if !srv.tryClose() {
return ErrServerClosed
}
@@ -147,7 +162,7 @@ func (srv *Server) Shutdown(ctx context.Context) error {
// ListenAndServe always returns a non-nil error. After Shutdown or Close, the
// returned error is ErrServerClosed.
func (srv *Server) ListenAndServe(ctx context.Context) error {
- if atomic.LoadInt32(&srv.closed) == 1 {
+ if srv.isClosed() {
return ErrServerClosed
}
@@ -176,13 +191,17 @@ func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, err
return srv.GetCertificate(h.ServerName)
}
-func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) {
+func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) bool {
srv.mu.Lock()
defer srv.mu.Unlock()
+ if srv.closed {
+ return false
+ }
if srv.listeners == nil {
srv.listeners = make(map[*net.Listener]context.CancelFunc)
}
srv.listeners[l] = cancel
+ return true
}
func (srv *Server) deleteListener(l *net.Listener) {
@@ -200,14 +219,12 @@ func (srv *Server) deleteListener(l *net.Listener) {
func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
defer l.Close()
- if atomic.LoadInt32(&srv.closed) == 1 {
- return ErrServerClosed
- }
-
lnctx, cancel := context.WithCancel(ctx)
defer cancel()
- srv.trackListener(&l, cancel)
+ if !srv.trackListener(&l, cancel) {
+ return ErrServerClosed
+ }
defer srv.tryFinishShutdown()
defer srv.deleteListener(&l)
@@ -218,7 +235,7 @@ func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
select {
case <-lnctx.Done():
- if atomic.LoadInt32(&srv.closed) == 1 {
+ if srv.isClosed() {
return ErrServerClosed
}
return lnctx.Err()
@@ -228,21 +245,10 @@ func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
}
func (srv *Server) serve(ctx context.Context, l net.Listener) error {
- // how long to sleep on accept failure
- var tempDelay time.Duration
-
+ var tempDelay time.Duration // how long to sleep on accept failure
for {
rw, err := l.Accept()
if err != nil {
- select {
- case <-ctx.Done():
- if atomic.LoadInt32(&srv.closed) == 1 {
- return ErrServerClosed
- }
- return ctx.Err()
- default:
- }
-
// If this is a temporary error, sleep
if ne, ok := err.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
@@ -257,22 +263,24 @@ func (srv *Server) serve(ctx context.Context, l net.Listener) error {
time.Sleep(tempDelay)
continue
}
-
return err
}
-
tempDelay = 0
- go srv.serveConn(ctx, rw)
+ go srv.ServeConn(ctx, rw)
}
}
-func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) {
+func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) bool {
srv.mu.Lock()
defer srv.mu.Unlock()
+ if srv.closed {
+ return false
+ }
if srv.conns == nil {
srv.conns = make(map[*net.Conn]context.CancelFunc)
}
srv.conns[conn] = cancel
+ return true
}
func (srv *Server) deleteConn(conn *net.Conn) {
@@ -281,19 +289,17 @@ func (srv *Server) deleteConn(conn *net.Conn) {
delete(srv.conns, conn)
}
-// serveConn serves a Gemini response over the provided connection.
+// ServeConn serves a Gemini response over the provided connection.
// It closes the connection when the response has been completed.
-func (srv *Server) serveConn(ctx context.Context, conn net.Conn) {
+func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error {
defer conn.Close()
- if atomic.LoadInt32(&srv.closed) == 1 {
- return
- }
-
ctx, cancel := context.WithCancel(ctx)
defer cancel()
- srv.trackConn(&conn, cancel)
+ if !srv.trackConn(&conn, cancel) {
+ return ErrServerClosed
+ }
defer srv.tryFinishShutdown()
defer srv.deleteConn(&conn)
@@ -304,26 +310,26 @@ func (srv *Server) serveConn(ctx context.Context, conn net.Conn) {
conn.SetWriteDeadline(time.Now().Add(d))
}
- done := make(chan struct{})
+ errch := make(chan error, 1)
go func() {
- srv.respond(ctx, conn)
- close(done)
+ errch <- srv.serveConn(ctx, conn)
}()
select {
case <-ctx.Done():
- case <-done:
+ return ctx.Err()
+ case err := <-errch:
+ return err
}
}
-func (srv *Server) respond(ctx context.Context, conn net.Conn) {
+func (srv *Server) serveConn(ctx context.Context, conn net.Conn) error {
w := newResponseWriter(conn)
- defer w.Flush()
req, err := ReadRequest(conn)
if err != nil {
w.WriteHeader(StatusBadRequest, "Bad request")
- return
+ return w.Flush()
}
// Store the TLS connection state
@@ -339,10 +345,11 @@ func (srv *Server) respond(ctx context.Context, conn net.Conn) {
h := srv.Handler
if h == nil {
w.WriteHeader(StatusNotFound, "Not found")
- return
+ return w.Flush()
}
h.ServeGemini(ctx, w, req)
+ return w.Flush()
}
func (srv *Server) logf(format string, args ...interface{}) {