diff --git a/helpers_test.go b/helpers_test.go index bde3703..3c11a08 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -9,21 +9,21 @@ import ( "testing" ) -func newServer() *GracefulServer { - return NewWithServer(new(http.Server)) -} - // a simple step-controllable http client type client struct { tls bool addr net.Addr connected chan error sendrequest chan bool - idle chan error - idlerelease chan bool + response chan *rawResponse closed chan bool } +type rawResponse struct { + body []string + err error +} + func (c *client) Run() { go func() { var err error @@ -39,19 +39,21 @@ func (c *client) Run() { for <-c.sendrequest { _, err = conn.Write([]byte("GET / HTTP/1.1\nHost: localhost:8000\n\n")) if err != nil { - c.idle <- err + c.response <- &rawResponse{err: err} } // Read response; no content scanner := bufio.NewScanner(conn) + var lines []string for scanner.Scan() { // our null handler doesn't send a body, so we know the request is // done when we reach the blank line after the headers - if scanner.Text() == "" { + line := scanner.Text() + if line == "" { break } + lines = append(lines, line) } - c.idle <- scanner.Err() - <-c.idlerelease + c.response <- &rawResponse{lines, scanner.Err()} } conn.Close() ioutil.ReadAll(conn) @@ -65,8 +67,7 @@ func newClient(addr net.Addr, tls bool) *client { tls: tls, connected: make(chan error), sendrequest: make(chan bool), - idle: make(chan error), - idlerelease: make(chan bool), + response: make(chan *rawResponse), closed: make(chan bool), } } diff --git a/server.go b/server.go index ec724dd..dfd3b87 100644 --- a/server.go +++ b/server.go @@ -63,13 +63,19 @@ type GracefulServer struct { shutdown chan bool shutdownFinished chan bool wg waitGroup + routinesCount int - lcsmu sync.RWMutex - lastConnState map[net.Conn]http.ConnState + lcsmu sync.RWMutex + connections map[net.Conn]bool up chan net.Listener // Only used by test code. } +// NewServer creates a new GracefulServer. +func NewServer() *GracefulServer { + return NewWithServer(new(http.Server)) +} + // NewWithServer wraps an existing http.Server object and returns a // GracefulServer that supports all of the original Server operations. func NewWithServer(s *http.Server) *GracefulServer { @@ -78,7 +84,8 @@ func NewWithServer(s *http.Server) *GracefulServer { shutdown: make(chan bool), shutdownFinished: make(chan bool, 1), wg: new(sync.WaitGroup), - lastConnState: make(map[net.Conn]http.ConnState), + routinesCount: 0, + connections: make(map[net.Conn]bool), } } @@ -142,63 +149,64 @@ func (s *GracefulServer) ListenAndServeTLS(certFile, keyFile string) error { // Serve provides a graceful equivalent net/http.Server.Serve. func (s *GracefulServer) Serve(listener net.Listener) error { - var closing int32 - + // Wrap the server HTTP handler into graceful one, that will close kept + // alive connections if a new request is received after shutdown. + gracefulHandler := newGracefulHandler(s.Server.Handler) + s.Server.Handler = gracefulHandler + + // Start a goroutine that waits for a shutdown signal and will stop the + // listener when it receives the signal. That in turn will result in + // unblocking of the http.Serve call. go func() { s.shutdown <- true close(s.shutdown) - atomic.StoreInt32(&closing, 1) + gracefulHandler.Close() s.Server.SetKeepAlivesEnabled(false) listener.Close() - s.shutdownFinished <- true }() originalConnState := s.Server.ConnState - // s.ConnState is invoked by the net/http.Server every time a connectiion + // s.ConnState is invoked by the net/http.Server every time a connection // changes state. It keeps track of each connection's state over time, // enabling manners to handle persisted connections correctly. s.ConnState = func(conn net.Conn, newState http.ConnState) { s.lcsmu.RLock() - lastConnState := s.lastConnState[conn] + protected := s.connections[conn] s.lcsmu.RUnlock() switch newState { - // New connection -> StateNew case http.StateNew: + // New connection -> StateNew + protected = true s.StartRoutine() - // (StateNew, StateIdle) -> StateActive case http.StateActive: - // The connection transitioned from idle back to active - if lastConnState == http.StateIdle { - s.StartRoutine() + // (StateNew, StateIdle) -> StateActive + if gracefulHandler.IsClosed() { + conn.Close() + break } - // StateActive -> StateIdle - // Immediately close newly idle connections; if not they may make - // one more request before SetKeepAliveEnabled(false) takes effect. - case http.StateIdle: - if atomic.LoadInt32(&closing) == 1 { - conn.Close() + if !protected { + protected = true + s.StartRoutine() } - s.FinishRoutine() - // (StateNew, StateActive, StateIdle) -> (StateClosed, StateHiJacked) - // If the connection was idle we do not need to decrement the counter. - case http.StateClosed, http.StateHijacked: - if lastConnState != http.StateIdle { + default: + // (StateNew, StateActive) -> (StateIdle, StateClosed, StateHiJacked) + if protected { s.FinishRoutine() + protected = false } - } s.lcsmu.Lock() if newState == http.StateClosed || newState == http.StateHijacked { - delete(s.lastConnState, conn) + delete(s.connections, conn) } else { - s.lastConnState[conn] = newState + s.connections[conn] = protected } s.lcsmu.Unlock() @@ -214,14 +222,14 @@ func (s *GracefulServer) Serve(listener net.Listener) error { } err := s.Server.Serve(listener) - - // This block is reached when the server has received a shut down command - // or a real error happened. - if err == nil || atomic.LoadInt32(&closing) == 1 { - s.wg.Wait() - return nil + // An error returned on shutdown is not worth reporting. + if err != nil && gracefulHandler.IsClosed() { + err = nil } + // Wait for pending requests to complete regardless the Serve result. + s.wg.Wait() + s.shutdownFinished <- true return err } @@ -229,11 +237,56 @@ func (s *GracefulServer) Serve(listener net.Listener) error { // starts more goroutines and these goroutines are not guaranteed to finish // before the request. func (s *GracefulServer) StartRoutine() { + s.lcsmu.Lock() + defer s.lcsmu.Unlock() s.wg.Add(1) + s.routinesCount++ } // FinishRoutine decrements the server's WaitGroup. Use this to complement // StartRoutine(). func (s *GracefulServer) FinishRoutine() { + s.lcsmu.Lock() + defer s.lcsmu.Unlock() s.wg.Done() + s.routinesCount-- +} + +// RoutinesCount returns the number of currently running routines +func (s *GracefulServer) RoutinesCount() int { + s.lcsmu.RLock() + defer s.lcsmu.RUnlock() + return s.routinesCount +} + +// gracefulHandler is used by GracefulServer to prevent calling ServeHTTP on +// to be closed kept-alive connections during the server shutdown. +type gracefulHandler struct { + closed int32 // accessed atomically. + wrapped http.Handler +} + +func newGracefulHandler(wrapped http.Handler) *gracefulHandler { + return &gracefulHandler{ + wrapped: wrapped, + } +} + +func (gh *gracefulHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if atomic.LoadInt32(&gh.closed) == 0 { + gh.wrapped.ServeHTTP(w, r) + return + } + r.Body.Close() + // Server is shutting down at this moment, and the connection that this + // handler is being called on is about to be closed. So we do not need to + // actually execute the handler logic. +} + +func (gh *gracefulHandler) Close() { + atomic.StoreInt32(&gh.closed, 1) +} + +func (gh *gracefulHandler) IsClosed() bool { + return atomic.LoadInt32(&gh.closed) == 1 } diff --git a/server_test.go b/server_test.go index 2f54eaf..9942842 100644 --- a/server_test.go +++ b/server_test.go @@ -1,17 +1,37 @@ package manners import ( - helpers "github.com/braintree/manners/test_helpers" "net" "net/http" "testing" "time" + + helpers "github.com/braintree/manners/test_helpers" ) +type httpInterface interface { + ListenAndServe() error + ListenAndServeTLS(certFile, keyFile string) error + Serve(listener net.Listener) error +} + +// Test that the method signatures of the methods we override from net/http/Server match those of the original. +func TestInterface(t *testing.T) { + var original, ours interface{} + original = &http.Server{} + ours = &GracefulServer{} + if _, ok := original.(httpInterface); !ok { + t.Errorf("httpInterface definition does not match the canonical server!") + } + if _, ok := ours.(httpInterface); !ok { + t.Errorf("GracefulServer does not implement httpInterface") + } +} + // Tests that the server allows in-flight requests to complete // before shutting down. func TestGracefulness(t *testing.T) { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() server.wg = wg statechanged := make(chan http.ConnState) @@ -24,10 +44,9 @@ func TestGracefulness(t *testing.T) { if err := <-client.connected; err != nil { t.Fatal("Client failed to connect to server", err) } - // avoid a race between the client connection and the server accept - if state := <-statechanged; state != http.StateNew { - t.Fatal("Unexpected state", state) - } + // Even though the client is connected, the server ConnState handler may + // not know about that yet. So wait until it is called. + waitForState(t, statechanged, http.StateNew, "Request not received") server.Close() @@ -45,10 +64,22 @@ func TestGracefulness(t *testing.T) { } } +// Tests that starting the server and closing in 2 new, separate goroutines doesnot +// get flagged by the race detector (need to run 'go test' w/the -race flag) +func TestRacyClose(t *testing.T) { + go func() { + ListenAndServe(":9000", nil) + }() + + go func() { + Close() + }() +} + // Tests that the server begins to shut down when told to and does not accept // new requests once shutdown has begun func TestShutdown(t *testing.T) { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() server.wg = wg statechanged := make(chan http.ConnState) @@ -61,10 +92,9 @@ func TestShutdown(t *testing.T) { if err := <-client1.connected; err != nil { t.Fatal("Client failed to connect to server", err) } - // avoid a race between the client connection and the server accept - if state := <-statechanged; state != http.StateNew { - t.Fatal("Unexpected state", state) - } + // Even though the client is connected, the server ConnState handler may + // not know about that yet. So wait until it is called. + waitForState(t, statechanged, http.StateNew, "Request not received") // start the shutdown; once it hits waitgroup.Wait() // the listener should of been closed, though client1 is still connected @@ -94,36 +124,32 @@ func TestShutdown(t *testing.T) { <-exitchan } -// Test that a connection is closed upon reaching an idle state if and only if the server -// is shutting down. -func TestCloseOnIdle(t *testing.T) { - server := newServer() - wg := helpers.NewWaitGroup() - server.wg = wg - fl := helpers.NewListener() - runner := func() error { - return server.Serve(fl) - } +// If a request is sent to a closed server via a kept alive connection then +// the server closes the connection upon receiving the request. +func TestRequestAfterClose(t *testing.T) { + // Given + server := NewServer() + srvStateChangedCh := make(chan http.ConnState, 100) + listener, srvClosedCh := startServer(t, server, srvStateChangedCh) - startGenericServer(t, server, nil, runner) - - // Change to idle state while server is not closing; Close should not be called - conn := &helpers.Conn{} - server.ConnState(conn, http.StateIdle) - if conn.CloseCalled { - t.Error("Close was called unexpected") - } + client := newClient(listener.Addr(), false) + client.Run() + <-client.connected + client.sendrequest <- true + <-client.response server.Close() + if err := <-srvClosedCh; err != nil { + t.Error("Unexpected error during shutdown", err) + } - // wait until the server calls Close() on the listener - // by that point the atomic closing variable will have been updated, avoiding a race. - <-fl.CloseCalled + // When + client.sendrequest <- true + rr := <-client.response - conn = &helpers.Conn{} - server.ConnState(conn, http.StateIdle) - if !conn.CloseCalled { - t.Error("Close was not called") + // Then + if rr.body != nil || rr.err != nil { + t.Errorf("Request should be rejected, body=%v, err=%v", rr.body, rr.err) } } @@ -143,7 +169,7 @@ func waitForState(t *testing.T, waiter chan http.ConnState, state http.ConnState // Test that a request moving from active->idle->active using an actual // network connection still results in a corect shutdown func TestStateTransitionActiveIdleActive(t *testing.T) { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() statechanged := make(chan http.ConnState) server.wg = wg @@ -160,8 +186,7 @@ func TestStateTransitionActiveIdleActive(t *testing.T) { for i := 0; i < 2; i++ { client.sendrequest <- true waitForState(t, statechanged, http.StateActive, "Client failed to reach active state") - <-client.idle - client.idlerelease <- true + <-client.response waitForState(t, statechanged, http.StateIdle, "Client failed to reach idle state") } @@ -196,7 +221,7 @@ func TestStateTransitionActiveIdleClosed(t *testing.T) { } for _, withTLS := range []bool{false, true} { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() statechanged := make(chan http.ConnState) server.wg = wg @@ -217,12 +242,11 @@ func TestStateTransitionActiveIdleClosed(t *testing.T) { client.sendrequest <- true waitForState(t, statechanged, http.StateActive, "Client failed to reach active state") - err := <-client.idle - if err != nil { - t.Fatalf("tls=%t unexpected error from client %s", withTLS, err) + rr := <-client.response + if rr.err != nil { + t.Fatalf("tls=%t unexpected error from client %s", withTLS, rr.err) } - client.idlerelease <- true waitForState(t, statechanged, http.StateIdle, "Client failed to reach idle state") // client is now in an idle state @@ -241,3 +265,25 @@ func TestStateTransitionActiveIdleClosed(t *testing.T) { } } } + +func TestRoutinesCount(t *testing.T) { + var count int + server := NewServer() + + count = server.RoutinesCount() + if count != 0 { + t.Errorf("Expected the routines count to equal 0; actually %d", count) + } + + server.StartRoutine() + count = server.RoutinesCount() + if count != 1 { + t.Errorf("Expected the routines count to equal 1; actually %d", count) + } + + server.FinishRoutine() + count = server.RoutinesCount() + if count != 0 { + t.Errorf("Expected the routines count to equal 0; actually %d", count) + } +} diff --git a/static.go b/static.go index 2a74b09..b539506 100644 --- a/static.go +++ b/static.go @@ -3,14 +3,23 @@ package manners import ( "net" "net/http" + "sync" ) -var defaultServer *GracefulServer +var ( + defaultServer *GracefulServer + defaultServerLock = &sync.Mutex{} +) + +func init() { + defaultServerLock.Lock() +} // ListenAndServe provides a graceful version of the function provided by the // net/http package. Call Close() to stop the server. func ListenAndServe(addr string, handler http.Handler) error { defaultServer = NewWithServer(&http.Server{Addr: addr, Handler: handler}) + defaultServerLock.Unlock() return defaultServer.ListenAndServe() } @@ -18,6 +27,7 @@ func ListenAndServe(addr string, handler http.Handler) error { // net/http package. Call Close() to stop the server. func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { defaultServer = NewWithServer(&http.Server{Addr: addr, Handler: handler}) + defaultServerLock.Unlock() return defaultServer.ListenAndServeTLS(certFile, keyFile) } @@ -25,11 +35,13 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler htt // package. Call Close() to stop the server. func Serve(l net.Listener, handler http.Handler) error { defaultServer = NewWithServer(&http.Server{Handler: handler}) + defaultServerLock.Unlock() return defaultServer.Serve(l) } // Shuts down the default server used by ListenAndServe, ListenAndServeTLS and // Serve. It returns true if it's the first time Close is called. func Close() bool { + defaultServerLock.Lock() return defaultServer.Close() } diff --git a/test_helpers/listener.go b/test_helpers/listener.go index a74ac11..e3af35a 100644 --- a/test_helpers/listener.go +++ b/test_helpers/listener.go @@ -1,8 +1,8 @@ package test_helpers import ( - "net" - "errors" + "errors" + "net" ) type Listener struct { @@ -11,10 +11,10 @@ type Listener struct { } func NewListener() *Listener { - return &Listener{ - make(chan bool, 1), - make(chan bool, 1), - } + return &Listener{ + make(chan bool, 1), + make(chan bool, 1), + } } func (l *Listener) Addr() net.Addr { diff --git a/transition_test.go b/transition_test.go index 34fe5c6..5d39851 100644 --- a/transition_test.go +++ b/transition_test.go @@ -31,7 +31,7 @@ type transitionTest struct { } func testStateTransition(t *testing.T, test transitionTest) { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() server.wg = wg startServer(t, server, nil)