From db98d4b07ccb6bd189d1dd63203b0be239368490 Mon Sep 17 00:00:00 2001 From: Gyu-Ho Lee Date: Thu, 14 Nov 2013 21:52:48 -0800 Subject: [PATCH 01/33] remove hex assignment and just return the value directly from fmt.Sprintf return fmt.Sprintf("%02x", hm.Sum(nil)) --- web.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/web.go b/web.go index 41bc8ec4..1adad59a 100644 --- a/web.go +++ b/web.go @@ -117,8 +117,7 @@ func getCookieSig(key string, val []byte, timestamp string) string { hm.Write(val) hm.Write([]byte(timestamp)) - hex := fmt.Sprintf("%02x", hm.Sum(nil)) - return hex + return fmt.Sprintf("%02x", hm.Sum(nil)) } func (ctx *Context) SetSecureCookie(name string, val string, age int64) { From 2863de8211bd9ac775790b9e39a2b88436dddb52 Mon Sep 17 00:00:00 2001 From: Daniel Hernik Date: Sun, 24 Nov 2013 21:12:03 +0100 Subject: [PATCH 02/33] Added starting message to RunTLS function --- server.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server.go b/server.go index 184f01c2..bd1c2d0c 100644 --- a/server.go +++ b/server.go @@ -175,6 +175,9 @@ func (s *Server) RunTLS(addr string, config *tls.Config) error { s.initServer() mux := http.NewServeMux() mux.Handle("/", s) + + s.Logger.Printf("web.go serving %s\n", addr) + l, err := tls.Listen("tcp", addr, config) if err != nil { log.Fatal("Listen:", err) From 5ae152733435d12b8d8bc2c68886580b9530880e Mon Sep 17 00:00:00 2001 From: dbowring Date: Wed, 4 Dec 2013 12:58:30 +1100 Subject: [PATCH 03/33] Fix 'serving' message displaying wrong address when using 0 as port --- server.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index 184f01c2..81b8b08b 100644 --- a/server.go +++ b/server.go @@ -145,12 +145,13 @@ func (s *Server) Run(addr string) { } mux.Handle("/", s) - s.Logger.Printf("web.go serving %s\n", addr) - l, err := net.Listen("tcp", addr) if err != nil { log.Fatal("ListenAndServe:", err) } + + s.Logger.Printf("web.go serving %s\n", l.Addr().String()) + s.l = l err = http.Serve(s.l, mux) s.l.Close() From d48916f617e92aed673cfe9d7e9d7b92d74e2fd8 Mon Sep 17 00:00:00 2001 From: mashuai Date: Thu, 24 Sep 2015 17:52:00 +0800 Subject: [PATCH 04/33] change 4 to runtime.NumCPU --- web_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web_test.go b/web_test.go index 609ca509..9f6f1b65 100644 --- a/web_test.go +++ b/web_test.go @@ -18,7 +18,7 @@ import ( ) func init() { - runtime.GOMAXPROCS(4) + runtime.GOMAXPROCS(runtime.NumCPU()) } // ioBuffer is a helper that implements io.ReadWriteCloser, From 81a3e63a574c9de77045a30d73fab7c5d7095b48 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 27 Jan 2016 17:37:42 -0600 Subject: [PATCH 05/33] replaced reference to depricated code.google.com --- server.go | 32 ++++++++++++++++---------------- web.go | 2 +- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/server.go b/server.go index 184f01c2..6ab8e802 100644 --- a/server.go +++ b/server.go @@ -1,22 +1,22 @@ package web import ( - "bytes" - "code.google.com/p/go.net/websocket" - "crypto/tls" - "fmt" - "log" - "net" - "net/http" - "net/http/pprof" - "os" - "path" - "reflect" - "regexp" - "runtime" - "strconv" - "strings" - "time" + "bytes" + "crypto/tls" + "fmt" + "golang.org/x/net/websocket" + "log" + "net" + "net/http" + "net/http/pprof" + "os" + "path" + "reflect" + "regexp" + "runtime" + "strconv" + "strings" + "time" ) // ServerConfig is configuration for server objects. diff --git a/web.go b/web.go index 41bc8ec4..7e09a8b9 100644 --- a/web.go +++ b/web.go @@ -4,7 +4,7 @@ package web import ( "bytes" - "code.google.com/p/go.net/websocket" + "golang.org/x/net/websocket" "crypto/hmac" "crypto/sha1" "crypto/tls" From e52eb83f06cc7c8673f6eab77012ae3bd1bc398d Mon Sep 17 00:00:00 2001 From: Sheldon Rupp Date: Thu, 21 Apr 2016 01:38:24 +0200 Subject: [PATCH 06/33] Change HTTP to HTTPS --- examples/tls.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tls.go b/examples/tls.go index 7517fe61..1a4fa5aa 100644 --- a/examples/tls.go +++ b/examples/tls.go @@ -62,7 +62,7 @@ func main() { return } - // you must access the server with an HTTP address, i.e https://localhost:9999/world + // you must access the server with an HTTPS address, i.e https://localhost:9999/world web.Get("/(.*)", hello) web.RunTLS("0.0.0.0:9999", &config) } From e5d74875471b0b313cdd77d3996ba2ae5712b82f Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Thu, 28 Jul 2016 13:33:39 -0700 Subject: [PATCH 07/33] Fix links in Readme.md The webgo.io domain name has expired. Change the docs link to Github Pages. Also, remove my home page link. --- Readme.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Readme.md b/Readme.md index 9750f4df..2c05fe9d 100644 --- a/Readme.md +++ b/Readme.md @@ -77,12 +77,12 @@ In this example, if you visit `http://localhost:9999/?a=1&b=2`, you'll see the f ## Documentation -API docs are hosted at http://webgo.io +API docs are hosted at https://hoisie.github.io/web/ If you use web.go, I'd greatly appreciate a quick message about what you're building with it. This will help me get a sense of usage patterns, and helps me focus development efforts on features that people will actually use. ## About -web.go was written by [Michael Hoisie](http://hoisie.com). +web.go was written by Michael Hoisie From 9b25872c812bb12b68b6e8c47cc5987a93a97808 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Fri, 29 Jul 2016 10:30:11 -0700 Subject: [PATCH 08/33] Add drone.io build status badge --- Readme.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Readme.md b/Readme.md index 2c05fe9d..93ce684d 100644 --- a/Readme.md +++ b/Readme.md @@ -1,3 +1,5 @@ +[![Build Status](https://drone.io/github.com/hoisie/web/status.png)](https://drone.io/github.com/hoisie/web/latest) + # web.go web.go is the simplest way to write web applications in the Go programming language. It's ideal for writing simple, performant backend web services. From 2ba703e1b81f80c8f604cf15127b94581290be78 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Fri, 29 Jul 2016 11:27:40 -0700 Subject: [PATCH 09/33] Clean up initial log statements 1. In RunTLS, only print the initial log statement if `Listen` is successful. Also, print the actual address instead of the one passed in. 2. Change `l.Addr().String()` to `l.Addr()`. The Addr interface has a `String` method, which will be called by fmt. --- server.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server.go b/server.go index a53d46d9..93dd9ba1 100644 --- a/server.go +++ b/server.go @@ -150,7 +150,7 @@ func (s *Server) Run(addr string) { log.Fatal("ListenAndServe:", err) } - s.Logger.Printf("web.go serving %s\n", l.Addr().String()) + s.Logger.Printf("web.go serving %s\n", l.Addr()) s.l = l err = http.Serve(s.l, mux) @@ -177,13 +177,12 @@ func (s *Server) RunTLS(addr string, config *tls.Config) error { mux := http.NewServeMux() mux.Handle("/", s) - s.Logger.Printf("web.go serving %s\n", addr) - l, err := tls.Listen("tcp", addr, config) if err != nil { log.Fatal("Listen:", err) return err } + s.Logger.Printf("web.go serving %s\n", l.Addr()) s.l = l return http.Serve(s.l, mux) From 2de8b2a91a52bbcf11f14fdd1c935d792a23268e Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Fri, 29 Jul 2016 11:03:08 -0700 Subject: [PATCH 10/33] Use standard tab indentation for web.go source files This is a whitespace change only. Run all the source files through `gofmt`. Previously, all web.go files were indented using four spaces. Convert all the spaces into tabs. Also, remove `Makefile`, which contained the old formatting command. --- Makefile | 6 - examples/arcchallenge.go | 44 +- examples/cookie.go | 42 +- examples/hello.go | 6 +- examples/logger.go | 24 +- examples/multipart.go | 52 +-- examples/multiserver.go | 16 +- examples/params.go | 12 +- examples/streaming.go | 28 +- examples/tls.go | 30 +- fcgi.go | 34 +- helpers.go | 142 +++--- scgi.go | 268 +++++------ server.go | 536 +++++++++++----------- status.go | 80 ++-- web.go | 242 +++++----- web_test.go | 954 +++++++++++++++++++-------------------- 17 files changed, 1255 insertions(+), 1261 deletions(-) delete mode 100644 Makefile diff --git a/Makefile b/Makefile deleted file mode 100644 index 25b06108..00000000 --- a/Makefile +++ /dev/null @@ -1,6 +0,0 @@ -GOFMT=gofmt -s -tabs=false -tabwidth=4 - -GOFILES=$(wildcard *.go **/*.go) - -format: - ${GOFMT} -w ${GOFILES} diff --git a/examples/arcchallenge.go b/examples/arcchallenge.go index 1de600c2..ea9feb4c 100644 --- a/examples/arcchallenge.go +++ b/examples/arcchallenge.go @@ -1,10 +1,10 @@ package main import ( - "fmt" - "github.com/hoisie/web" - "math/rand" - "time" + "fmt" + "github.com/hoisie/web" + "math/rand" + "time" ) var form = `
` @@ -12,22 +12,22 @@ var form = `
Click Here` - }) - web.Get("/final", func(ctx *web.Context) string { - uid, _ := ctx.GetSecureCookie("user") - return "You said " + users[uid] - }) - web.Run("0.0.0.0:9999") + rand.Seed(time.Now().UnixNano()) + web.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" + web.Get("/", func(ctx *web.Context) string { + ctx.Redirect(302, "/said") + return "" + }) + web.Get("/said", func() string { return form }) + web.Post("/say", func(ctx *web.Context) string { + uid := fmt.Sprintf("%d\n", rand.Int63()) + ctx.SetSecureCookie("user", uid, 3600) + users[uid] = ctx.Params["said"] + return `Click Here` + }) + web.Get("/final", func(ctx *web.Context) string { + uid, _ := ctx.GetSecureCookie("user") + return "You said " + users[uid] + }) + web.Run("0.0.0.0:9999") } diff --git a/examples/cookie.go b/examples/cookie.go index e686c767..a4bb8a3d 100644 --- a/examples/cookie.go +++ b/examples/cookie.go @@ -1,9 +1,9 @@ package main import ( - "fmt" - "github.com/hoisie/web" - "html" + "fmt" + "github.com/hoisie/web" + "html" ) var cookieName = "cookie" @@ -24,28 +24,28 @@ var form = ` ` func index(ctx *web.Context) string { - cookie, _ := ctx.Request.Cookie(cookieName) - var top string - if cookie == nil { - top = fmt.Sprintf(notice, "The cookie has not been set") - } else { - var val = html.EscapeString(cookie.Value) - top = fmt.Sprintf(notice, "The value of the cookie is '"+val+"'.") - } - return top + form + cookie, _ := ctx.Request.Cookie(cookieName) + var top string + if cookie == nil { + top = fmt.Sprintf(notice, "The cookie has not been set") + } else { + var val = html.EscapeString(cookie.Value) + top = fmt.Sprintf(notice, "The value of the cookie is '"+val+"'.") + } + return top + form } func update(ctx *web.Context) { - if ctx.Params["submit"] == "Delete" { - ctx.SetCookie(web.NewCookie(cookieName, "", -1)) - } else { - ctx.SetCookie(web.NewCookie(cookieName, ctx.Params["cookie"], 0)) - } - ctx.Redirect(301, "/") + if ctx.Params["submit"] == "Delete" { + ctx.SetCookie(web.NewCookie(cookieName, "", -1)) + } else { + ctx.SetCookie(web.NewCookie(cookieName, ctx.Params["cookie"], 0)) + } + ctx.Redirect(301, "/") } func main() { - web.Get("/", index) - web.Post("/update", update) - web.Run("0.0.0.0:9999") + web.Get("/", index) + web.Post("/update", update) + web.Run("0.0.0.0:9999") } diff --git a/examples/hello.go b/examples/hello.go index d502391a..29bf8d10 100644 --- a/examples/hello.go +++ b/examples/hello.go @@ -1,12 +1,12 @@ package main import ( - "github.com/hoisie/web" + "github.com/hoisie/web" ) func hello(val string) string { return "hello " + val } func main() { - web.Get("/(.*)", hello) - web.Run("0.0.0.0:9999") + web.Get("/(.*)", hello) + web.Run("0.0.0.0:9999") } diff --git a/examples/logger.go b/examples/logger.go index f684261a..e4b134a3 100644 --- a/examples/logger.go +++ b/examples/logger.go @@ -1,21 +1,21 @@ package main import ( - "github.com/hoisie/web" - "log" - "os" + "github.com/hoisie/web" + "log" + "os" ) func hello(val string) string { return "hello " + val } func main() { - f, err := os.Create("server.log") - if err != nil { - println(err.Error()) - return - } - logger := log.New(f, "", log.Ldate|log.Ltime) - web.Get("/(.*)", hello) - web.SetLogger(logger) - web.Run("0.0.0.0:9999") + f, err := os.Create("server.log") + if err != nil { + println(err.Error()) + return + } + logger := log.New(f, "", log.Ldate|log.Ltime) + web.Get("/(.*)", hello) + web.SetLogger(logger) + web.Run("0.0.0.0:9999") } diff --git a/examples/multipart.go b/examples/multipart.go index d89e514e..68df57e8 100644 --- a/examples/multipart.go +++ b/examples/multipart.go @@ -1,17 +1,17 @@ package main import ( - "bytes" - "crypto/md5" - "fmt" - "github.com/hoisie/web" - "io" + "bytes" + "crypto/md5" + "fmt" + "github.com/hoisie/web" + "io" ) func Md5(r io.Reader) string { - hash := md5.New() - io.Copy(hash, r) - return fmt.Sprintf("%x", hash.Sum(nil)) + hash := md5.New() + io.Copy(hash, r) + return fmt.Sprintf("%x", hash.Sum(nil)) } var page = ` @@ -38,25 +38,25 @@ var page = ` func index() string { return page } func multipart(ctx *web.Context) string { - ctx.Request.ParseMultipartForm(10 * 1024 * 1024) - form := ctx.Request.MultipartForm - var output bytes.Buffer - output.WriteString("

input1: " + form.Value["input1"][0] + "

") - output.WriteString("

input2: " + form.Value["input2"][0] + "

") - - fileHeader := form.File["file"][0] - filename := fileHeader.Filename - file, err := fileHeader.Open() - if err != nil { - return err.Error() - } - - output.WriteString("

file: " + filename + " " + Md5(file) + "

") - return output.String() + ctx.Request.ParseMultipartForm(10 * 1024 * 1024) + form := ctx.Request.MultipartForm + var output bytes.Buffer + output.WriteString("

input1: " + form.Value["input1"][0] + "

") + output.WriteString("

input2: " + form.Value["input2"][0] + "

") + + fileHeader := form.File["file"][0] + filename := fileHeader.Filename + file, err := fileHeader.Open() + if err != nil { + return err.Error() + } + + output.WriteString("

file: " + filename + " " + Md5(file) + "

") + return output.String() } func main() { - web.Get("/", index) - web.Post("/multipart", multipart) - web.Run("0.0.0.0:9999") + web.Get("/", index) + web.Post("/multipart", multipart) + web.Run("0.0.0.0:9999") } diff --git a/examples/multiserver.go b/examples/multiserver.go index 354b563e..4fc479ed 100644 --- a/examples/multiserver.go +++ b/examples/multiserver.go @@ -1,7 +1,7 @@ package main import ( - "github.com/hoisie/web" + "github.com/hoisie/web" ) func hello1(val string) string { return "hello1 " + val } @@ -9,12 +9,12 @@ func hello1(val string) string { return "hello1 " + val } func hello2(val string) string { return "hello2 " + val } func main() { - var server1 web.Server - var server2 web.Server + var server1 web.Server + var server2 web.Server - server1.Get("/(.*)", hello1) - go server1.Run("0.0.0.0:9999") - server2.Get("/(.*)", hello2) - go server2.Run("0.0.0.0:8999") - <-make(chan int) + server1.Get("/(.*)", hello1) + go server1.Run("0.0.0.0:9999") + server2.Get("/(.*)", hello2) + go server2.Run("0.0.0.0:8999") + <-make(chan int) } diff --git a/examples/params.go b/examples/params.go index dc396a1b..ac70654d 100644 --- a/examples/params.go +++ b/examples/params.go @@ -1,8 +1,8 @@ package main import ( - "fmt" - "github.com/hoisie/web" + "fmt" + "github.com/hoisie/web" ) var page = ` @@ -33,11 +33,11 @@ var page = ` func index() string { return page } func process(ctx *web.Context) string { - return fmt.Sprintf("%v\n", ctx.Params) + return fmt.Sprintf("%v\n", ctx.Params) } func main() { - web.Get("/", index) - web.Post("/process", process) - web.Run("0.0.0.0:9999") + web.Get("/", index) + web.Post("/process", process) + web.Run("0.0.0.0:9999") } diff --git a/examples/streaming.go b/examples/streaming.go index e17ab372..2a5b948f 100644 --- a/examples/streaming.go +++ b/examples/streaming.go @@ -1,24 +1,24 @@ package main import ( - "github.com/hoisie/web" - "net/http" - "strconv" - "time" + "github.com/hoisie/web" + "net/http" + "strconv" + "time" ) func hello(ctx *web.Context, num string) { - flusher, _ := ctx.ResponseWriter.(http.Flusher) - flusher.Flush() - n, _ := strconv.ParseInt(num, 10, 64) - for i := int64(0); i < n; i++ { - ctx.WriteString("
hello world
") - flusher.Flush() - time.Sleep(1e9) - } + flusher, _ := ctx.ResponseWriter.(http.Flusher) + flusher.Flush() + n, _ := strconv.ParseInt(num, 10, 64) + for i := int64(0); i < n; i++ { + ctx.WriteString("
hello world
") + flusher.Flush() + time.Sleep(1e9) + } } func main() { - web.Get("/([0-9]+)", hello) - web.Run("0.0.0.0:9999") + web.Get("/([0-9]+)", hello) + web.Run("0.0.0.0:9999") } diff --git a/examples/tls.go b/examples/tls.go index 1a4fa5aa..dc189eac 100644 --- a/examples/tls.go +++ b/examples/tls.go @@ -1,8 +1,8 @@ package main import ( - "crypto/tls" - "github.com/hoisie/web" + "crypto/tls" + "github.com/hoisie/web" ) // an arbitrary self-signed certificate, generated with @@ -50,19 +50,19 @@ gWrxykqyLToIiAuL+pvC3Jv8IOPIiVFsY032rOqcwSGdVUyhTsG28+7KnR6744tM func hello(val string) string { return "hello " + val } func main() { - config := tls.Config{ - Time: nil, - } + config := tls.Config{ + Time: nil, + } - config.Certificates = make([]tls.Certificate, 1) - var err error - config.Certificates[0], err = tls.X509KeyPair([]byte(cert), []byte(pkey)) - if err != nil { - println(err.Error()) - return - } + config.Certificates = make([]tls.Certificate, 1) + var err error + config.Certificates[0], err = tls.X509KeyPair([]byte(cert), []byte(pkey)) + if err != nil { + println(err.Error()) + return + } - // you must access the server with an HTTPS address, i.e https://localhost:9999/world - web.Get("/(.*)", hello) - web.RunTLS("0.0.0.0:9999", &config) + // you must access the server with an HTTPS address, i.e https://localhost:9999/world + web.Get("/(.*)", hello) + web.RunTLS("0.0.0.0:9999", &config) } diff --git a/fcgi.go b/fcgi.go index fa380dc9..017b8024 100644 --- a/fcgi.go +++ b/fcgi.go @@ -1,27 +1,27 @@ package web import ( - "net" - "net/http/fcgi" + "net" + "net/http/fcgi" ) func (s *Server) listenAndServeFcgi(addr string) error { - var l net.Listener - var err error + var l net.Listener + var err error - //if the path begins with a "/", assume it's a unix address - if addr[0] == '/' { - l, err = net.Listen("unix", addr) - } else { - l, err = net.Listen("tcp", addr) - } + //if the path begins with a "/", assume it's a unix address + if addr[0] == '/' { + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } - //save the listener so it can be closed - s.l = l + //save the listener so it can be closed + s.l = l - if err != nil { - s.Logger.Println("FCGI listen error", err.Error()) - return err - } - return fcgi.Serve(s.l, s) + if err != nil { + s.Logger.Println("FCGI listen error", err.Error()) + return err + } + return fcgi.Serve(s.l, s) } diff --git a/helpers.go b/helpers.go index 93c8d5ad..e77d87f2 100644 --- a/helpers.go +++ b/helpers.go @@ -1,59 +1,59 @@ package web import ( - "bytes" - "encoding/base64" - "errors" - "net/http" - "net/url" - "os" - "regexp" - "strings" - "time" + "bytes" + "encoding/base64" + "errors" + "net/http" + "net/url" + "os" + "regexp" + "strings" + "time" ) // internal utility methods func webTime(t time.Time) string { - ftime := t.Format(time.RFC1123) - if strings.HasSuffix(ftime, "UTC") { - ftime = ftime[0:len(ftime)-3] + "GMT" - } - return ftime + ftime := t.Format(time.RFC1123) + if strings.HasSuffix(ftime, "UTC") { + ftime = ftime[0:len(ftime)-3] + "GMT" + } + return ftime } func dirExists(dir string) bool { - d, e := os.Stat(dir) - switch { - case e != nil: - return false - case !d.IsDir(): - return false - } + d, e := os.Stat(dir) + switch { + case e != nil: + return false + case !d.IsDir(): + return false + } - return true + return true } func fileExists(dir string) bool { - info, err := os.Stat(dir) - if err != nil { - return false - } + info, err := os.Stat(dir) + if err != nil { + return false + } - return !info.IsDir() + return !info.IsDir() } // Urlencode is a helper method that converts a map into URL-encoded form data. // It is a useful when constructing HTTP POST requests. func Urlencode(data map[string]string) string { - var buf bytes.Buffer - for k, v := range data { - buf.WriteString(url.QueryEscape(k)) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(v)) - buf.WriteByte('&') - } - s := buf.String() - return s[0 : len(s)-1] + var buf bytes.Buffer + for k, v := range data { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(v)) + buf.WriteByte('&') + } + s := buf.String() + return s[0 : len(s)-1] } var slugRegex = regexp.MustCompile(`(?i:[^a-z0-9\-_])`) @@ -62,50 +62,50 @@ var slugRegex = regexp.MustCompile(`(?i:[^a-z0-9\-_])`) // It's used to return clean, URL-friendly strings that can be // used in routing. func Slug(s string, sep string) string { - if s == "" { - return "" - } - slug := slugRegex.ReplaceAllString(s, sep) - if slug == "" { - return "" - } - quoted := regexp.QuoteMeta(sep) - sepRegex := regexp.MustCompile("(" + quoted + "){2,}") - slug = sepRegex.ReplaceAllString(slug, sep) - sepEnds := regexp.MustCompile("^" + quoted + "|" + quoted + "$") - slug = sepEnds.ReplaceAllString(slug, "") - return strings.ToLower(slug) + if s == "" { + return "" + } + slug := slugRegex.ReplaceAllString(s, sep) + if slug == "" { + return "" + } + quoted := regexp.QuoteMeta(sep) + sepRegex := regexp.MustCompile("(" + quoted + "){2,}") + slug = sepRegex.ReplaceAllString(slug, sep) + sepEnds := regexp.MustCompile("^" + quoted + "|" + quoted + "$") + slug = sepEnds.ReplaceAllString(slug, "") + return strings.ToLower(slug) } // NewCookie is a helper method that returns a new http.Cookie object. // Duration is specified in seconds. If the duration is zero, the cookie is permanent. // This can be used in conjunction with ctx.SetCookie. func NewCookie(name string, value string, age int64) *http.Cookie { - var utctime time.Time - if age == 0 { - // 2^31 - 1 seconds (roughly 2038) - utctime = time.Unix(2147483647, 0) - } else { - utctime = time.Unix(time.Now().Unix()+age, 0) - } - return &http.Cookie{Name: name, Value: value, Expires: utctime} + var utctime time.Time + if age == 0 { + // 2^31 - 1 seconds (roughly 2038) + utctime = time.Unix(2147483647, 0) + } else { + utctime = time.Unix(time.Now().Unix()+age, 0) + } + return &http.Cookie{Name: name, Value: value, Expires: utctime} } // GetBasicAuth is a helper method of *Context that returns the decoded // user and password from the *Context's authorization header func (ctx *Context) GetBasicAuth() (string, string, error) { - authHeader := ctx.Request.Header["Authorization"][0] - authString := strings.Split(string(authHeader), " ") - if authString[0] != "Basic" { - return "", "", errors.New("Not Basic Authentication") - } - decodedAuth, err := base64.StdEncoding.DecodeString(authString[1]) - if err != nil { - return "", "", err - } - authSlice := strings.Split(string(decodedAuth), ":") - if len(authSlice) != 2 { - return "", "", errors.New("Error delimiting authString into username/password. Malformed input: " + authString[1]) - } - return authSlice[0], authSlice[1], nil + authHeader := ctx.Request.Header["Authorization"][0] + authString := strings.Split(string(authHeader), " ") + if authString[0] != "Basic" { + return "", "", errors.New("Not Basic Authentication") + } + decodedAuth, err := base64.StdEncoding.DecodeString(authString[1]) + if err != nil { + return "", "", err + } + authSlice := strings.Split(string(decodedAuth), ":") + if len(authSlice) != 2 { + return "", "", errors.New("Error delimiting authString into username/password. Malformed input: " + authString[1]) + } + return authSlice[0], authSlice[1], nil } diff --git a/scgi.go b/scgi.go index 45d3aec5..7d0dc535 100644 --- a/scgi.go +++ b/scgi.go @@ -1,180 +1,180 @@ package web import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/http/cgi" - "strconv" - "strings" + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/cgi" + "strconv" + "strings" ) type scgiBody struct { - reader io.Reader - conn io.ReadWriteCloser - closed bool + reader io.Reader + conn io.ReadWriteCloser + closed bool } func (b *scgiBody) Read(p []byte) (n int, err error) { - if b.closed { - return 0, errors.New("SCGI read after close") - } - return b.reader.Read(p) + if b.closed { + return 0, errors.New("SCGI read after close") + } + return b.reader.Read(p) } func (b *scgiBody) Close() error { - b.closed = true - return b.conn.Close() + b.closed = true + return b.conn.Close() } type scgiConn struct { - fd io.ReadWriteCloser - req *http.Request - headers http.Header - wroteHeaders bool + fd io.ReadWriteCloser + req *http.Request + headers http.Header + wroteHeaders bool } func (conn *scgiConn) WriteHeader(status int) { - if !conn.wroteHeaders { - conn.wroteHeaders = true + if !conn.wroteHeaders { + conn.wroteHeaders = true - var buf bytes.Buffer - text := statusText[status] + var buf bytes.Buffer + text := statusText[status] - fmt.Fprintf(&buf, "HTTP/1.1 %d %s\r\n", status, text) + fmt.Fprintf(&buf, "HTTP/1.1 %d %s\r\n", status, text) - for k, v := range conn.headers { - for _, i := range v { - buf.WriteString(k + ": " + i + "\r\n") - } - } + for k, v := range conn.headers { + for _, i := range v { + buf.WriteString(k + ": " + i + "\r\n") + } + } - buf.WriteString("\r\n") - conn.fd.Write(buf.Bytes()) - } + buf.WriteString("\r\n") + conn.fd.Write(buf.Bytes()) + } } func (conn *scgiConn) Header() http.Header { - return conn.headers + return conn.headers } func (conn *scgiConn) Write(data []byte) (n int, err error) { - if !conn.wroteHeaders { - conn.WriteHeader(200) - } + if !conn.wroteHeaders { + conn.WriteHeader(200) + } - if conn.req.Method == "HEAD" { - return 0, errors.New("Body Not Allowed") - } + if conn.req.Method == "HEAD" { + return 0, errors.New("Body Not Allowed") + } - return conn.fd.Write(data) + return conn.fd.Write(data) } func (conn *scgiConn) Close() { conn.fd.Close() } func (conn *scgiConn) finishRequest() error { - var buf bytes.Buffer - if !conn.wroteHeaders { - conn.wroteHeaders = true - - for k, v := range conn.headers { - for _, i := range v { - buf.WriteString(k + ": " + i + "\r\n") - } - } - - buf.WriteString("\r\n") - conn.fd.Write(buf.Bytes()) - } - return nil + var buf bytes.Buffer + if !conn.wroteHeaders { + conn.wroteHeaders = true + + for k, v := range conn.headers { + for _, i := range v { + buf.WriteString(k + ": " + i + "\r\n") + } + } + + buf.WriteString("\r\n") + conn.fd.Write(buf.Bytes()) + } + return nil } func (s *Server) readScgiRequest(fd io.ReadWriteCloser) (*http.Request, error) { - reader := bufio.NewReader(fd) - line, err := reader.ReadString(':') - if err != nil { - s.Logger.Println("Error during SCGI read: ", err.Error()) - } - length, _ := strconv.Atoi(line[0 : len(line)-1]) - if length > 16384 { - s.Logger.Println("Error: max header size is 16k") - } - headerData := make([]byte, length) - _, err = reader.Read(headerData) - if err != nil { - return nil, err - } - - b, err := reader.ReadByte() - if err != nil { - return nil, err - } - // discard the trailing comma - if b != ',' { - return nil, errors.New("SCGI protocol error: missing comma") - } - headerList := bytes.Split(headerData, []byte{0}) - headers := map[string]string{} - for i := 0; i < len(headerList)-1; i += 2 { - headers[string(headerList[i])] = string(headerList[i+1]) - } - httpReq, err := cgi.RequestFromMap(headers) - if err != nil { - return nil, err - } - if httpReq.ContentLength > 0 { - httpReq.Body = &scgiBody{ - reader: io.LimitReader(reader, httpReq.ContentLength), - conn: fd, - } - } else { - httpReq.Body = &scgiBody{reader: reader, conn: fd} - } - return httpReq, nil + reader := bufio.NewReader(fd) + line, err := reader.ReadString(':') + if err != nil { + s.Logger.Println("Error during SCGI read: ", err.Error()) + } + length, _ := strconv.Atoi(line[0 : len(line)-1]) + if length > 16384 { + s.Logger.Println("Error: max header size is 16k") + } + headerData := make([]byte, length) + _, err = reader.Read(headerData) + if err != nil { + return nil, err + } + + b, err := reader.ReadByte() + if err != nil { + return nil, err + } + // discard the trailing comma + if b != ',' { + return nil, errors.New("SCGI protocol error: missing comma") + } + headerList := bytes.Split(headerData, []byte{0}) + headers := map[string]string{} + for i := 0; i < len(headerList)-1; i += 2 { + headers[string(headerList[i])] = string(headerList[i+1]) + } + httpReq, err := cgi.RequestFromMap(headers) + if err != nil { + return nil, err + } + if httpReq.ContentLength > 0 { + httpReq.Body = &scgiBody{ + reader: io.LimitReader(reader, httpReq.ContentLength), + conn: fd, + } + } else { + httpReq.Body = &scgiBody{reader: reader, conn: fd} + } + return httpReq, nil } func (s *Server) handleScgiRequest(fd io.ReadWriteCloser) { - req, err := s.readScgiRequest(fd) - if err != nil { - s.Logger.Println("SCGI error: %q", err.Error()) - } - sc := scgiConn{fd, req, make(map[string][]string), false} - s.routeHandler(req, &sc) - sc.finishRequest() - fd.Close() + req, err := s.readScgiRequest(fd) + if err != nil { + s.Logger.Println("SCGI error: %q", err.Error()) + } + sc := scgiConn{fd, req, make(map[string][]string), false} + s.routeHandler(req, &sc) + sc.finishRequest() + fd.Close() } func (s *Server) listenAndServeScgi(addr string) error { - var l net.Listener - var err error - - //if the path begins with a "/", assume it's a unix address - if strings.HasPrefix(addr, "/") { - l, err = net.Listen("unix", addr) - } else { - l, err = net.Listen("tcp", addr) - } - - //save the listener so it can be closed - s.l = l - - if err != nil { - s.Logger.Println("SCGI listen error", err.Error()) - return err - } - - for { - fd, err := l.Accept() - if err != nil { - s.Logger.Println("SCGI accept error", err.Error()) - return err - } - go s.handleScgiRequest(fd) - } - return nil + var l net.Listener + var err error + + //if the path begins with a "/", assume it's a unix address + if strings.HasPrefix(addr, "/") { + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } + + //save the listener so it can be closed + s.l = l + + if err != nil { + s.Logger.Println("SCGI listen error", err.Error()) + return err + } + + for { + fd, err := l.Accept() + if err != nil { + s.Logger.Println("SCGI accept error", err.Error()) + return err + } + go s.handleScgiRequest(fd) + } + return nil } diff --git a/server.go b/server.go index 93dd9ba1..e84c674b 100644 --- a/server.go +++ b/server.go @@ -21,223 +21,223 @@ import ( // ServerConfig is configuration for server objects. type ServerConfig struct { - StaticDir string - Addr string - Port int - CookieSecret string - RecoverPanic bool - Profiler bool + StaticDir string + Addr string + Port int + CookieSecret string + RecoverPanic bool + Profiler bool } // Server represents a web.go server. type Server struct { - Config *ServerConfig - routes []route - Logger *log.Logger - Env map[string]interface{} - //save the listener so it can be closed - l net.Listener + Config *ServerConfig + routes []route + Logger *log.Logger + Env map[string]interface{} + //save the listener so it can be closed + l net.Listener } func NewServer() *Server { - return &Server{ - Config: Config, - Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime), - Env: map[string]interface{}{}, - } + return &Server{ + Config: Config, + Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime), + Env: map[string]interface{}{}, + } } func (s *Server) initServer() { - if s.Config == nil { - s.Config = &ServerConfig{} - } + if s.Config == nil { + s.Config = &ServerConfig{} + } - if s.Logger == nil { - s.Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime) - } + if s.Logger == nil { + s.Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime) + } } type route struct { - r string - cr *regexp.Regexp - method string - handler reflect.Value - httpHandler http.Handler + r string + cr *regexp.Regexp + method string + handler reflect.Value + httpHandler http.Handler } func (s *Server) addRoute(r string, method string, handler interface{}) { - cr, err := regexp.Compile(r) - if err != nil { - s.Logger.Printf("Error in route regex %q\n", r) - return - } - - switch handler.(type) { - case http.Handler: - s.routes = append(s.routes, route{r: r, cr: cr, method: method, httpHandler: handler.(http.Handler)}) - case reflect.Value: - fv := handler.(reflect.Value) - s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv}) - default: - fv := reflect.ValueOf(handler) - s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv}) - } + cr, err := regexp.Compile(r) + if err != nil { + s.Logger.Printf("Error in route regex %q\n", r) + return + } + + switch handler.(type) { + case http.Handler: + s.routes = append(s.routes, route{r: r, cr: cr, method: method, httpHandler: handler.(http.Handler)}) + case reflect.Value: + fv := handler.(reflect.Value) + s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv}) + default: + fv := reflect.ValueOf(handler) + s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv}) + } } // ServeHTTP is the interface method for Go's http server package func (s *Server) ServeHTTP(c http.ResponseWriter, req *http.Request) { - s.Process(c, req) + s.Process(c, req) } // Process invokes the routing system for server s func (s *Server) Process(c http.ResponseWriter, req *http.Request) { - route := s.routeHandler(req, c) - if route != nil { - route.httpHandler.ServeHTTP(c, req) - } + route := s.routeHandler(req, c) + if route != nil { + route.httpHandler.ServeHTTP(c, req) + } } // Get adds a handler for the 'GET' http method for server s. func (s *Server) Get(route string, handler interface{}) { - s.addRoute(route, "GET", handler) + s.addRoute(route, "GET", handler) } // Post adds a handler for the 'POST' http method for server s. func (s *Server) Post(route string, handler interface{}) { - s.addRoute(route, "POST", handler) + s.addRoute(route, "POST", handler) } // Put adds a handler for the 'PUT' http method for server s. func (s *Server) Put(route string, handler interface{}) { - s.addRoute(route, "PUT", handler) + s.addRoute(route, "PUT", handler) } // Delete adds a handler for the 'DELETE' http method for server s. func (s *Server) Delete(route string, handler interface{}) { - s.addRoute(route, "DELETE", handler) + s.addRoute(route, "DELETE", handler) } // Match adds a handler for an arbitrary http method for server s. func (s *Server) Match(method string, route string, handler interface{}) { - s.addRoute(route, method, handler) + s.addRoute(route, method, handler) } //Adds a custom handler. Only for webserver mode. Will have no effect when running as FCGI or SCGI. func (s *Server) Handler(route string, method string, httpHandler http.Handler) { - s.addRoute(route, method, httpHandler) + s.addRoute(route, method, httpHandler) } //Adds a handler for websockets. Only for webserver mode. Will have no effect when running as FCGI or SCGI. func (s *Server) Websocket(route string, httpHandler websocket.Handler) { - s.addRoute(route, "GET", httpHandler) + s.addRoute(route, "GET", httpHandler) } // Run starts the web application and serves HTTP requests for s func (s *Server) Run(addr string) { - s.initServer() - - mux := http.NewServeMux() - if s.Config.Profiler { - mux.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) - mux.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) - mux.Handle("/debug/pprof/heap", pprof.Handler("heap")) - mux.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) - } - mux.Handle("/", s) - - l, err := net.Listen("tcp", addr) - if err != nil { - log.Fatal("ListenAndServe:", err) - } - - s.Logger.Printf("web.go serving %s\n", l.Addr()) - - s.l = l - err = http.Serve(s.l, mux) - s.l.Close() + s.initServer() + + mux := http.NewServeMux() + if s.Config.Profiler { + mux.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + mux.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) + mux.Handle("/debug/pprof/heap", pprof.Handler("heap")) + mux.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) + } + mux.Handle("/", s) + + l, err := net.Listen("tcp", addr) + if err != nil { + log.Fatal("ListenAndServe:", err) + } + + s.Logger.Printf("web.go serving %s\n", l.Addr()) + + s.l = l + err = http.Serve(s.l, mux) + s.l.Close() } // RunFcgi starts the web application and serves FastCGI requests for s. func (s *Server) RunFcgi(addr string) { - s.initServer() - s.Logger.Printf("web.go serving fcgi %s\n", addr) - s.listenAndServeFcgi(addr) + s.initServer() + s.Logger.Printf("web.go serving fcgi %s\n", addr) + s.listenAndServeFcgi(addr) } // RunScgi starts the web application and serves SCGI requests for s. func (s *Server) RunScgi(addr string) { - s.initServer() - s.Logger.Printf("web.go serving scgi %s\n", addr) - s.listenAndServeScgi(addr) + s.initServer() + s.Logger.Printf("web.go serving scgi %s\n", addr) + s.listenAndServeScgi(addr) } // RunTLS starts the web application and serves HTTPS requests for s. func (s *Server) RunTLS(addr string, config *tls.Config) error { - s.initServer() - mux := http.NewServeMux() - mux.Handle("/", s) - - l, err := tls.Listen("tcp", addr, config) - if err != nil { - log.Fatal("Listen:", err) - return err - } - s.Logger.Printf("web.go serving %s\n", l.Addr()) - - s.l = l - return http.Serve(s.l, mux) + s.initServer() + mux := http.NewServeMux() + mux.Handle("/", s) + + l, err := tls.Listen("tcp", addr, config) + if err != nil { + log.Fatal("Listen:", err) + return err + } + s.Logger.Printf("web.go serving %s\n", l.Addr()) + + s.l = l + return http.Serve(s.l, mux) } // Close stops server s. func (s *Server) Close() { - if s.l != nil { - s.l.Close() - } + if s.l != nil { + s.l.Close() + } } // safelyCall invokes `function` in recover block func (s *Server) safelyCall(function reflect.Value, args []reflect.Value) (resp []reflect.Value, e interface{}) { - defer func() { - if err := recover(); err != nil { - if !s.Config.RecoverPanic { - // go back to panic - panic(err) - } else { - e = err - resp = nil - s.Logger.Println("Handler crashed with error", err) - for i := 1; ; i += 1 { - _, file, line, ok := runtime.Caller(i) - if !ok { - break - } - s.Logger.Println(file, line) - } - } - } - }() - return function.Call(args), nil + defer func() { + if err := recover(); err != nil { + if !s.Config.RecoverPanic { + // go back to panic + panic(err) + } else { + e = err + resp = nil + s.Logger.Println("Handler crashed with error", err) + for i := 1; ; i += 1 { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + s.Logger.Println(file, line) + } + } + } + }() + return function.Call(args), nil } // requiresContext determines whether 'handlerType' contains // an argument to 'web.Ctx' as its first argument func requiresContext(handlerType reflect.Type) bool { - //if the method doesn't take arguments, no - if handlerType.NumIn() == 0 { - return false - } - - //if the first argument is not a pointer, no - a0 := handlerType.In(0) - if a0.Kind() != reflect.Ptr { - return false - } - //if the first argument is a context, yes - if a0.Elem() == contextType { - return true - } - - return false + //if the method doesn't take arguments, no + if handlerType.NumIn() == 0 { + return false + } + + //if the first argument is not a pointer, no + a0 := handlerType.In(0) + if a0.Kind() != reflect.Ptr { + return false + } + //if the first argument is a context, yes + if a0.Elem() == contextType { + return true + } + + return false } // tryServingFile attempts to serve a static file, and returns @@ -247,50 +247,50 @@ func requiresContext(handlerType reflect.Type) bool { // 2) The 'static' directory in the parent directory of the executable. // 3) The 'static' directory in the current working directory func (s *Server) tryServingFile(name string, req *http.Request, w http.ResponseWriter) bool { - //try to serve a static file - if s.Config.StaticDir != "" { - staticFile := path.Join(s.Config.StaticDir, name) - if fileExists(staticFile) { - http.ServeFile(w, req, staticFile) - return true - } - } else { - for _, staticDir := range defaultStaticDirs { - staticFile := path.Join(staticDir, name) - if fileExists(staticFile) { - http.ServeFile(w, req, staticFile) - return true - } - } - } - return false + //try to serve a static file + if s.Config.StaticDir != "" { + staticFile := path.Join(s.Config.StaticDir, name) + if fileExists(staticFile) { + http.ServeFile(w, req, staticFile) + return true + } + } else { + for _, staticDir := range defaultStaticDirs { + staticFile := path.Join(staticDir, name) + if fileExists(staticFile) { + http.ServeFile(w, req, staticFile) + return true + } + } + } + return false } func (s *Server) logRequest(ctx Context, sTime time.Time) { - //log the request - var logEntry bytes.Buffer - req := ctx.Request - requestPath := req.URL.Path + //log the request + var logEntry bytes.Buffer + req := ctx.Request + requestPath := req.URL.Path - duration := time.Now().Sub(sTime) - var client string + duration := time.Now().Sub(sTime) + var client string - // We suppose RemoteAddr is of the form Ip:Port as specified in the Request - // documentation at http://golang.org/pkg/net/http/#Request - pos := strings.LastIndex(req.RemoteAddr, ":") - if pos > 0 { - client = req.RemoteAddr[0:pos] - } else { - client = req.RemoteAddr - } + // We suppose RemoteAddr is of the form Ip:Port as specified in the Request + // documentation at http://golang.org/pkg/net/http/#Request + pos := strings.LastIndex(req.RemoteAddr, ":") + if pos > 0 { + client = req.RemoteAddr[0:pos] + } else { + client = req.RemoteAddr + } - fmt.Fprintf(&logEntry, "%s - \033[32;1m %s %s\033[0m - %v", client, req.Method, requestPath, duration) + fmt.Fprintf(&logEntry, "%s - \033[32;1m %s %s\033[0m - %v", client, req.Method, requestPath, duration) - if len(ctx.Params) > 0 { - fmt.Fprintf(&logEntry, " - \033[37;1mParams: %v\033[0m\n", ctx.Params) - } + if len(ctx.Params) > 0 { + fmt.Fprintf(&logEntry, " - \033[37;1mParams: %v\033[0m\n", ctx.Params) + } - ctx.Server.Logger.Print(logEntry.String()) + ctx.Server.Logger.Print(logEntry.String()) } @@ -301,105 +301,105 @@ func (s *Server) logRequest(ctx Context, sTime time.Time) { // route. The caller is then responsible for calling the httpHandler associated // with the returned route. func (s *Server) routeHandler(req *http.Request, w http.ResponseWriter) (unused *route) { - requestPath := req.URL.Path - ctx := Context{req, map[string]string{}, s, w} - - //set some default headers - ctx.SetHeader("Server", "web.go", true) - tm := time.Now().UTC() - - //ignore errors from ParseForm because it's usually harmless. - req.ParseForm() - if len(req.Form) > 0 { - for k, v := range req.Form { - ctx.Params[k] = v[0] - } - } - - defer s.logRequest(ctx, tm) - - ctx.SetHeader("Date", webTime(tm), true) - - if req.Method == "GET" || req.Method == "HEAD" { - if s.tryServingFile(requestPath, req, w) { - return - } - } - - //Set the default content-type - ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true) - - for i := 0; i < len(s.routes); i++ { - route := s.routes[i] - cr := route.cr - //if the methods don't match, skip this handler (except HEAD can be used in place of GET) - if req.Method != route.method && !(req.Method == "HEAD" && route.method == "GET") { - continue - } - - if !cr.MatchString(requestPath) { - continue - } - match := cr.FindStringSubmatch(requestPath) - - if len(match[0]) != len(requestPath) { - continue - } - - if route.httpHandler != nil { - unused = &route - // We can not handle custom http handlers here, give back to the caller. - return - } - - var args []reflect.Value - handlerType := route.handler.Type() - if requiresContext(handlerType) { - args = append(args, reflect.ValueOf(&ctx)) - } - for _, arg := range match[1:] { - args = append(args, reflect.ValueOf(arg)) - } - - ret, err := s.safelyCall(route.handler, args) - if err != nil { - //there was an error or panic while calling the handler - ctx.Abort(500, "Server Error") - } - if len(ret) == 0 { - return - } - - sval := ret[0] - - var content []byte - - if sval.Kind() == reflect.String { - content = []byte(sval.String()) - } else if sval.Kind() == reflect.Slice && sval.Type().Elem().Kind() == reflect.Uint8 { - content = sval.Interface().([]byte) - } - ctx.SetHeader("Content-Length", strconv.Itoa(len(content)), true) - _, err = ctx.ResponseWriter.Write(content) - if err != nil { - ctx.Server.Logger.Println("Error during write: ", err) - } - return - } - - // try serving index.html or index.htm - if req.Method == "GET" || req.Method == "HEAD" { - if s.tryServingFile(path.Join(requestPath, "index.html"), req, w) { - return - } else if s.tryServingFile(path.Join(requestPath, "index.htm"), req, w) { - return - } - } - ctx.Abort(404, "Page not found") - return + requestPath := req.URL.Path + ctx := Context{req, map[string]string{}, s, w} + + //set some default headers + ctx.SetHeader("Server", "web.go", true) + tm := time.Now().UTC() + + //ignore errors from ParseForm because it's usually harmless. + req.ParseForm() + if len(req.Form) > 0 { + for k, v := range req.Form { + ctx.Params[k] = v[0] + } + } + + defer s.logRequest(ctx, tm) + + ctx.SetHeader("Date", webTime(tm), true) + + if req.Method == "GET" || req.Method == "HEAD" { + if s.tryServingFile(requestPath, req, w) { + return + } + } + + //Set the default content-type + ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true) + + for i := 0; i < len(s.routes); i++ { + route := s.routes[i] + cr := route.cr + //if the methods don't match, skip this handler (except HEAD can be used in place of GET) + if req.Method != route.method && !(req.Method == "HEAD" && route.method == "GET") { + continue + } + + if !cr.MatchString(requestPath) { + continue + } + match := cr.FindStringSubmatch(requestPath) + + if len(match[0]) != len(requestPath) { + continue + } + + if route.httpHandler != nil { + unused = &route + // We can not handle custom http handlers here, give back to the caller. + return + } + + var args []reflect.Value + handlerType := route.handler.Type() + if requiresContext(handlerType) { + args = append(args, reflect.ValueOf(&ctx)) + } + for _, arg := range match[1:] { + args = append(args, reflect.ValueOf(arg)) + } + + ret, err := s.safelyCall(route.handler, args) + if err != nil { + //there was an error or panic while calling the handler + ctx.Abort(500, "Server Error") + } + if len(ret) == 0 { + return + } + + sval := ret[0] + + var content []byte + + if sval.Kind() == reflect.String { + content = []byte(sval.String()) + } else if sval.Kind() == reflect.Slice && sval.Type().Elem().Kind() == reflect.Uint8 { + content = sval.Interface().([]byte) + } + ctx.SetHeader("Content-Length", strconv.Itoa(len(content)), true) + _, err = ctx.ResponseWriter.Write(content) + if err != nil { + ctx.Server.Logger.Println("Error during write: ", err) + } + return + } + + // try serving index.html or index.htm + if req.Method == "GET" || req.Method == "HEAD" { + if s.tryServingFile(path.Join(requestPath, "index.html"), req, w) { + return + } else if s.tryServingFile(path.Join(requestPath, "index.htm"), req, w) { + return + } + } + ctx.Abort(404, "Page not found") + return } // SetLogger sets the logger for server s func (s *Server) SetLogger(logger *log.Logger) { - s.Logger = logger + s.Logger = logger } diff --git a/status.go b/status.go index 83053cce..e0d995f9 100644 --- a/status.go +++ b/status.go @@ -7,48 +7,48 @@ package web import "net/http" var statusText = map[int]string{ - http.StatusContinue: "Continue", - http.StatusSwitchingProtocols: "Switching Protocols", + http.StatusContinue: "Continue", + http.StatusSwitchingProtocols: "Switching Protocols", - http.StatusOK: "OK", - http.StatusCreated: "Created", - http.StatusAccepted: "Accepted", - http.StatusNonAuthoritativeInfo: "Non-Authoritative Information", - http.StatusNoContent: "No Content", - http.StatusResetContent: "Reset Content", - http.StatusPartialContent: "Partial Content", + http.StatusOK: "OK", + http.StatusCreated: "Created", + http.StatusAccepted: "Accepted", + http.StatusNonAuthoritativeInfo: "Non-Authoritative Information", + http.StatusNoContent: "No Content", + http.StatusResetContent: "Reset Content", + http.StatusPartialContent: "Partial Content", - http.StatusMultipleChoices: "Multiple Choices", - http.StatusMovedPermanently: "Moved Permanently", - http.StatusFound: "Found", - http.StatusSeeOther: "See Other", - http.StatusNotModified: "Not Modified", - http.StatusUseProxy: "Use Proxy", - http.StatusTemporaryRedirect: "Temporary Redirect", + http.StatusMultipleChoices: "Multiple Choices", + http.StatusMovedPermanently: "Moved Permanently", + http.StatusFound: "Found", + http.StatusSeeOther: "See Other", + http.StatusNotModified: "Not Modified", + http.StatusUseProxy: "Use Proxy", + http.StatusTemporaryRedirect: "Temporary Redirect", - http.StatusBadRequest: "Bad Request", - http.StatusUnauthorized: "Unauthorized", - http.StatusPaymentRequired: "Payment Required", - http.StatusForbidden: "Forbidden", - http.StatusNotFound: "Not Found", - http.StatusMethodNotAllowed: "Method Not Allowed", - http.StatusNotAcceptable: "Not Acceptable", - http.StatusProxyAuthRequired: "Proxy Authentication Required", - http.StatusRequestTimeout: "Request Timeout", - http.StatusConflict: "Conflict", - http.StatusGone: "Gone", - http.StatusLengthRequired: "Length Required", - http.StatusPreconditionFailed: "Precondition Failed", - http.StatusRequestEntityTooLarge: "Request Entity Too Large", - http.StatusRequestURITooLong: "Request URI Too Long", - http.StatusUnsupportedMediaType: "Unsupported Media Type", - http.StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", - http.StatusExpectationFailed: "Expectation Failed", + http.StatusBadRequest: "Bad Request", + http.StatusUnauthorized: "Unauthorized", + http.StatusPaymentRequired: "Payment Required", + http.StatusForbidden: "Forbidden", + http.StatusNotFound: "Not Found", + http.StatusMethodNotAllowed: "Method Not Allowed", + http.StatusNotAcceptable: "Not Acceptable", + http.StatusProxyAuthRequired: "Proxy Authentication Required", + http.StatusRequestTimeout: "Request Timeout", + http.StatusConflict: "Conflict", + http.StatusGone: "Gone", + http.StatusLengthRequired: "Length Required", + http.StatusPreconditionFailed: "Precondition Failed", + http.StatusRequestEntityTooLarge: "Request Entity Too Large", + http.StatusRequestURITooLong: "Request URI Too Long", + http.StatusUnsupportedMediaType: "Unsupported Media Type", + http.StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", + http.StatusExpectationFailed: "Expectation Failed", - http.StatusInternalServerError: "Internal Server Error", - http.StatusNotImplemented: "Not Implemented", - http.StatusBadGateway: "Bad Gateway", - http.StatusServiceUnavailable: "Service Unavailable", - http.StatusGatewayTimeout: "Gateway Timeout", - http.StatusHTTPVersionNotSupported: "HTTP Version Not Supported", + http.StatusInternalServerError: "Internal Server Error", + http.StatusNotImplemented: "Not Implemented", + http.StatusBadGateway: "Bad Gateway", + http.StatusServiceUnavailable: "Service Unavailable", + http.StatusGatewayTimeout: "Gateway Timeout", + http.StatusHTTPVersionNotSupported: "HTTP Version Not Supported", } diff --git a/web.go b/web.go index acad9955..1c597a47 100644 --- a/web.go +++ b/web.go @@ -3,23 +3,23 @@ package web import ( - "bytes" - "golang.org/x/net/websocket" - "crypto/hmac" - "crypto/sha1" - "crypto/tls" - "encoding/base64" - "fmt" - "io/ioutil" - "log" - "mime" - "net/http" - "os" - "path" - "reflect" - "strconv" - "strings" - "time" + "bytes" + "crypto/hmac" + "crypto/sha1" + "crypto/tls" + "encoding/base64" + "fmt" + "golang.org/x/net/websocket" + "io/ioutil" + "log" + "mime" + "net/http" + "os" + "path" + "reflect" + "strconv" + "strings" + "time" ) // A Context object is created for every incoming HTTP request, and is @@ -27,15 +27,15 @@ import ( // about the request, including the http.Request object, the GET and POST params, // and acts as a Writer for the response. type Context struct { - Request *http.Request - Params map[string]string - Server *Server - http.ResponseWriter + Request *http.Request + Params map[string]string + Server *Server + http.ResponseWriter } // WriteString writes string data into the response object. func (ctx *Context) WriteString(content string) { - ctx.ResponseWriter.Write([]byte(content)) + ctx.ResponseWriter.Write([]byte(content)) } // Abort is a helper method that sends an HTTP header and an optional @@ -43,36 +43,36 @@ func (ctx *Context) WriteString(content string) { // Once it has been called, any return value from the handler will // not be written to the response. func (ctx *Context) Abort(status int, body string) { - ctx.ResponseWriter.WriteHeader(status) - ctx.ResponseWriter.Write([]byte(body)) + ctx.ResponseWriter.WriteHeader(status) + ctx.ResponseWriter.Write([]byte(body)) } // Redirect is a helper method for 3xx redirects. func (ctx *Context) Redirect(status int, url_ string) { - ctx.ResponseWriter.Header().Set("Location", url_) - ctx.ResponseWriter.WriteHeader(status) - ctx.ResponseWriter.Write([]byte("Redirecting to: " + url_)) + ctx.ResponseWriter.Header().Set("Location", url_) + ctx.ResponseWriter.WriteHeader(status) + ctx.ResponseWriter.Write([]byte("Redirecting to: " + url_)) } // Notmodified writes a 304 HTTP response func (ctx *Context) NotModified() { - ctx.ResponseWriter.WriteHeader(304) + ctx.ResponseWriter.WriteHeader(304) } // NotFound writes a 404 HTTP response func (ctx *Context) NotFound(message string) { - ctx.ResponseWriter.WriteHeader(404) - ctx.ResponseWriter.Write([]byte(message)) + ctx.ResponseWriter.WriteHeader(404) + ctx.ResponseWriter.Write([]byte(message)) } //Unauthorized writes a 401 HTTP response func (ctx *Context) Unauthorized() { - ctx.ResponseWriter.WriteHeader(401) + ctx.ResponseWriter.WriteHeader(401) } //Forbidden writes a 403 HTTP response func (ctx *Context) Forbidden() { - ctx.ResponseWriter.WriteHeader(403) + ctx.ResponseWriter.WriteHeader(403) } // ContentType sets the Content-Type header for an HTTP response. @@ -81,92 +81,92 @@ func (ctx *Context) Forbidden() { // verbatim. The return value is the content type as it was // set, or an empty string if none was found. func (ctx *Context) ContentType(val string) string { - var ctype string - if strings.ContainsRune(val, '/') { - ctype = val - } else { - if !strings.HasPrefix(val, ".") { - val = "." + val - } - ctype = mime.TypeByExtension(val) - } - if ctype != "" { - ctx.Header().Set("Content-Type", ctype) - } - return ctype + var ctype string + if strings.ContainsRune(val, '/') { + ctype = val + } else { + if !strings.HasPrefix(val, ".") { + val = "." + val + } + ctype = mime.TypeByExtension(val) + } + if ctype != "" { + ctx.Header().Set("Content-Type", ctype) + } + return ctype } // SetHeader sets a response header. If `unique` is true, the current value // of that header will be overwritten . If false, it will be appended. func (ctx *Context) SetHeader(hdr string, val string, unique bool) { - if unique { - ctx.Header().Set(hdr, val) - } else { - ctx.Header().Add(hdr, val) - } + if unique { + ctx.Header().Set(hdr, val) + } else { + ctx.Header().Add(hdr, val) + } } // SetCookie adds a cookie header to the response. func (ctx *Context) SetCookie(cookie *http.Cookie) { - ctx.SetHeader("Set-Cookie", cookie.String(), false) + ctx.SetHeader("Set-Cookie", cookie.String(), false) } func getCookieSig(key string, val []byte, timestamp string) string { - hm := hmac.New(sha1.New, []byte(key)) + hm := hmac.New(sha1.New, []byte(key)) - hm.Write(val) - hm.Write([]byte(timestamp)) + hm.Write(val) + hm.Write([]byte(timestamp)) - return fmt.Sprintf("%02x", hm.Sum(nil)) + return fmt.Sprintf("%02x", hm.Sum(nil)) } func (ctx *Context) SetSecureCookie(name string, val string, age int64) { - //base64 encode the val - if len(ctx.Server.Config.CookieSecret) == 0 { - ctx.Server.Logger.Println("Secret Key for secure cookies has not been set. Please assign a cookie secret to web.Config.CookieSecret.") - return - } - var buf bytes.Buffer - encoder := base64.NewEncoder(base64.StdEncoding, &buf) - encoder.Write([]byte(val)) - encoder.Close() - vs := buf.String() - vb := buf.Bytes() - timestamp := strconv.FormatInt(time.Now().Unix(), 10) - sig := getCookieSig(ctx.Server.Config.CookieSecret, vb, timestamp) - cookie := strings.Join([]string{vs, timestamp, sig}, "|") - ctx.SetCookie(NewCookie(name, cookie, age)) + //base64 encode the val + if len(ctx.Server.Config.CookieSecret) == 0 { + ctx.Server.Logger.Println("Secret Key for secure cookies has not been set. Please assign a cookie secret to web.Config.CookieSecret.") + return + } + var buf bytes.Buffer + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + encoder.Write([]byte(val)) + encoder.Close() + vs := buf.String() + vb := buf.Bytes() + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + sig := getCookieSig(ctx.Server.Config.CookieSecret, vb, timestamp) + cookie := strings.Join([]string{vs, timestamp, sig}, "|") + ctx.SetCookie(NewCookie(name, cookie, age)) } func (ctx *Context) GetSecureCookie(name string) (string, bool) { - for _, cookie := range ctx.Request.Cookies() { - if cookie.Name != name { - continue - } + for _, cookie := range ctx.Request.Cookies() { + if cookie.Name != name { + continue + } - parts := strings.SplitN(cookie.Value, "|", 3) + parts := strings.SplitN(cookie.Value, "|", 3) - val := parts[0] - timestamp := parts[1] - sig := parts[2] + val := parts[0] + timestamp := parts[1] + sig := parts[2] - if getCookieSig(ctx.Server.Config.CookieSecret, []byte(val), timestamp) != sig { - return "", false - } + if getCookieSig(ctx.Server.Config.CookieSecret, []byte(val), timestamp) != sig { + return "", false + } - ts, _ := strconv.ParseInt(timestamp, 0, 64) + ts, _ := strconv.ParseInt(timestamp, 0, 64) - if time.Now().Unix()-31*86400 > ts { - return "", false - } + if time.Now().Unix()-31*86400 > ts { + return "", false + } - buf := bytes.NewBufferString(val) - encoder := base64.NewDecoder(base64.StdEncoding, buf) + buf := bytes.NewBufferString(val) + encoder := base64.NewDecoder(base64.StdEncoding, buf) - res, _ := ioutil.ReadAll(encoder) - return string(res), true - } - return "", false + res, _ := ioutil.ReadAll(encoder) + return string(res), true + } + return "", false } // small optimization: cache the context type instead of repeteadly calling reflect.Typeof @@ -175,96 +175,96 @@ var contextType reflect.Type var defaultStaticDirs []string func init() { - contextType = reflect.TypeOf(Context{}) - //find the location of the exe file - wd, _ := os.Getwd() - arg0 := path.Clean(os.Args[0]) - var exeFile string - if strings.HasPrefix(arg0, "/") { - exeFile = arg0 - } else { - //TODO for robustness, search each directory in $PATH - exeFile = path.Join(wd, arg0) - } - parent, _ := path.Split(exeFile) - defaultStaticDirs = append(defaultStaticDirs, path.Join(parent, "static")) - defaultStaticDirs = append(defaultStaticDirs, path.Join(wd, "static")) - return + contextType = reflect.TypeOf(Context{}) + //find the location of the exe file + wd, _ := os.Getwd() + arg0 := path.Clean(os.Args[0]) + var exeFile string + if strings.HasPrefix(arg0, "/") { + exeFile = arg0 + } else { + //TODO for robustness, search each directory in $PATH + exeFile = path.Join(wd, arg0) + } + parent, _ := path.Split(exeFile) + defaultStaticDirs = append(defaultStaticDirs, path.Join(parent, "static")) + defaultStaticDirs = append(defaultStaticDirs, path.Join(wd, "static")) + return } // Process invokes the main server's routing system. func Process(c http.ResponseWriter, req *http.Request) { - mainServer.Process(c, req) + mainServer.Process(c, req) } // Run starts the web application and serves HTTP requests for the main server. func Run(addr string) { - mainServer.Run(addr) + mainServer.Run(addr) } // RunTLS starts the web application and serves HTTPS requests for the main server. func RunTLS(addr string, config *tls.Config) { - mainServer.RunTLS(addr, config) + mainServer.RunTLS(addr, config) } // RunScgi starts the web application and serves SCGI requests for the main server. func RunScgi(addr string) { - mainServer.RunScgi(addr) + mainServer.RunScgi(addr) } // RunFcgi starts the web application and serves FastCGI requests for the main server. func RunFcgi(addr string) { - mainServer.RunFcgi(addr) + mainServer.RunFcgi(addr) } // Close stops the main server. func Close() { - mainServer.Close() + mainServer.Close() } // Get adds a handler for the 'GET' http method in the main server. func Get(route string, handler interface{}) { - mainServer.Get(route, handler) + mainServer.Get(route, handler) } // Post adds a handler for the 'POST' http method in the main server. func Post(route string, handler interface{}) { - mainServer.addRoute(route, "POST", handler) + mainServer.addRoute(route, "POST", handler) } // Put adds a handler for the 'PUT' http method in the main server. func Put(route string, handler interface{}) { - mainServer.addRoute(route, "PUT", handler) + mainServer.addRoute(route, "PUT", handler) } // Delete adds a handler for the 'DELETE' http method in the main server. func Delete(route string, handler interface{}) { - mainServer.addRoute(route, "DELETE", handler) + mainServer.addRoute(route, "DELETE", handler) } // Match adds a handler for an arbitrary http method in the main server. func Match(method string, route string, handler interface{}) { - mainServer.addRoute(route, method, handler) + mainServer.addRoute(route, method, handler) } //Adds a custom handler. Only for webserver mode. Will have no effect when running as FCGI or SCGI. func Handler(route string, method string, httpHandler http.Handler) { - mainServer.Handler(route, method, httpHandler) + mainServer.Handler(route, method, httpHandler) } //Adds a handler for websockets. Only for webserver mode. Will have no effect when running as FCGI or SCGI. func Websocket(route string, httpHandler websocket.Handler) { - mainServer.Websocket(route, httpHandler) + mainServer.Websocket(route, httpHandler) } // SetLogger sets the logger for the main server. func SetLogger(logger *log.Logger) { - mainServer.Logger = logger + mainServer.Logger = logger } // Config is the configuration of the main server. var Config = &ServerConfig{ - RecoverPanic: true, + RecoverPanic: true, } var mainServer = NewServer() diff --git a/web_test.go b/web_test.go index 9f6f1b65..2c5f069a 100644 --- a/web_test.go +++ b/web_test.go @@ -1,580 +1,580 @@ package web import ( - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "io/ioutil" - "log" - "net/http" - "net/url" - "runtime" - "strconv" - "strings" - "testing" + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "net/url" + "runtime" + "strconv" + "strings" + "testing" ) func init() { - runtime.GOMAXPROCS(runtime.NumCPU()) + runtime.GOMAXPROCS(runtime.NumCPU()) } // ioBuffer is a helper that implements io.ReadWriteCloser, // which is helpful in imitating a net.Conn type ioBuffer struct { - input *bytes.Buffer - output *bytes.Buffer - closed bool + input *bytes.Buffer + output *bytes.Buffer + closed bool } func (buf *ioBuffer) Write(p []uint8) (n int, err error) { - if buf.closed { - return 0, errors.New("Write after Close on ioBuffer") - } - return buf.output.Write(p) + if buf.closed { + return 0, errors.New("Write after Close on ioBuffer") + } + return buf.output.Write(p) } func (buf *ioBuffer) Read(p []byte) (n int, err error) { - if buf.closed { - return 0, errors.New("Read after Close on ioBuffer") - } - return buf.input.Read(p) + if buf.closed { + return 0, errors.New("Read after Close on ioBuffer") + } + return buf.input.Read(p) } //noop func (buf *ioBuffer) Close() error { - buf.closed = true - return nil + buf.closed = true + return nil } type testResponse struct { - statusCode int - status string - body string - headers map[string][]string - cookies map[string]string + statusCode int + status string + body string + headers map[string][]string + cookies map[string]string } func buildTestResponse(buf *bytes.Buffer) *testResponse { - response := testResponse{headers: make(map[string][]string), cookies: make(map[string]string)} - s := buf.String() - contents := strings.SplitN(s, "\r\n\r\n", 2) + response := testResponse{headers: make(map[string][]string), cookies: make(map[string]string)} + s := buf.String() + contents := strings.SplitN(s, "\r\n\r\n", 2) - header := contents[0] + header := contents[0] - if len(contents) > 1 { - response.body = contents[1] - } + if len(contents) > 1 { + response.body = contents[1] + } - headers := strings.Split(header, "\r\n") + headers := strings.Split(header, "\r\n") - statusParts := strings.SplitN(headers[0], " ", 3) - response.statusCode, _ = strconv.Atoi(statusParts[1]) + statusParts := strings.SplitN(headers[0], " ", 3) + response.statusCode, _ = strconv.Atoi(statusParts[1]) - for _, h := range headers[1:] { - split := strings.SplitN(h, ":", 2) - name := strings.TrimSpace(split[0]) - value := strings.TrimSpace(split[1]) - if _, ok := response.headers[name]; !ok { - response.headers[name] = []string{} - } + for _, h := range headers[1:] { + split := strings.SplitN(h, ":", 2) + name := strings.TrimSpace(split[0]) + value := strings.TrimSpace(split[1]) + if _, ok := response.headers[name]; !ok { + response.headers[name] = []string{} + } - newheaders := make([]string, len(response.headers[name])+1) - copy(newheaders, response.headers[name]) - newheaders[len(newheaders)-1] = value - response.headers[name] = newheaders + newheaders := make([]string, len(response.headers[name])+1) + copy(newheaders, response.headers[name]) + newheaders[len(newheaders)-1] = value + response.headers[name] = newheaders - //if the header is a cookie, set it - if name == "Set-Cookie" { - i := strings.Index(value, ";") - cookie := value[0:i] - cookieParts := strings.SplitN(cookie, "=", 2) - response.cookies[strings.TrimSpace(cookieParts[0])] = strings.TrimSpace(cookieParts[1]) - } - } + //if the header is a cookie, set it + if name == "Set-Cookie" { + i := strings.Index(value, ";") + cookie := value[0:i] + cookieParts := strings.SplitN(cookie, "=", 2) + response.cookies[strings.TrimSpace(cookieParts[0])] = strings.TrimSpace(cookieParts[1]) + } + } - return &response + return &response } func getTestResponse(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *testResponse { - req := buildTestRequest(method, path, body, headers, cookies) - var buf bytes.Buffer + req := buildTestRequest(method, path, body, headers, cookies) + var buf bytes.Buffer - tcpb := ioBuffer{input: nil, output: &buf} - c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &tcpb} - mainServer.Process(&c, req) - return buildTestResponse(&buf) + tcpb := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &tcpb} + mainServer.Process(&c, req) + return buildTestResponse(&buf) } func testGet(path string, headers map[string]string) *testResponse { - var header http.Header - for k, v := range headers { - header.Set(k, v) - } - return getTestResponse("GET", path, "", header, nil) + var header http.Header + for k, v := range headers { + header.Set(k, v) + } + return getTestResponse("GET", path, "", header, nil) } type Test struct { - method string - path string - headers map[string][]string - body string - expectedStatus int - expectedBody string + method string + path string + headers map[string][]string + body string + expectedStatus int + expectedBody string } //initialize the routes func init() { - mainServer.SetLogger(log.New(ioutil.Discard, "", 0)) - Get("/", func() string { return "index" }) - Get("/panic", func() { panic(0) }) - Get("/echo/(.*)", func(s string) string { return s }) - Get("/multiecho/(.*)/(.*)/(.*)/(.*)", func(a, b, c, d string) string { return a + b + c + d }) - Post("/post/echo/(.*)", func(s string) string { return s }) - Post("/post/echoparam/(.*)", func(ctx *Context, name string) string { return ctx.Params[name] }) - - Get("/error/code/(.*)", func(ctx *Context, code string) string { - n, _ := strconv.Atoi(code) - message := statusText[n] - ctx.Abort(n, message) - return "" - }) - - Get("/error/notfound/(.*)", func(ctx *Context, message string) { ctx.NotFound(message) }) - - Get("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) - Post("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) - - Get("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) - Post("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) - - Post("/posterror/code/(.*)/(.*)", func(ctx *Context, code string, message string) string { - n, _ := strconv.Atoi(code) - ctx.Abort(n, message) - return "" - }) - - Get("/writetest", func(ctx *Context) { ctx.WriteString("hello") }) - - Post("/securecookie/set/(.+)/(.+)", func(ctx *Context, name string, val string) string { - ctx.SetSecureCookie(name, val, 60) - return "" - }) - - Get("/securecookie/get/(.+)", func(ctx *Context, name string) string { - val, ok := ctx.GetSecureCookie(name) - if !ok { - return "" - } - return val - }) - Get("/getparam", func(ctx *Context) string { return ctx.Params["a"] }) - Get("/fullparams", func(ctx *Context) string { - return strings.Join(ctx.Request.Form["a"], ",") - }) - - Get("/json", func(ctx *Context) string { - ctx.ContentType("json") - data, _ := json.Marshal(ctx.Params) - return string(data) - }) - - Get("/jsonbytes", func(ctx *Context) []byte { - ctx.ContentType("json") - data, _ := json.Marshal(ctx.Params) - return data - }) - - Post("/parsejson", func(ctx *Context) string { - var tmp = struct { - A string - B string - }{} - json.NewDecoder(ctx.Request.Body).Decode(&tmp) - return tmp.A + " " + tmp.B - }) - - Match("OPTIONS", "/options", func(ctx *Context) { - ctx.SetHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS", true) - ctx.SetHeader("Access-Control-Max-Age", "1000", true) - ctx.WriteHeader(200) - }) - - Get("/dupeheader", func(ctx *Context) string { - ctx.SetHeader("Server", "myserver", true) - return "" - }) - - Get("/authorization", func(ctx *Context) string { - user, pass, err := ctx.GetBasicAuth() - if err != nil { - return "fail" - } - return user + pass - }) + mainServer.SetLogger(log.New(ioutil.Discard, "", 0)) + Get("/", func() string { return "index" }) + Get("/panic", func() { panic(0) }) + Get("/echo/(.*)", func(s string) string { return s }) + Get("/multiecho/(.*)/(.*)/(.*)/(.*)", func(a, b, c, d string) string { return a + b + c + d }) + Post("/post/echo/(.*)", func(s string) string { return s }) + Post("/post/echoparam/(.*)", func(ctx *Context, name string) string { return ctx.Params[name] }) + + Get("/error/code/(.*)", func(ctx *Context, code string) string { + n, _ := strconv.Atoi(code) + message := statusText[n] + ctx.Abort(n, message) + return "" + }) + + Get("/error/notfound/(.*)", func(ctx *Context, message string) { ctx.NotFound(message) }) + + Get("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) + Post("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) + + Get("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) + Post("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) + + Post("/posterror/code/(.*)/(.*)", func(ctx *Context, code string, message string) string { + n, _ := strconv.Atoi(code) + ctx.Abort(n, message) + return "" + }) + + Get("/writetest", func(ctx *Context) { ctx.WriteString("hello") }) + + Post("/securecookie/set/(.+)/(.+)", func(ctx *Context, name string, val string) string { + ctx.SetSecureCookie(name, val, 60) + return "" + }) + + Get("/securecookie/get/(.+)", func(ctx *Context, name string) string { + val, ok := ctx.GetSecureCookie(name) + if !ok { + return "" + } + return val + }) + Get("/getparam", func(ctx *Context) string { return ctx.Params["a"] }) + Get("/fullparams", func(ctx *Context) string { + return strings.Join(ctx.Request.Form["a"], ",") + }) + + Get("/json", func(ctx *Context) string { + ctx.ContentType("json") + data, _ := json.Marshal(ctx.Params) + return string(data) + }) + + Get("/jsonbytes", func(ctx *Context) []byte { + ctx.ContentType("json") + data, _ := json.Marshal(ctx.Params) + return data + }) + + Post("/parsejson", func(ctx *Context) string { + var tmp = struct { + A string + B string + }{} + json.NewDecoder(ctx.Request.Body).Decode(&tmp) + return tmp.A + " " + tmp.B + }) + + Match("OPTIONS", "/options", func(ctx *Context) { + ctx.SetHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS", true) + ctx.SetHeader("Access-Control-Max-Age", "1000", true) + ctx.WriteHeader(200) + }) + + Get("/dupeheader", func(ctx *Context) string { + ctx.SetHeader("Server", "myserver", true) + return "" + }) + + Get("/authorization", func(ctx *Context) string { + user, pass, err := ctx.GetBasicAuth() + if err != nil { + return "fail" + } + return user + pass + }) } var tests = []Test{ - {"GET", "/", nil, "", 200, "index"}, - {"GET", "/echo/hello", nil, "", 200, "hello"}, - {"GET", "/echo/hello", nil, "", 200, "hello"}, - {"GET", "/multiecho/a/b/c/d", nil, "", 200, "abcd"}, - {"POST", "/post/echo/hello", nil, "", 200, "hello"}, - {"POST", "/post/echo/hello", nil, "", 200, "hello"}, - {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello", 200, "hello"}, - {"POST", "/post/echoparam/c?c=hello", nil, "", 200, "hello"}, - {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello\x00", 200, "hello\x00"}, - //long url - {"GET", "/echo/" + strings.Repeat("0123456789", 100), nil, "", 200, strings.Repeat("0123456789", 100)}, - {"GET", "/writetest", nil, "", 200, "hello"}, - {"GET", "/error/unauthorized", nil, "", 401, ""}, - {"POST", "/error/unauthorized", nil, "", 401, ""}, - {"GET", "/error/forbidden", nil, "", 403, ""}, - {"POST", "/error/forbidden", nil, "", 403, ""}, - {"GET", "/error/notfound/notfound", nil, "", 404, "notfound"}, - {"GET", "/doesnotexist", nil, "", 404, "Page not found"}, - {"POST", "/doesnotexist", nil, "", 404, "Page not found"}, - {"GET", "/error/code/500", nil, "", 500, statusText[500]}, - {"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"}, - {"GET", "/getparam?a=abcd", nil, "", 200, "abcd"}, - {"GET", "/getparam?b=abcd", nil, "", 200, ""}, - {"GET", "/fullparams?a=1&a=2&a=3", nil, "", 200, "1,2,3"}, - {"GET", "/panic", nil, "", 500, "Server Error"}, - {"GET", "/json?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, - {"GET", "/jsonbytes?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, - {"POST", "/parsejson", map[string][]string{"Content-Type": {"application/json"}}, `{"a":"hello", "b":"world"}`, 200, "hello world"}, - //{"GET", "/testenv", "", 200, "hello world"}, - {"GET", "/authorization", map[string][]string{"Authorization": {BuildBasicAuthCredentials("foo", "bar")}}, "", 200, "foobar"}, + {"GET", "/", nil, "", 200, "index"}, + {"GET", "/echo/hello", nil, "", 200, "hello"}, + {"GET", "/echo/hello", nil, "", 200, "hello"}, + {"GET", "/multiecho/a/b/c/d", nil, "", 200, "abcd"}, + {"POST", "/post/echo/hello", nil, "", 200, "hello"}, + {"POST", "/post/echo/hello", nil, "", 200, "hello"}, + {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello", 200, "hello"}, + {"POST", "/post/echoparam/c?c=hello", nil, "", 200, "hello"}, + {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello\x00", 200, "hello\x00"}, + //long url + {"GET", "/echo/" + strings.Repeat("0123456789", 100), nil, "", 200, strings.Repeat("0123456789", 100)}, + {"GET", "/writetest", nil, "", 200, "hello"}, + {"GET", "/error/unauthorized", nil, "", 401, ""}, + {"POST", "/error/unauthorized", nil, "", 401, ""}, + {"GET", "/error/forbidden", nil, "", 403, ""}, + {"POST", "/error/forbidden", nil, "", 403, ""}, + {"GET", "/error/notfound/notfound", nil, "", 404, "notfound"}, + {"GET", "/doesnotexist", nil, "", 404, "Page not found"}, + {"POST", "/doesnotexist", nil, "", 404, "Page not found"}, + {"GET", "/error/code/500", nil, "", 500, statusText[500]}, + {"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"}, + {"GET", "/getparam?a=abcd", nil, "", 200, "abcd"}, + {"GET", "/getparam?b=abcd", nil, "", 200, ""}, + {"GET", "/fullparams?a=1&a=2&a=3", nil, "", 200, "1,2,3"}, + {"GET", "/panic", nil, "", 500, "Server Error"}, + {"GET", "/json?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, + {"GET", "/jsonbytes?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, + {"POST", "/parsejson", map[string][]string{"Content-Type": {"application/json"}}, `{"a":"hello", "b":"world"}`, 200, "hello world"}, + //{"GET", "/testenv", "", 200, "hello world"}, + {"GET", "/authorization", map[string][]string{"Authorization": {BuildBasicAuthCredentials("foo", "bar")}}, "", 200, "foobar"}, } func buildTestRequest(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *http.Request { - host := "127.0.0.1" - port := "80" - rawurl := "http://" + host + ":" + port + path - url_, _ := url.Parse(rawurl) - proto := "HTTP/1.1" - - if headers == nil { - headers = map[string][]string{} - } - - headers["User-Agent"] = []string{"web.go test"} - if method == "POST" { - headers["Content-Length"] = []string{fmt.Sprintf("%d", len(body))} - if headers["Content-Type"] == nil { - headers["Content-Type"] = []string{"text/plain"} - } - } - - req := http.Request{Method: method, - URL: url_, - Proto: proto, - Host: host, - Header: http.Header(headers), - Body: ioutil.NopCloser(bytes.NewBufferString(body)), - } - - for _, cookie := range cookies { - req.AddCookie(cookie) - } - return &req + host := "127.0.0.1" + port := "80" + rawurl := "http://" + host + ":" + port + path + url_, _ := url.Parse(rawurl) + proto := "HTTP/1.1" + + if headers == nil { + headers = map[string][]string{} + } + + headers["User-Agent"] = []string{"web.go test"} + if method == "POST" { + headers["Content-Length"] = []string{fmt.Sprintf("%d", len(body))} + if headers["Content-Type"] == nil { + headers["Content-Type"] = []string{"text/plain"} + } + } + + req := http.Request{Method: method, + URL: url_, + Proto: proto, + Host: host, + Header: http.Header(headers), + Body: ioutil.NopCloser(bytes.NewBufferString(body)), + } + + for _, cookie := range cookies { + req.AddCookie(cookie) + } + return &req } func TestRouting(t *testing.T) { - for _, test := range tests { - resp := getTestResponse(test.method, test.path, test.body, test.headers, nil) - - if resp.statusCode != test.expectedStatus { - t.Fatalf("%v(%v) expected status %d got %d", test.method, test.path, test.expectedStatus, resp.statusCode) - } - if resp.body != test.expectedBody { - t.Fatalf("%v(%v) expected %q got %q", test.method, test.path, test.expectedBody, resp.body) - } - if cl, ok := resp.headers["Content-Length"]; ok { - clExp, _ := strconv.Atoi(cl[0]) - clAct := len(resp.body) - if clExp != clAct { - t.Fatalf("Content-length doesn't match. expected %d got %d", clExp, clAct) - } - } - } + for _, test := range tests { + resp := getTestResponse(test.method, test.path, test.body, test.headers, nil) + + if resp.statusCode != test.expectedStatus { + t.Fatalf("%v(%v) expected status %d got %d", test.method, test.path, test.expectedStatus, resp.statusCode) + } + if resp.body != test.expectedBody { + t.Fatalf("%v(%v) expected %q got %q", test.method, test.path, test.expectedBody, resp.body) + } + if cl, ok := resp.headers["Content-Length"]; ok { + clExp, _ := strconv.Atoi(cl[0]) + clAct := len(resp.body) + if clExp != clAct { + t.Fatalf("Content-length doesn't match. expected %d got %d", clExp, clAct) + } + } + } } func TestHead(t *testing.T) { - for _, test := range tests { - - if test.method != "GET" { - continue - } - getresp := getTestResponse("GET", test.path, test.body, test.headers, nil) - headresp := getTestResponse("HEAD", test.path, test.body, test.headers, nil) - - if getresp.statusCode != headresp.statusCode { - t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) - } - if len(headresp.body) != 0 { - t.Fatalf("head request arrived with a body") - } - - var cl []string - var getcl, headcl int - var hascl1, hascl2 bool - - if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { - getcl, _ = strconv.Atoi(cl[0]) - } - - if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { - headcl, _ = strconv.Atoi(cl[0]) - } - - if hascl1 != hascl2 { - t.Fatalf("head and get: one has content-length, one doesn't") - } - - if hascl1 == true && getcl != headcl { - t.Fatalf("head and get content-length differ") - } - } + for _, test := range tests { + + if test.method != "GET" { + continue + } + getresp := getTestResponse("GET", test.path, test.body, test.headers, nil) + headresp := getTestResponse("HEAD", test.path, test.body, test.headers, nil) + + if getresp.statusCode != headresp.statusCode { + t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) + } + if len(headresp.body) != 0 { + t.Fatalf("head request arrived with a body") + } + + var cl []string + var getcl, headcl int + var hascl1, hascl2 bool + + if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { + getcl, _ = strconv.Atoi(cl[0]) + } + + if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { + headcl, _ = strconv.Atoi(cl[0]) + } + + if hascl1 != hascl2 { + t.Fatalf("head and get: one has content-length, one doesn't") + } + + if hascl1 == true && getcl != headcl { + t.Fatalf("head and get content-length differ") + } + } } func buildTestScgiRequest(method string, path string, body string, headers map[string][]string) *bytes.Buffer { - var headerBuf bytes.Buffer - scgiHeaders := make(map[string]string) - - headerBuf.WriteString("CONTENT_LENGTH") - headerBuf.WriteByte(0) - headerBuf.WriteString(fmt.Sprintf("%d", len(body))) - headerBuf.WriteByte(0) - - scgiHeaders["REQUEST_METHOD"] = method - scgiHeaders["HTTP_HOST"] = "127.0.0.1" - scgiHeaders["REQUEST_URI"] = path - scgiHeaders["SERVER_PORT"] = "80" - scgiHeaders["SERVER_PROTOCOL"] = "HTTP/1.1" - scgiHeaders["USER_AGENT"] = "web.go test framework" - - for k, v := range headers { - //Skip content-length - if k == "Content-Length" { - continue - } - key := "HTTP_" + strings.ToUpper(strings.Replace(k, "-", "_", -1)) - scgiHeaders[key] = v[0] - } - for k, v := range scgiHeaders { - headerBuf.WriteString(k) - headerBuf.WriteByte(0) - headerBuf.WriteString(v) - headerBuf.WriteByte(0) - } - headerData := headerBuf.Bytes() - - var buf bytes.Buffer - //extra 1 is for the comma at the end - dlen := len(headerData) - fmt.Fprintf(&buf, "%d:", dlen) - buf.Write(headerData) - buf.WriteByte(',') - buf.WriteString(body) - return &buf + var headerBuf bytes.Buffer + scgiHeaders := make(map[string]string) + + headerBuf.WriteString("CONTENT_LENGTH") + headerBuf.WriteByte(0) + headerBuf.WriteString(fmt.Sprintf("%d", len(body))) + headerBuf.WriteByte(0) + + scgiHeaders["REQUEST_METHOD"] = method + scgiHeaders["HTTP_HOST"] = "127.0.0.1" + scgiHeaders["REQUEST_URI"] = path + scgiHeaders["SERVER_PORT"] = "80" + scgiHeaders["SERVER_PROTOCOL"] = "HTTP/1.1" + scgiHeaders["USER_AGENT"] = "web.go test framework" + + for k, v := range headers { + //Skip content-length + if k == "Content-Length" { + continue + } + key := "HTTP_" + strings.ToUpper(strings.Replace(k, "-", "_", -1)) + scgiHeaders[key] = v[0] + } + for k, v := range scgiHeaders { + headerBuf.WriteString(k) + headerBuf.WriteByte(0) + headerBuf.WriteString(v) + headerBuf.WriteByte(0) + } + headerData := headerBuf.Bytes() + + var buf bytes.Buffer + //extra 1 is for the comma at the end + dlen := len(headerData) + fmt.Fprintf(&buf, "%d:", dlen) + buf.Write(headerData) + buf.WriteByte(',') + buf.WriteString(body) + return &buf } func TestScgi(t *testing.T) { - for _, test := range tests { - req := buildTestScgiRequest(test.method, test.path, test.body, test.headers) - var output bytes.Buffer - nb := ioBuffer{input: req, output: &output} - mainServer.handleScgiRequest(&nb) - resp := buildTestResponse(&output) - - if resp.statusCode != test.expectedStatus { - t.Fatalf("expected status %d got %d", test.expectedStatus, resp.statusCode) - } - - if resp.body != test.expectedBody { - t.Fatalf("Scgi expected %q got %q", test.expectedBody, resp.body) - } - } + for _, test := range tests { + req := buildTestScgiRequest(test.method, test.path, test.body, test.headers) + var output bytes.Buffer + nb := ioBuffer{input: req, output: &output} + mainServer.handleScgiRequest(&nb) + resp := buildTestResponse(&output) + + if resp.statusCode != test.expectedStatus { + t.Fatalf("expected status %d got %d", test.expectedStatus, resp.statusCode) + } + + if resp.body != test.expectedBody { + t.Fatalf("Scgi expected %q got %q", test.expectedBody, resp.body) + } + } } func TestScgiHead(t *testing.T) { - for _, test := range tests { - - if test.method != "GET" { - continue - } - - req := buildTestScgiRequest("GET", test.path, test.body, make(map[string][]string)) - var output bytes.Buffer - nb := ioBuffer{input: req, output: &output} - mainServer.handleScgiRequest(&nb) - getresp := buildTestResponse(&output) - - req = buildTestScgiRequest("HEAD", test.path, test.body, make(map[string][]string)) - var output2 bytes.Buffer - nb = ioBuffer{input: req, output: &output2} - mainServer.handleScgiRequest(&nb) - headresp := buildTestResponse(&output2) - - if getresp.statusCode != headresp.statusCode { - t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) - } - if len(headresp.body) != 0 { - t.Fatalf("head request arrived with a body") - } - - var cl []string - var getcl, headcl int - var hascl1, hascl2 bool - - if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { - getcl, _ = strconv.Atoi(cl[0]) - } - - if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { - headcl, _ = strconv.Atoi(cl[0]) - } - - if hascl1 != hascl2 { - t.Fatalf("head and get: one has content-length, one doesn't") - } - - if hascl1 == true && getcl != headcl { - t.Fatalf("head and get content-length differ") - } - } + for _, test := range tests { + + if test.method != "GET" { + continue + } + + req := buildTestScgiRequest("GET", test.path, test.body, make(map[string][]string)) + var output bytes.Buffer + nb := ioBuffer{input: req, output: &output} + mainServer.handleScgiRequest(&nb) + getresp := buildTestResponse(&output) + + req = buildTestScgiRequest("HEAD", test.path, test.body, make(map[string][]string)) + var output2 bytes.Buffer + nb = ioBuffer{input: req, output: &output2} + mainServer.handleScgiRequest(&nb) + headresp := buildTestResponse(&output2) + + if getresp.statusCode != headresp.statusCode { + t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) + } + if len(headresp.body) != 0 { + t.Fatalf("head request arrived with a body") + } + + var cl []string + var getcl, headcl int + var hascl1, hascl2 bool + + if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { + getcl, _ = strconv.Atoi(cl[0]) + } + + if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { + headcl, _ = strconv.Atoi(cl[0]) + } + + if hascl1 != hascl2 { + t.Fatalf("head and get: one has content-length, one doesn't") + } + + if hascl1 == true && getcl != headcl { + t.Fatalf("head and get content-length differ") + } + } } func TestReadScgiRequest(t *testing.T) { - headers := map[string][]string{"User-Agent": {"web.go"}} - req := buildTestScgiRequest("POST", "/hello", "Hello world!", headers) - var s Server - httpReq, err := s.readScgiRequest(&ioBuffer{input: req, output: nil}) - if err != nil { - t.Fatalf("Error while reading SCGI request: ", err.Error()) - } - if httpReq.ContentLength != 12 { - t.Fatalf("Content length mismatch, expected %d, got %d ", 12, httpReq.ContentLength) - } - var body bytes.Buffer - io.Copy(&body, httpReq.Body) - if body.String() != "Hello world!" { - t.Fatalf("Body mismatch, expected %q, got %q ", "Hello world!", body.String()) - } + headers := map[string][]string{"User-Agent": {"web.go"}} + req := buildTestScgiRequest("POST", "/hello", "Hello world!", headers) + var s Server + httpReq, err := s.readScgiRequest(&ioBuffer{input: req, output: nil}) + if err != nil { + t.Fatalf("Error while reading SCGI request: ", err.Error()) + } + if httpReq.ContentLength != 12 { + t.Fatalf("Content length mismatch, expected %d, got %d ", 12, httpReq.ContentLength) + } + var body bytes.Buffer + io.Copy(&body, httpReq.Body) + if body.String() != "Hello world!" { + t.Fatalf("Body mismatch, expected %q, got %q ", "Hello world!", body.String()) + } } func makeCookie(vals map[string]string) []*http.Cookie { - var cookies []*http.Cookie - for k, v := range vals { - c := &http.Cookie{ - Name: k, - Value: v, - } - cookies = append(cookies, c) - } - return cookies + var cookies []*http.Cookie + for k, v := range vals { + c := &http.Cookie{ + Name: k, + Value: v, + } + cookies = append(cookies, c) + } + return cookies } func TestSecureCookie(t *testing.T) { - mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" - resp1 := getTestResponse("POST", "/securecookie/set/a/1", "", nil, nil) - sval, ok := resp1.cookies["a"] - if !ok { - t.Fatalf("Failed to get cookie ") - } - cookies := makeCookie(map[string]string{"a": sval}) - - resp2 := getTestResponse("GET", "/securecookie/get/a", "", nil, cookies) - - if resp2.body != "1" { - t.Fatalf("SecureCookie test failed") - } + mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" + resp1 := getTestResponse("POST", "/securecookie/set/a/1", "", nil, nil) + sval, ok := resp1.cookies["a"] + if !ok { + t.Fatalf("Failed to get cookie ") + } + cookies := makeCookie(map[string]string{"a": sval}) + + resp2 := getTestResponse("GET", "/securecookie/get/a", "", nil, cookies) + + if resp2.body != "1" { + t.Fatalf("SecureCookie test failed") + } } func TestEarlyClose(t *testing.T) { - var server1 Server - server1.Close() + var server1 Server + server1.Close() } func TestOptions(t *testing.T) { - resp := getTestResponse("OPTIONS", "/options", "", nil, nil) - if resp.headers["Access-Control-Allow-Methods"][0] != "POST, GET, OPTIONS" { - t.Fatalf("TestOptions - Access-Control-Allow-Methods failed") - } - if resp.headers["Access-Control-Max-Age"][0] != "1000" { - t.Fatalf("TestOptions - Access-Control-Max-Age failed") - } + resp := getTestResponse("OPTIONS", "/options", "", nil, nil) + if resp.headers["Access-Control-Allow-Methods"][0] != "POST, GET, OPTIONS" { + t.Fatalf("TestOptions - Access-Control-Allow-Methods failed") + } + if resp.headers["Access-Control-Max-Age"][0] != "1000" { + t.Fatalf("TestOptions - Access-Control-Max-Age failed") + } } func TestSlug(t *testing.T) { - tests := [][]string{ - {"", ""}, - {"a", "a"}, - {"a/b", "a-b"}, - {"a b", "a-b"}, - {"a////b", "a-b"}, - {" a////b ", "a-b"}, - {" Manowar / Friends ", "manowar-friends"}, - } - - for _, test := range tests { - v := Slug(test[0], "-") - if v != test[1] { - t.Fatalf("TestSlug(%v) failed, expected %v, got %v", test[0], test[1], v) - } - } + tests := [][]string{ + {"", ""}, + {"a", "a"}, + {"a/b", "a-b"}, + {"a b", "a-b"}, + {"a////b", "a-b"}, + {" a////b ", "a-b"}, + {" Manowar / Friends ", "manowar-friends"}, + } + + for _, test := range tests { + v := Slug(test[0], "-") + if v != test[1] { + t.Fatalf("TestSlug(%v) failed, expected %v, got %v", test[0], test[1], v) + } + } } // tests that we don't duplicate headers func TestDuplicateHeader(t *testing.T) { - resp := testGet("/dupeheader", nil) - if len(resp.headers["Server"]) > 1 { - t.Fatalf("Expected only one header, got %#v", resp.headers["Server"]) - } - if resp.headers["Server"][0] != "myserver" { - t.Fatalf("Incorrect header, exp 'myserver', got %q", resp.headers["Server"][0]) - } + resp := testGet("/dupeheader", nil) + if len(resp.headers["Server"]) > 1 { + t.Fatalf("Expected only one header, got %#v", resp.headers["Server"]) + } + if resp.headers["Server"][0] != "myserver" { + t.Fatalf("Incorrect header, exp 'myserver', got %q", resp.headers["Server"][0]) + } } func BuildBasicAuthCredentials(user string, pass string) string { - s := user + ":" + pass - return "Basic " + base64.StdEncoding.EncodeToString([]byte(s)) + s := user + ":" + pass + return "Basic " + base64.StdEncoding.EncodeToString([]byte(s)) } func BenchmarkProcessGet(b *testing.B) { - s := NewServer() - s.SetLogger(log.New(ioutil.Discard, "", 0)) - s.Get("/echo/(.*)", func(s string) string { - return s - }) - req := buildTestRequest("GET", "/echo/hi", "", nil, nil) - var buf bytes.Buffer - iob := ioBuffer{input: nil, output: &buf} - c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - s.Process(&c, req) - } + s := NewServer() + s.SetLogger(log.New(ioutil.Discard, "", 0)) + s.Get("/echo/(.*)", func(s string) string { + return s + }) + req := buildTestRequest("GET", "/echo/hi", "", nil, nil) + var buf bytes.Buffer + iob := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Process(&c, req) + } } func BenchmarkProcessPost(b *testing.B) { - s := NewServer() - s.SetLogger(log.New(ioutil.Discard, "", 0)) - s.Post("/echo/(.*)", func(s string) string { - return s - }) - req := buildTestRequest("POST", "/echo/hi", "", nil, nil) - var buf bytes.Buffer - iob := ioBuffer{input: nil, output: &buf} - c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - s.Process(&c, req) - } + s := NewServer() + s.SetLogger(log.New(ioutil.Discard, "", 0)) + s.Post("/echo/(.*)", func(s string) string { + return s + }) + req := buildTestRequest("POST", "/echo/hi", "", nil, nil) + var buf bytes.Buffer + iob := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Process(&c, req) + } } From acfbad68f14cee20d03f005d6f11265ffa4ca685 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 08:59:37 -0700 Subject: [PATCH 11/33] Use Go's http.StatusText for HTTP status messages Also remove `status.go`, which contained all the status messages copied from the Go http package. --- scgi.go | 2 +- status.go | 54 ----------------------------------------------------- web_test.go | 4 ++-- 3 files changed, 3 insertions(+), 57 deletions(-) delete mode 100644 status.go diff --git a/scgi.go b/scgi.go index 7d0dc535..2c2d36bd 100644 --- a/scgi.go +++ b/scgi.go @@ -43,7 +43,7 @@ func (conn *scgiConn) WriteHeader(status int) { conn.wroteHeaders = true var buf bytes.Buffer - text := statusText[status] + text := http.StatusText(status) fmt.Fprintf(&buf, "HTTP/1.1 %d %s\r\n", status, text) diff --git a/status.go b/status.go deleted file mode 100644 index e0d995f9..00000000 --- a/status.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package web - -import "net/http" - -var statusText = map[int]string{ - http.StatusContinue: "Continue", - http.StatusSwitchingProtocols: "Switching Protocols", - - http.StatusOK: "OK", - http.StatusCreated: "Created", - http.StatusAccepted: "Accepted", - http.StatusNonAuthoritativeInfo: "Non-Authoritative Information", - http.StatusNoContent: "No Content", - http.StatusResetContent: "Reset Content", - http.StatusPartialContent: "Partial Content", - - http.StatusMultipleChoices: "Multiple Choices", - http.StatusMovedPermanently: "Moved Permanently", - http.StatusFound: "Found", - http.StatusSeeOther: "See Other", - http.StatusNotModified: "Not Modified", - http.StatusUseProxy: "Use Proxy", - http.StatusTemporaryRedirect: "Temporary Redirect", - - http.StatusBadRequest: "Bad Request", - http.StatusUnauthorized: "Unauthorized", - http.StatusPaymentRequired: "Payment Required", - http.StatusForbidden: "Forbidden", - http.StatusNotFound: "Not Found", - http.StatusMethodNotAllowed: "Method Not Allowed", - http.StatusNotAcceptable: "Not Acceptable", - http.StatusProxyAuthRequired: "Proxy Authentication Required", - http.StatusRequestTimeout: "Request Timeout", - http.StatusConflict: "Conflict", - http.StatusGone: "Gone", - http.StatusLengthRequired: "Length Required", - http.StatusPreconditionFailed: "Precondition Failed", - http.StatusRequestEntityTooLarge: "Request Entity Too Large", - http.StatusRequestURITooLong: "Request URI Too Long", - http.StatusUnsupportedMediaType: "Unsupported Media Type", - http.StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", - http.StatusExpectationFailed: "Expectation Failed", - - http.StatusInternalServerError: "Internal Server Error", - http.StatusNotImplemented: "Not Implemented", - http.StatusBadGateway: "Bad Gateway", - http.StatusServiceUnavailable: "Service Unavailable", - http.StatusGatewayTimeout: "Gateway Timeout", - http.StatusHTTPVersionNotSupported: "HTTP Version Not Supported", -} diff --git a/web_test.go b/web_test.go index 2c5f069a..44107fd9 100644 --- a/web_test.go +++ b/web_test.go @@ -138,7 +138,7 @@ func init() { Get("/error/code/(.*)", func(ctx *Context, code string) string { n, _ := strconv.Atoi(code) - message := statusText[n] + message := http.StatusText(n) ctx.Abort(n, message) return "" }) @@ -237,7 +237,7 @@ var tests = []Test{ {"GET", "/error/notfound/notfound", nil, "", 404, "notfound"}, {"GET", "/doesnotexist", nil, "", 404, "Page not found"}, {"POST", "/doesnotexist", nil, "", 404, "Page not found"}, - {"GET", "/error/code/500", nil, "", 500, statusText[500]}, + {"GET", "/error/code/500", nil, "", 500, http.StatusText(500)}, {"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"}, {"GET", "/getparam?a=abcd", nil, "", 200, "abcd"}, {"GET", "/getparam?b=abcd", nil, "", 200, ""}, From 062368c4c9a20a717333a16051db3ff91a34b0b4 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 11:18:49 -0700 Subject: [PATCH 12/33] Fix formatting of web_test.go There was an extra space before the call to http.StatusText. A check has been added to Drone to catch these issues. --- web_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web_test.go b/web_test.go index 44107fd9..8d8d1bd4 100644 --- a/web_test.go +++ b/web_test.go @@ -237,7 +237,7 @@ var tests = []Test{ {"GET", "/error/notfound/notfound", nil, "", 404, "notfound"}, {"GET", "/doesnotexist", nil, "", 404, "Page not found"}, {"POST", "/doesnotexist", nil, "", 404, "Page not found"}, - {"GET", "/error/code/500", nil, "", 500, http.StatusText(500)}, + {"GET", "/error/code/500", nil, "", 500, http.StatusText(500)}, {"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"}, {"GET", "/getparam?a=abcd", nil, "", 200, "abcd"}, {"GET", "/getparam?b=abcd", nil, "", 200, ""}, From 128d5dd585a2170e5128c8362e00de881caea5ef Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 11:27:23 -0700 Subject: [PATCH 13/33] Add the 'ColorOutput' server config option When this option is set to true, the log output will contain color escape sequences. Set it to false to disable color escape sequences. Inspired by https://github.com/xyproto/web/commit/88b1a319e4f07943b0e4d3ac032862cd1aaf15c2 Resolves #153 --- server.go | 13 +++++++++++-- web.go | 1 + web_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index e84c674b..09dafdb3 100644 --- a/server.go +++ b/server.go @@ -27,6 +27,7 @@ type ServerConfig struct { CookieSecret string RecoverPanic bool Profiler bool + ColorOutput bool } // Server represents a web.go server. @@ -284,10 +285,18 @@ func (s *Server) logRequest(ctx Context, sTime time.Time) { client = req.RemoteAddr } - fmt.Fprintf(&logEntry, "%s - \033[32;1m %s %s\033[0m - %v", client, req.Method, requestPath, duration) + if s.Config.ColorOutput { + fmt.Fprintf(&logEntry, "%s - \x1b[32;1m%s %s\x1b[0m - %v", client, req.Method, requestPath, duration) + } else { + fmt.Fprintf(&logEntry, "%s - %s %s - %v", client, req.Method, requestPath, duration) + } if len(ctx.Params) > 0 { - fmt.Fprintf(&logEntry, " - \033[37;1mParams: %v\033[0m\n", ctx.Params) + if s.Config.ColorOutput { + fmt.Fprintf(&logEntry, " - \x1b[37;1mParams: %v\x1b[0m\n", ctx.Params) + } else { + fmt.Fprintf(&logEntry, " - Params: %v\n", ctx.Params) + } } ctx.Server.Logger.Print(logEntry.String()) diff --git a/web.go b/web.go index 1c597a47..16c735e9 100644 --- a/web.go +++ b/web.go @@ -265,6 +265,7 @@ func SetLogger(logger *log.Logger) { // Config is the configuration of the main server. var Config = &ServerConfig{ RecoverPanic: true, + ColorOutput: true, } var mainServer = NewServer() diff --git a/web_test.go b/web_test.go index 8d8d1bd4..c2d0727c 100644 --- a/web_test.go +++ b/web_test.go @@ -540,6 +540,45 @@ func TestDuplicateHeader(t *testing.T) { } } +// test that output contains ASCII color codes by default +func TestColorOutputDefault(t *testing.T) { + s := NewServer() + var logOutput bytes.Buffer + logger := log.New(&logOutput, "", 0) + s.Logger = logger + s.Get("/test", func() string { + return "test" + }) + req := buildTestRequest("GET", "/test", "", nil, nil) + var buf bytes.Buffer + iob := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} + s.Process(&c, req) + if !strings.Contains(logOutput.String(), "\x1b") { + t.Fatalf("The default log output does not seem to be colored") + } +} + +// test that output contains ASCII color codes by default +func TestNoColorOutput(t *testing.T) { + s := NewServer() + s.Config.ColorOutput = false + var logOutput bytes.Buffer + logger := log.New(&logOutput, "", 0) + s.Logger = logger + s.Get("/test", func() string { + return "test" + }) + req := buildTestRequest("GET", "/test", "", nil, nil) + var buf bytes.Buffer + iob := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} + s.Process(&c, req) + if strings.Contains(logOutput.String(), "\x1b") { + t.Fatalf("The log contains color escape codes") + } +} + func BuildBasicAuthCredentials(user string, pass string) string { s := user + ":" + pass return "Basic " + base64.StdEncoding.EncodeToString([]byte(s)) From 8df0eaef3a5767d2690a0fec741786226b5c13b6 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 12:21:54 -0700 Subject: [PATCH 14/33] Add .travis.yml file Because Drone doesn't seem to be able to set pull request build statuses, I'd like to give Travis a try. --- .travis.yml | 1 + 1 file changed, 1 insertion(+) create mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..4f2ee4d9 --- /dev/null +++ b/.travis.yml @@ -0,0 +1 @@ +language: go From f7bc145ffb32239bbdf80e27954c70324e38a292 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 12:27:43 -0700 Subject: [PATCH 15/33] Override test script for Travis By default Travis runs `go test ./...` as the test script, which fails on the examples. Change it to `go test -short`. --- .travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index 4f2ee4d9..0ee1ce74 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1 +1,3 @@ language: go + +script: go test -short From 461beeadb26c180179850f73dbfedf54d77a6db3 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 12:30:59 -0700 Subject: [PATCH 16/33] Override install script for travis Update the `install` script to a more simple `go get`. The Travis default `go get ./...` has problems with the examples directory. --- .travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index 0ee1ce74..50fd2a21 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,5 @@ language: go +install: go get + script: go test -short From 19e018c8945a2eaab35b458f175da76509441a23 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 12:35:22 -0700 Subject: [PATCH 17/33] Update build badge The CI has been moved from Drone to Travis. --- Readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Readme.md b/Readme.md index 93ce684d..2be27536 100644 --- a/Readme.md +++ b/Readme.md @@ -1,4 +1,4 @@ -[![Build Status](https://drone.io/github.com/hoisie/web/status.png)](https://drone.io/github.com/hoisie/web/latest) +[![Build Status](https://travis-ci.org/hoisie/web.svg?branch=master)](https://travis-ci.org/hoisie/web) # web.go From 14bd2ffe0a4df6658273eb6a61548d8d6ce8f5bc Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 12:48:25 -0700 Subject: [PATCH 18/33] Add gofmt check for Travis This will help ensure that code is properly formatted. --- .travis.gofmt.sh | 10 ++++++++++ .travis.yml | 4 +++- 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100755 .travis.gofmt.sh diff --git a/.travis.gofmt.sh b/.travis.gofmt.sh new file mode 100755 index 00000000..bca2bf2f --- /dev/null +++ b/.travis.gofmt.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +cd "$(dirname $0)" + +BADLY_FORMATTED="$(go fmt ./...)" + +if [[ -n $BADLY_FORMATTED ]]; then + echo "The following files are badly formatted: $BADLY_FORMATTED" + exit 1 +fi diff --git a/.travis.yml b/.travis.yml index 50fd2a21..1f90d7fd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,4 +2,6 @@ language: go install: go get -script: go test -short +script: + - ./.travis.gofmt.sh + - go test -short From 5a1d2269626044d0e9d98d9a989b4cd0e087d858 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 12:03:49 -0700 Subject: [PATCH 19/33] Add check for 'Authorization' header in GetBasicAuth Previously, if the `Authorization` header was not provided, the method would crash. Add a check for the presence a header, and a test case as well. Resolves #174 --- helpers.go | 7 +++++-- web_test.go | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/helpers.go b/helpers.go index e77d87f2..a87e93f4 100644 --- a/helpers.go +++ b/helpers.go @@ -91,9 +91,12 @@ func NewCookie(name string, value string, age int64) *http.Cookie { return &http.Cookie{Name: name, Value: value, Expires: utctime} } -// GetBasicAuth is a helper method of *Context that returns the decoded -// user and password from the *Context's authorization header +// GetBasicAuth returns the decoded user and password from the context's +// 'Authorization' header. func (ctx *Context) GetBasicAuth() (string, string, error) { + if len(ctx.Request.Header["Authorization"]) == 0 { + return "", "", errors.New("No Authorization header provided") + } authHeader := ctx.Request.Header["Authorization"][0] authString := strings.Split(string(authHeader), " ") if authString[0] != "Basic" { diff --git a/web_test.go b/web_test.go index c2d0727c..ee296b46 100644 --- a/web_test.go +++ b/web_test.go @@ -248,6 +248,7 @@ var tests = []Test{ {"POST", "/parsejson", map[string][]string{"Content-Type": {"application/json"}}, `{"a":"hello", "b":"world"}`, 200, "hello world"}, //{"GET", "/testenv", "", 200, "hello world"}, {"GET", "/authorization", map[string][]string{"Authorization": {BuildBasicAuthCredentials("foo", "bar")}}, "", 200, "foobar"}, + {"GET", "/authorization", nil, "", 200, "fail"}, } func buildTestRequest(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *http.Request { From bbb68ed3db57539d60bf62efddf37a9d0f84a45e Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 14:29:39 -0700 Subject: [PATCH 20/33] Add check for secure cookie structure A secure cookie has three parts separated by the pipe ("|") character. Before trying to parse it, ensure there are actually three parts. This is a potential fix for #163 --- web.go | 3 +++ web_test.go | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/web.go b/web.go index 16c735e9..8f464570 100644 --- a/web.go +++ b/web.go @@ -145,6 +145,9 @@ func (ctx *Context) GetSecureCookie(name string) (string, bool) { } parts := strings.SplitN(cookie.Value, "|", 3) + if len(parts) != 3 { + return "", false + } val := parts[0] timestamp := parts[1] diff --git a/web_test.go b/web_test.go index ee296b46..7861871b 100644 --- a/web_test.go +++ b/web_test.go @@ -496,6 +496,17 @@ func TestSecureCookie(t *testing.T) { } } +func TestEmptySecureCookie(t *testing.T) { + mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" + cookies := makeCookie(map[string]string{"empty": ""}) + + resp2 := getTestResponse("GET", "/securecookie/get/empty", "", nil, cookies) + + if resp2.body != "" { + t.Fatalf("Expected an empty secure cookie") + } +} + func TestEarlyClose(t *testing.T) { var server1 Server server1.Close() From 05245864d04b88eaef648e1b053bef9d9c01861f Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 14:36:07 -0700 Subject: [PATCH 21/33] Add trailing newlines to many of the 'hello world' examples This makes them more curl-friendly. --- examples/hello.go | 2 +- examples/logger.go | 2 +- examples/multiserver.go | 4 ++-- examples/tls.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/hello.go b/examples/hello.go index 29bf8d10..26ecb0b0 100644 --- a/examples/hello.go +++ b/examples/hello.go @@ -4,7 +4,7 @@ import ( "github.com/hoisie/web" ) -func hello(val string) string { return "hello " + val } +func hello(val string) string { return "hello " + val + "\n" } func main() { web.Get("/(.*)", hello) diff --git a/examples/logger.go b/examples/logger.go index e4b134a3..462ad610 100644 --- a/examples/logger.go +++ b/examples/logger.go @@ -6,7 +6,7 @@ import ( "os" ) -func hello(val string) string { return "hello " + val } +func hello(val string) string { return "hello " + val + "\n" } func main() { f, err := os.Create("server.log") diff --git a/examples/multiserver.go b/examples/multiserver.go index 4fc479ed..2523cb90 100644 --- a/examples/multiserver.go +++ b/examples/multiserver.go @@ -4,9 +4,9 @@ import ( "github.com/hoisie/web" ) -func hello1(val string) string { return "hello1 " + val } +func hello1(val string) string { return "hello1 " + val + "\n" } -func hello2(val string) string { return "hello2 " + val } +func hello2(val string) string { return "hello2 " + val + "\n" } func main() { var server1 web.Server diff --git a/examples/tls.go b/examples/tls.go index dc189eac..951418fd 100644 --- a/examples/tls.go +++ b/examples/tls.go @@ -47,7 +47,7 @@ gWrxykqyLToIiAuL+pvC3Jv8IOPIiVFsY032rOqcwSGdVUyhTsG28+7KnR6744tM -----END CERTIFICATE----- ` -func hello(val string) string { return "hello " + val } +func hello(val string) string { return "hello " + val + "\n" } func main() { config := tls.Config{ From 6e587a8844f2b8b3e521a6083d1193021a260f32 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 20:49:07 -0700 Subject: [PATCH 22/33] Rename 'web.Handler' to 'web.Handle' Most other functions in `web` are in the imperative tense. Rename `Handler` to `Handle` for consistency. Perform a similar rename for the 'Server' type. Also, simplify the comment a bit. --- server.go | 4 ++-- web.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/server.go b/server.go index 09dafdb3..384dbc73 100644 --- a/server.go +++ b/server.go @@ -123,8 +123,8 @@ func (s *Server) Match(method string, route string, handler interface{}) { s.addRoute(route, method, handler) } -//Adds a custom handler. Only for webserver mode. Will have no effect when running as FCGI or SCGI. -func (s *Server) Handler(route string, method string, httpHandler http.Handler) { +// Add a custom http.Handler. Will have no effect when running as FCGI or SCGI. +func (s *Server) Handle(route string, method string, httpHandler http.Handler) { s.addRoute(route, method, httpHandler) } diff --git a/web.go b/web.go index 8f464570..454472d8 100644 --- a/web.go +++ b/web.go @@ -250,9 +250,9 @@ func Match(method string, route string, handler interface{}) { mainServer.addRoute(route, method, handler) } -//Adds a custom handler. Only for webserver mode. Will have no effect when running as FCGI or SCGI. -func Handler(route string, method string, httpHandler http.Handler) { - mainServer.Handler(route, method, httpHandler) +// Add a custom HTTP handler for a path. This will have no effect when running as FCGI or SCGI. +func Handle(route string, method string, httpHandler http.Handler) { + mainServer.Handle(route, method, httpHandler) } //Adds a handler for websockets. Only for webserver mode. Will have no effect when running as FCGI or SCGI. From c21e884130c4417c9339d509d8dfc1997b5a9787 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 30 Jul 2016 20:52:46 -0700 Subject: [PATCH 23/33] Stop setting a 'Content-Type' header for custom HTTP handlers Previously, when using web.Handle, the `Content-Type` HTTP header was set by default to `text/html; charset=utf-8`. This does not play nicely well with Go's FileHandler. If a `Content-Type` header is set, Go's FileHandler will not overwrite it. This breaks serving static assets with a FileHandler. This resolves #158 --- server.go | 6 +++--- web.go | 3 ++- web_test.go | 18 ++++++++++++++++++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/server.go b/server.go index 384dbc73..9869f46b 100644 --- a/server.go +++ b/server.go @@ -335,9 +335,6 @@ func (s *Server) routeHandler(req *http.Request, w http.ResponseWriter) (unused } } - //Set the default content-type - ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true) - for i := 0; i < len(s.routes); i++ { route := s.routes[i] cr := route.cr @@ -361,6 +358,9 @@ func (s *Server) routeHandler(req *http.Request, w http.ResponseWriter) (unused return } + // set the default content-type + ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true) + var args []reflect.Value handlerType := route.handler.Type() if requiresContext(handlerType) { diff --git a/web.go b/web.go index 454472d8..5533cd77 100644 --- a/web.go +++ b/web.go @@ -43,6 +43,7 @@ func (ctx *Context) WriteString(content string) { // Once it has been called, any return value from the handler will // not be written to the response. func (ctx *Context) Abort(status int, body string) { + ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true) ctx.ResponseWriter.WriteHeader(status) ctx.ResponseWriter.Write([]byte(body)) } @@ -250,7 +251,7 @@ func Match(method string, route string, handler interface{}) { mainServer.addRoute(route, method, handler) } -// Add a custom HTTP handler for a path. This will have no effect when running as FCGI or SCGI. +// Add a custom http.Handler. Will have no effect when running as FCGI or SCGI. func Handle(route string, method string, httpHandler http.Handler) { mainServer.Handle(route, method, httpHandler) } diff --git a/web_test.go b/web_test.go index 7861871b..2c44ad7d 100644 --- a/web_test.go +++ b/web_test.go @@ -591,6 +591,24 @@ func TestNoColorOutput(t *testing.T) { } } +type TestHandler struct{} + +func (t *TestHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) { +} + +// When a custom HTTP handler is used, the Content-Type header should not be set to a default. +// Go's FileHandler does not replace the Content-Type header if it is already set. +func TestCustomHandlerContentType(t *testing.T) { + s := NewServer() + s.Handle("/testHandler", "GET", &TestHandler{}) + req := buildTestRequest("GET", "/testHandler", "", nil, nil) + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: nil} + s.Process(&c, req) + if c.headers["Content-Type"] != nil { + t.Fatalf("A default Content-Type should not be present when using a custom HTTP handler") + } +} + func BuildBasicAuthCredentials(user string, pass string) string { s := user + ":" + pass return "Basic " + base64.StdEncoding.EncodeToString([]byte(s)) From 86285b6b1fbb195e233c61641cccc4873a1c03f0 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Mon, 1 Aug 2016 18:26:31 -0700 Subject: [PATCH 24/33] Simplify gofmt check in Travis Instead of providing a custom script, the check can be replaced with a handy one-liner: `diff -u <(echo -n) <(gofmt -d -s .)` --- .travis.gofmt.sh | 10 ---------- .travis.yml | 2 +- 2 files changed, 1 insertion(+), 11 deletions(-) delete mode 100755 .travis.gofmt.sh diff --git a/.travis.gofmt.sh b/.travis.gofmt.sh deleted file mode 100755 index bca2bf2f..00000000 --- a/.travis.gofmt.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -cd "$(dirname $0)" - -BADLY_FORMATTED="$(go fmt ./...)" - -if [[ -n $BADLY_FORMATTED ]]; then - echo "The following files are badly formatted: $BADLY_FORMATTED" - exit 1 -fi diff --git a/.travis.yml b/.travis.yml index 1f90d7fd..1a48aba3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,5 +3,5 @@ language: go install: go get script: - - ./.travis.gofmt.sh + - diff -u <(echo -n) <(gofmt -d -s .) - go test -short From 3c7dcfd7f2d75673f948a2fe3dd529f99774752d Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Mon, 1 Aug 2016 18:38:03 -0700 Subject: [PATCH 25/33] Specify go versions for Travis This should help maintain backwards compatibility with older go versions. --- .travis.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.travis.yml b/.travis.yml index 1a48aba3..b5ebb9ce 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,11 @@ language: go +go: + - 1.3 + - 1.4 + - 1.5 + - 1.6 + install: go get script: From 2f9512f85248b40f3c0db8cf09567e883ff2865f Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Mon, 1 Aug 2016 18:16:19 -0700 Subject: [PATCH 26/33] Gracefully handle malformed SCGI requests Previously, malformed SCGI requests would cause a panic when they were processed. Gracefully handle malformed SCGI requests. Also, add some additional error checking when parsing the length of the request, and clean up the code related to logging SCGI errors. Fixes #166 --- scgi.go | 15 +++++++++------ web_test.go | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/scgi.go b/scgi.go index 2c2d36bd..eea2b4fe 100644 --- a/scgi.go +++ b/scgi.go @@ -97,18 +97,20 @@ func (s *Server) readScgiRequest(fd io.ReadWriteCloser) (*http.Request, error) { reader := bufio.NewReader(fd) line, err := reader.ReadString(':') if err != nil { - s.Logger.Println("Error during SCGI read: ", err.Error()) + return nil, err + } + length, err := strconv.Atoi(line[0 : len(line)-1]) + if err != nil { + return nil, err } - length, _ := strconv.Atoi(line[0 : len(line)-1]) if length > 16384 { - s.Logger.Println("Error: max header size is 16k") + return nil, errors.New("Max header size is 16k") } headerData := make([]byte, length) _, err = reader.Read(headerData) if err != nil { return nil, err } - b, err := reader.ReadByte() if err != nil { return nil, err @@ -138,14 +140,15 @@ func (s *Server) readScgiRequest(fd io.ReadWriteCloser) (*http.Request, error) { } func (s *Server) handleScgiRequest(fd io.ReadWriteCloser) { + defer fd.Close() req, err := s.readScgiRequest(fd) if err != nil { - s.Logger.Println("SCGI error: %q", err.Error()) + s.Logger.Println("Error reading SCGI request: %q", err.Error()) + return } sc := scgiConn{fd, req, make(map[string][]string), false} s.routeHandler(req, &sc) sc.finishRequest() - fd.Close() } func (s *Server) listenAndServeScgi(addr string) error { diff --git a/web_test.go b/web_test.go index 2c44ad7d..f901a7fd 100644 --- a/web_test.go +++ b/web_test.go @@ -591,6 +591,29 @@ func TestNoColorOutput(t *testing.T) { } } +// a malformed SCGI request should be discarded and not cause a panic +func TestMaformedScgiRequest(t *testing.T) { + var headerBuf bytes.Buffer + + headerBuf.WriteString("CONTENT_LENGTH") + headerBuf.WriteByte(0) + headerBuf.WriteString("0") + headerBuf.WriteByte(0) + headerData := headerBuf.Bytes() + + var buf bytes.Buffer + fmt.Fprintf(&buf, "%d:", len(headerData)) + buf.Write(headerData) + buf.WriteByte(',') + + var output bytes.Buffer + nb := ioBuffer{input: &buf, output: &output} + mainServer.handleScgiRequest(&nb) + if !nb.closed { + t.Fatalf("The connection should have been closed") + } +} + type TestHandler struct{} func (t *TestHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) { From acb3d146a1d11bb9b91bccdfbeb395c0f735050f Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Mon, 1 Aug 2016 20:58:07 -0700 Subject: [PATCH 27/33] Discard log output during TestCustomHandlerContentType The request log wasn't being discarded. --- web_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/web_test.go b/web_test.go index f901a7fd..b4f83a4d 100644 --- a/web_test.go +++ b/web_test.go @@ -623,6 +623,7 @@ func (t *TestHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) { // Go's FileHandler does not replace the Content-Type header if it is already set. func TestCustomHandlerContentType(t *testing.T) { s := NewServer() + s.SetLogger(log.New(ioutil.Discard, "", 0)) s.Handle("/testHandler", "GET", &TestHandler{}) req := buildTestRequest("GET", "/testHandler", "", nil, nil) c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: nil} From c1d5d893c4f1f7bffba532ea674a11ecd74c8f9c Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Thu, 4 Aug 2016 17:46:46 -0700 Subject: [PATCH 28/33] Add a helper method for returning HTTP 400 This will cause HTTP 400 Bad Request to be returned. Resolves #180 --- web.go | 17 +++++++++++------ web_test.go | 5 +++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/web.go b/web.go index 5533cd77..d4926360 100644 --- a/web.go +++ b/web.go @@ -55,17 +55,16 @@ func (ctx *Context) Redirect(status int, url_ string) { ctx.ResponseWriter.Write([]byte("Redirecting to: " + url_)) } +//BadRequest writes a 400 HTTP response +func (ctx *Context) BadRequest() { + ctx.ResponseWriter.WriteHeader(400) +} + // Notmodified writes a 304 HTTP response func (ctx *Context) NotModified() { ctx.ResponseWriter.WriteHeader(304) } -// NotFound writes a 404 HTTP response -func (ctx *Context) NotFound(message string) { - ctx.ResponseWriter.WriteHeader(404) - ctx.ResponseWriter.Write([]byte(message)) -} - //Unauthorized writes a 401 HTTP response func (ctx *Context) Unauthorized() { ctx.ResponseWriter.WriteHeader(401) @@ -76,6 +75,12 @@ func (ctx *Context) Forbidden() { ctx.ResponseWriter.WriteHeader(403) } +// NotFound writes a 404 HTTP response +func (ctx *Context) NotFound(message string) { + ctx.ResponseWriter.WriteHeader(404) + ctx.ResponseWriter.Write([]byte(message)) +} + // ContentType sets the Content-Type header for an HTTP response. // For example, ctx.ContentType("json") sets the content-type to "application/json" // If the supplied value contains a slash (/) it is set as the Content-Type diff --git a/web_test.go b/web_test.go index b4f83a4d..1dba5830 100644 --- a/web_test.go +++ b/web_test.go @@ -145,6 +145,9 @@ func init() { Get("/error/notfound/(.*)", func(ctx *Context, message string) { ctx.NotFound(message) }) + Get("/error/badrequest", func(ctx *Context) { ctx.BadRequest() }) + Post("/error/badrequest", func(ctx *Context) { ctx.BadRequest() }) + Get("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) Post("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) @@ -230,6 +233,8 @@ var tests = []Test{ //long url {"GET", "/echo/" + strings.Repeat("0123456789", 100), nil, "", 200, strings.Repeat("0123456789", 100)}, {"GET", "/writetest", nil, "", 200, "hello"}, + {"GET", "/error/badrequest", nil, "", 400, ""}, + {"POST", "/error/badrequest", nil, "", 400, ""}, {"GET", "/error/unauthorized", nil, "", 401, ""}, {"POST", "/error/unauthorized", nil, "", 401, ""}, {"GET", "/error/forbidden", nil, "", 403, ""}, From 22000d109c0416fa17f621574f7a5c57b95d06af Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sat, 6 Aug 2016 10:36:55 -0700 Subject: [PATCH 29/33] Clean up the logic for color logging Previously, the escape sequences for terminal colors were included directly in the log string. This made it difficult to understand what was being logged. Add some wrapper methods to clean up the abstraction of the color output logic. Also, avoid using color logging if web.go isn't running in a terminal (e.g the output is being piped). --- server.go | 33 ++++++++++++++++++++------------- ttycolors.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 13 deletions(-) create mode 100644 ttycolors.go diff --git a/server.go b/server.go index 9869f46b..ae6b8a9f 100644 --- a/server.go +++ b/server.go @@ -269,7 +269,6 @@ func (s *Server) tryServingFile(name string, req *http.Request, w http.ResponseW func (s *Server) logRequest(ctx Context, sTime time.Time) { //log the request - var logEntry bytes.Buffer req := ctx.Request requestPath := req.URL.Path @@ -285,22 +284,30 @@ func (s *Server) logRequest(ctx Context, sTime time.Time) { client = req.RemoteAddr } - if s.Config.ColorOutput { - fmt.Fprintf(&logEntry, "%s - \x1b[32;1m%s %s\x1b[0m - %v", client, req.Method, requestPath, duration) - } else { - fmt.Fprintf(&logEntry, "%s - %s %s - %v", client, req.Method, requestPath, duration) - } - + var logEntry bytes.Buffer + logEntry.WriteString(client) + logEntry.WriteString(" - " + s.ttyGreen(req.Method+" "+requestPath)) + logEntry.WriteString(" - " + duration.String()) if len(ctx.Params) > 0 { - if s.Config.ColorOutput { - fmt.Fprintf(&logEntry, " - \x1b[37;1mParams: %v\x1b[0m\n", ctx.Params) - } else { - fmt.Fprintf(&logEntry, " - Params: %v\n", ctx.Params) - } + logEntry.WriteString(" - " + s.ttyWhite(fmt.Sprintf("Params: %v\n", ctx.Params))) } - ctx.Server.Logger.Print(logEntry.String()) +} + +func (s *Server) ttyGreen(msg string) string { + return s.ttyColor(msg, ttyCodes.green) +} + +func (s *Server) ttyWhite(msg string) string { + return s.ttyColor(msg, ttyCodes.white) +} +func (s *Server) ttyColor(msg string, colorCode string) string { + if s.Config.ColorOutput { + return colorCode + msg + ttyCodes.reset + } else { + return msg + } } // the main route handler in web.go diff --git a/ttycolors.go b/ttycolors.go new file mode 100644 index 00000000..fe63c1af --- /dev/null +++ b/ttycolors.go @@ -0,0 +1,30 @@ +package web + +import ( + "golang.org/x/crypto/ssh/terminal" + "syscall" +) + +var ttyCodes struct { + green string + white string + reset string +} + +func init() { + ttyCodes.green = ttyBold("32") + ttyCodes.white = ttyBold("37") + ttyCodes.reset = ttyEscape("0") +} + +func ttyBold(code string) string { + return ttyEscape("1;" + code) +} + +func ttyEscape(code string) string { + if terminal.IsTerminal(syscall.Stdout) { + return "\x1b[" + code + "m" + } else { + return "" + } +} From 955fd9413489f37867778a0ce588cce61fa132d9 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Sun, 7 Aug 2016 22:54:33 -0700 Subject: [PATCH 30/33] Simplify base64 encoding logic in SetSecureCookie Use `EncodeToString` instead of setting up a separate buffer. This code was likely written before `EncodeToString` existed. --- web.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/web.go b/web.go index d4926360..a2471960 100644 --- a/web.go +++ b/web.go @@ -127,20 +127,14 @@ func getCookieSig(key string, val []byte, timestamp string) string { } func (ctx *Context) SetSecureCookie(name string, val string, age int64) { - //base64 encode the val if len(ctx.Server.Config.CookieSecret) == 0 { ctx.Server.Logger.Println("Secret Key for secure cookies has not been set. Please assign a cookie secret to web.Config.CookieSecret.") return } - var buf bytes.Buffer - encoder := base64.NewEncoder(base64.StdEncoding, &buf) - encoder.Write([]byte(val)) - encoder.Close() - vs := buf.String() - vb := buf.Bytes() + encoded := base64.StdEncoding.EncodeToString([]byte(val)) timestamp := strconv.FormatInt(time.Now().Unix(), 10) - sig := getCookieSig(ctx.Server.Config.CookieSecret, vb, timestamp) - cookie := strings.Join([]string{vs, timestamp, sig}, "|") + sig := getCookieSig(ctx.Server.Config.CookieSecret, []byte(encoded), timestamp) + cookie := strings.Join([]string{encoded, timestamp, sig}, "|") ctx.SetCookie(NewCookie(name, cookie, age)) } From fe499299ddfe95c4d4350654e28513a46a5928d9 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Mon, 8 Aug 2016 23:15:39 -0700 Subject: [PATCH 31/33] Switch secure cookie implementation Previously, secure cookies in web.go were only cryptographically signed. This prevented them from being tampered with. However, the contents of the cookies were still transmitted in plain text to the client. Instead of only signing the contents of the cookie, encrypt the contents as well. This prevents any kind of information leakage. Secure cookies are now encrypted with AES counter mode with a 32 bit key. The contents are still signed using HMAC. Both the encryption key and the signature key are generated using pbkdf2 using the CookieSecret config option as the password source. The ciphertext, initialization vector, and signature are now transmitted to the client. Although the API is the same, cookies previously stored will not be readable. Unfortunately there is no smooth upgrade process. An example of using secure cookies has been added as well. Fixes #160 --- examples/secure_cookie.go | 52 +++++++++++++++++ secure_cookie.go | 115 ++++++++++++++++++++++++++++++++++++++ server.go | 8 +++ web.go | 63 --------------------- web_test.go | 1 + 5 files changed, 176 insertions(+), 63 deletions(-) create mode 100644 examples/secure_cookie.go create mode 100644 secure_cookie.go diff --git a/examples/secure_cookie.go b/examples/secure_cookie.go new file mode 100644 index 00000000..4da59bb7 --- /dev/null +++ b/examples/secure_cookie.go @@ -0,0 +1,52 @@ +package main + +import ( + "fmt" + "github.com/hoisie/web" + "html" +) + +var cookieName = "cookie" + +var notice = ` +
%v
+` +var form = ` + +
+ + +
+ + + + +` + +func index(ctx *web.Context) string { + cookie, ok := ctx.GetSecureCookie(cookieName) + var top string + if !ok { + top = fmt.Sprintf(notice, "The cookie has not been set") + } else { + var val = html.EscapeString(cookie) + top = fmt.Sprintf(notice, "The value of the cookie is '"+val+"'.") + } + return top + form +} + +func update(ctx *web.Context) { + if ctx.Params["submit"] == "Delete" { + ctx.SetCookie(web.NewCookie(cookieName, "", -1)) + } else { + ctx.SetSecureCookie(cookieName, ctx.Params["cookie"], 0) + } + ctx.Redirect(301, "/") +} + +func main() { + web.Config.CookieSecret = "a long secure cookie secret" + web.Get("/", index) + web.Post("/update", update) + web.Run("0.0.0.0:9999") +} diff --git a/secure_cookie.go b/secure_cookie.go new file mode 100644 index 00000000..a9cdba13 --- /dev/null +++ b/secure_cookie.go @@ -0,0 +1,115 @@ +package web + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha512" + "encoding/base64" + "errors" + "golang.org/x/crypto/pbkdf2" + "io" + "strings" +) + +const ( + pbkdf2Iterations = 64000 + keySize = 32 +) + +var ( + ErrMissingCookieSecret = errors.New("Secret Key for secure cookies has not been set. Assign one to web.Config.CookieSecret.") + ErrInvalidKey = errors.New("The keys for secure cookies have not been initialized. Ensure that a Run* method is being called") +) + +func (ctx *Context) SetSecureCookie(name string, val string, age int64) error { + serverConfig := ctx.Server.Config + if len(serverConfig.CookieSecret) == 0 { + return ErrMissingCookieSecret + } + + if len(serverConfig.encKey) == 0 || len(serverConfig.signKey) == 0 { + return ErrInvalidKey + } + ciphertext, err := encrypt([]byte(val), serverConfig.encKey) + if err != nil { + return err + } + sig := sign(ciphertext, serverConfig.signKey) + data := base64.StdEncoding.EncodeToString(ciphertext) + "|" + base64.StdEncoding.EncodeToString(sig) + ctx.SetCookie(NewCookie(name, data, age)) + return nil +} + +func (ctx *Context) GetSecureCookie(name string) (string, bool) { + for _, cookie := range ctx.Request.Cookies() { + if cookie.Name != name { + continue + } + + parts := strings.SplitN(cookie.Value, "|", 2) + if len(parts) != 2 { + return "", false + } + + ciphertext, err := base64.StdEncoding.DecodeString(parts[0]) + if err != nil { + return "", false + } + sig, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return "", false + } + expectedSig := sign([]byte(ciphertext), ctx.Server.Config.signKey) + if !bytes.Equal(expectedSig, sig) { + return "", false + } + plaintext, err := decrypt(ciphertext, ctx.Server.Config.encKey) + if err != nil { + return "", false + } + return string(plaintext), true + } + return "", false +} + +func genKey(password string, salt string) []byte { + return pbkdf2.Key([]byte(password), []byte(salt), pbkdf2Iterations, keySize, sha512.New) +} + +func encrypt(plaintext []byte, key []byte) ([]byte, error) { + aesCipher, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + ciphertext := make([]byte, aes.BlockSize+len(plaintext)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + stream := cipher.NewCTR(aesCipher, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext) + return ciphertext, nil +} + +func decrypt(ciphertext []byte, key []byte) ([]byte, error) { + if len(ciphertext) <= aes.BlockSize { + return nil, errors.New("Invalid cipher text") + } + aesCipher, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + plaintext := make([]byte, len(ciphertext)-aes.BlockSize) + stream := cipher.NewCTR(aesCipher, ciphertext[:aes.BlockSize]) + stream.XORKeyStream(plaintext, ciphertext[aes.BlockSize:]) + return plaintext, nil +} + +func sign(data []byte, key []byte) []byte { + mac := hmac.New(sha512.New, key) + mac.Write(data) + return mac.Sum(nil) +} diff --git a/server.go b/server.go index ae6b8a9f..98b9b972 100644 --- a/server.go +++ b/server.go @@ -28,6 +28,8 @@ type ServerConfig struct { RecoverPanic bool Profiler bool ColorOutput bool + encKey []byte + signKey []byte } // Server represents a web.go server. @@ -56,6 +58,12 @@ func (s *Server) initServer() { if s.Logger == nil { s.Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime) } + + if len(s.Config.CookieSecret) > 0 { + s.Logger.Println("Generating cookie encryption keys") + s.Config.encKey = genKey(s.Config.CookieSecret, "encryption key salt") + s.Config.signKey = genKey(s.Config.CookieSecret, "signature key salt") + } } type route struct { diff --git a/web.go b/web.go index a2471960..8d36130d 100644 --- a/web.go +++ b/web.go @@ -3,23 +3,15 @@ package web import ( - "bytes" - "crypto/hmac" - "crypto/sha1" "crypto/tls" - "encoding/base64" - "fmt" "golang.org/x/net/websocket" - "io/ioutil" "log" "mime" "net/http" "os" "path" "reflect" - "strconv" "strings" - "time" ) // A Context object is created for every incoming HTTP request, and is @@ -117,61 +109,6 @@ func (ctx *Context) SetCookie(cookie *http.Cookie) { ctx.SetHeader("Set-Cookie", cookie.String(), false) } -func getCookieSig(key string, val []byte, timestamp string) string { - hm := hmac.New(sha1.New, []byte(key)) - - hm.Write(val) - hm.Write([]byte(timestamp)) - - return fmt.Sprintf("%02x", hm.Sum(nil)) -} - -func (ctx *Context) SetSecureCookie(name string, val string, age int64) { - if len(ctx.Server.Config.CookieSecret) == 0 { - ctx.Server.Logger.Println("Secret Key for secure cookies has not been set. Please assign a cookie secret to web.Config.CookieSecret.") - return - } - encoded := base64.StdEncoding.EncodeToString([]byte(val)) - timestamp := strconv.FormatInt(time.Now().Unix(), 10) - sig := getCookieSig(ctx.Server.Config.CookieSecret, []byte(encoded), timestamp) - cookie := strings.Join([]string{encoded, timestamp, sig}, "|") - ctx.SetCookie(NewCookie(name, cookie, age)) -} - -func (ctx *Context) GetSecureCookie(name string) (string, bool) { - for _, cookie := range ctx.Request.Cookies() { - if cookie.Name != name { - continue - } - - parts := strings.SplitN(cookie.Value, "|", 3) - if len(parts) != 3 { - return "", false - } - - val := parts[0] - timestamp := parts[1] - sig := parts[2] - - if getCookieSig(ctx.Server.Config.CookieSecret, []byte(val), timestamp) != sig { - return "", false - } - - ts, _ := strconv.ParseInt(timestamp, 0, 64) - - if time.Now().Unix()-31*86400 > ts { - return "", false - } - - buf := bytes.NewBufferString(val) - encoder := base64.NewDecoder(base64.StdEncoding, buf) - - res, _ := ioutil.ReadAll(encoder) - return string(res), true - } - return "", false -} - // small optimization: cache the context type instead of repeteadly calling reflect.Typeof var contextType reflect.Type diff --git a/web_test.go b/web_test.go index 1dba5830..6468b796 100644 --- a/web_test.go +++ b/web_test.go @@ -487,6 +487,7 @@ func makeCookie(vals map[string]string) []*http.Cookie { func TestSecureCookie(t *testing.T) { mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" + mainServer.initServer() resp1 := getTestResponse("POST", "/securecookie/set/a/1", "", nil, nil) sval, ok := resp1.cookies["a"] if !ok { From 27fd338e8fba5200b61de4016752af34c80fd1f6 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Mon, 8 Aug 2016 23:21:26 -0700 Subject: [PATCH 32/33] Remove superfluous newlines in secure_cookie.go Whitespace only change. I forgot to add this file when amending the commit. --- secure_cookie.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/secure_cookie.go b/secure_cookie.go index a9cdba13..67ca6738 100644 --- a/secure_cookie.go +++ b/secure_cookie.go @@ -29,7 +29,6 @@ func (ctx *Context) SetSecureCookie(name string, val string, age int64) error { if len(serverConfig.CookieSecret) == 0 { return ErrMissingCookieSecret } - if len(serverConfig.encKey) == 0 || len(serverConfig.signKey) == 0 { return ErrInvalidKey } @@ -48,12 +47,10 @@ func (ctx *Context) GetSecureCookie(name string) (string, bool) { if cookie.Name != name { continue } - parts := strings.SplitN(cookie.Value, "|", 2) if len(parts) != 2 { return "", false } - ciphertext, err := base64.StdEncoding.DecodeString(parts[0]) if err != nil { return "", false From 83519dd4099705ea1b238f178aed92f4d1b3ee65 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Tue, 9 Aug 2016 00:02:05 -0700 Subject: [PATCH 33/33] Move encKey and signKey from ServerConfig to Server Belonging to the `Server` struct seems more appropriate than `ServerConfig`. The `ServerConfig` is mainly for user-defined configuration, and the keys are generated during runtime. --- secure_cookie.go | 14 +++++++------- server.go | 10 +++++----- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/secure_cookie.go b/secure_cookie.go index 67ca6738..dfd5d196 100644 --- a/secure_cookie.go +++ b/secure_cookie.go @@ -25,18 +25,18 @@ var ( ) func (ctx *Context) SetSecureCookie(name string, val string, age int64) error { - serverConfig := ctx.Server.Config - if len(serverConfig.CookieSecret) == 0 { + server := ctx.Server + if len(server.Config.CookieSecret) == 0 { return ErrMissingCookieSecret } - if len(serverConfig.encKey) == 0 || len(serverConfig.signKey) == 0 { + if len(server.encKey) == 0 || len(server.signKey) == 0 { return ErrInvalidKey } - ciphertext, err := encrypt([]byte(val), serverConfig.encKey) + ciphertext, err := encrypt([]byte(val), server.encKey) if err != nil { return err } - sig := sign(ciphertext, serverConfig.signKey) + sig := sign(ciphertext, server.signKey) data := base64.StdEncoding.EncodeToString(ciphertext) + "|" + base64.StdEncoding.EncodeToString(sig) ctx.SetCookie(NewCookie(name, data, age)) return nil @@ -59,11 +59,11 @@ func (ctx *Context) GetSecureCookie(name string) (string, bool) { if err != nil { return "", false } - expectedSig := sign([]byte(ciphertext), ctx.Server.Config.signKey) + expectedSig := sign([]byte(ciphertext), ctx.Server.signKey) if !bytes.Equal(expectedSig, sig) { return "", false } - plaintext, err := decrypt(ciphertext, ctx.Server.Config.encKey) + plaintext, err := decrypt(ciphertext, ctx.Server.encKey) if err != nil { return "", false } diff --git a/server.go b/server.go index 98b9b972..0e97a4dd 100644 --- a/server.go +++ b/server.go @@ -28,8 +28,6 @@ type ServerConfig struct { RecoverPanic bool Profiler bool ColorOutput bool - encKey []byte - signKey []byte } // Server represents a web.go server. @@ -39,7 +37,9 @@ type Server struct { Logger *log.Logger Env map[string]interface{} //save the listener so it can be closed - l net.Listener + l net.Listener + encKey []byte + signKey []byte } func NewServer() *Server { @@ -61,8 +61,8 @@ func (s *Server) initServer() { if len(s.Config.CookieSecret) > 0 { s.Logger.Println("Generating cookie encryption keys") - s.Config.encKey = genKey(s.Config.CookieSecret, "encryption key salt") - s.Config.signKey = genKey(s.Config.CookieSecret, "signature key salt") + s.encKey = genKey(s.Config.CookieSecret, "encryption key salt") + s.signKey = genKey(s.Config.CookieSecret, "signature key salt") } }