diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..b5ebb9ce --- /dev/null +++ b/.travis.yml @@ -0,0 +1,13 @@ +language: go + +go: + - 1.3 + - 1.4 + - 1.5 + - 1.6 + +install: go get + +script: + - diff -u <(echo -n) <(gofmt -d -s .) + - go test -short diff --git a/Makefile b/Makefile deleted file mode 100644 index 25b06108..00000000 --- a/Makefile +++ /dev/null @@ -1,6 +0,0 @@ -GOFMT=gofmt -s -tabs=false -tabwidth=4 - -GOFILES=$(wildcard *.go **/*.go) - -format: - ${GOFMT} -w ${GOFILES} diff --git a/Readme.md b/Readme.md index 9750f4df..2be27536 100644 --- a/Readme.md +++ b/Readme.md @@ -1,3 +1,5 @@ +[![Build Status](https://travis-ci.org/hoisie/web.svg?branch=master)](https://travis-ci.org/hoisie/web) + # web.go web.go is the simplest way to write web applications in the Go programming language. It's ideal for writing simple, performant backend web services. @@ -77,12 +79,12 @@ In this example, if you visit `http://localhost:9999/?a=1&b=2`, you'll see the f ## Documentation -API docs are hosted at http://webgo.io +API docs are hosted at https://hoisie.github.io/web/ If you use web.go, I'd greatly appreciate a quick message about what you're building with it. This will help me get a sense of usage patterns, and helps me focus development efforts on features that people will actually use. ## About -web.go was written by [Michael Hoisie](http://hoisie.com). +web.go was written by Michael Hoisie diff --git a/examples/arcchallenge.go b/examples/arcchallenge.go index 1de600c2..ea9feb4c 100644 --- a/examples/arcchallenge.go +++ b/examples/arcchallenge.go @@ -1,10 +1,10 @@ package main import ( - "fmt" - "github.com/hoisie/web" - "math/rand" - "time" + "fmt" + "github.com/hoisie/web" + "math/rand" + "time" ) var form = `
` @@ -12,22 +12,22 @@ var form = `
Click Here` - }) - web.Get("/final", func(ctx *web.Context) string { - uid, _ := ctx.GetSecureCookie("user") - return "You said " + users[uid] - }) - web.Run("0.0.0.0:9999") + rand.Seed(time.Now().UnixNano()) + web.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" + web.Get("/", func(ctx *web.Context) string { + ctx.Redirect(302, "/said") + return "" + }) + web.Get("/said", func() string { return form }) + web.Post("/say", func(ctx *web.Context) string { + uid := fmt.Sprintf("%d\n", rand.Int63()) + ctx.SetSecureCookie("user", uid, 3600) + users[uid] = ctx.Params["said"] + return `Click Here` + }) + web.Get("/final", func(ctx *web.Context) string { + uid, _ := ctx.GetSecureCookie("user") + return "You said " + users[uid] + }) + web.Run("0.0.0.0:9999") } diff --git a/examples/cookie.go b/examples/cookie.go index e686c767..a4bb8a3d 100644 --- a/examples/cookie.go +++ b/examples/cookie.go @@ -1,9 +1,9 @@ package main import ( - "fmt" - "github.com/hoisie/web" - "html" + "fmt" + "github.com/hoisie/web" + "html" ) var cookieName = "cookie" @@ -24,28 +24,28 @@ var form = ` ` func index(ctx *web.Context) string { - cookie, _ := ctx.Request.Cookie(cookieName) - var top string - if cookie == nil { - top = fmt.Sprintf(notice, "The cookie has not been set") - } else { - var val = html.EscapeString(cookie.Value) - top = fmt.Sprintf(notice, "The value of the cookie is '"+val+"'.") - } - return top + form + cookie, _ := ctx.Request.Cookie(cookieName) + var top string + if cookie == nil { + top = fmt.Sprintf(notice, "The cookie has not been set") + } else { + var val = html.EscapeString(cookie.Value) + top = fmt.Sprintf(notice, "The value of the cookie is '"+val+"'.") + } + return top + form } func update(ctx *web.Context) { - if ctx.Params["submit"] == "Delete" { - ctx.SetCookie(web.NewCookie(cookieName, "", -1)) - } else { - ctx.SetCookie(web.NewCookie(cookieName, ctx.Params["cookie"], 0)) - } - ctx.Redirect(301, "/") + if ctx.Params["submit"] == "Delete" { + ctx.SetCookie(web.NewCookie(cookieName, "", -1)) + } else { + ctx.SetCookie(web.NewCookie(cookieName, ctx.Params["cookie"], 0)) + } + ctx.Redirect(301, "/") } func main() { - web.Get("/", index) - web.Post("/update", update) - web.Run("0.0.0.0:9999") + web.Get("/", index) + web.Post("/update", update) + web.Run("0.0.0.0:9999") } diff --git a/examples/hello.go b/examples/hello.go index d502391a..26ecb0b0 100644 --- a/examples/hello.go +++ b/examples/hello.go @@ -1,12 +1,12 @@ package main import ( - "github.com/hoisie/web" + "github.com/hoisie/web" ) -func hello(val string) string { return "hello " + val } +func hello(val string) string { return "hello " + val + "\n" } func main() { - web.Get("/(.*)", hello) - web.Run("0.0.0.0:9999") + web.Get("/(.*)", hello) + web.Run("0.0.0.0:9999") } diff --git a/examples/logger.go b/examples/logger.go index f684261a..462ad610 100644 --- a/examples/logger.go +++ b/examples/logger.go @@ -1,21 +1,21 @@ package main import ( - "github.com/hoisie/web" - "log" - "os" + "github.com/hoisie/web" + "log" + "os" ) -func hello(val string) string { return "hello " + val } +func hello(val string) string { return "hello " + val + "\n" } func main() { - f, err := os.Create("server.log") - if err != nil { - println(err.Error()) - return - } - logger := log.New(f, "", log.Ldate|log.Ltime) - web.Get("/(.*)", hello) - web.SetLogger(logger) - web.Run("0.0.0.0:9999") + f, err := os.Create("server.log") + if err != nil { + println(err.Error()) + return + } + logger := log.New(f, "", log.Ldate|log.Ltime) + web.Get("/(.*)", hello) + web.SetLogger(logger) + web.Run("0.0.0.0:9999") } diff --git a/examples/multipart.go b/examples/multipart.go index d89e514e..68df57e8 100644 --- a/examples/multipart.go +++ b/examples/multipart.go @@ -1,17 +1,17 @@ package main import ( - "bytes" - "crypto/md5" - "fmt" - "github.com/hoisie/web" - "io" + "bytes" + "crypto/md5" + "fmt" + "github.com/hoisie/web" + "io" ) func Md5(r io.Reader) string { - hash := md5.New() - io.Copy(hash, r) - return fmt.Sprintf("%x", hash.Sum(nil)) + hash := md5.New() + io.Copy(hash, r) + return fmt.Sprintf("%x", hash.Sum(nil)) } var page = ` @@ -38,25 +38,25 @@ var page = ` func index() string { return page } func multipart(ctx *web.Context) string { - ctx.Request.ParseMultipartForm(10 * 1024 * 1024) - form := ctx.Request.MultipartForm - var output bytes.Buffer - output.WriteString("

input1: " + form.Value["input1"][0] + "

") - output.WriteString("

input2: " + form.Value["input2"][0] + "

") - - fileHeader := form.File["file"][0] - filename := fileHeader.Filename - file, err := fileHeader.Open() - if err != nil { - return err.Error() - } - - output.WriteString("

file: " + filename + " " + Md5(file) + "

") - return output.String() + ctx.Request.ParseMultipartForm(10 * 1024 * 1024) + form := ctx.Request.MultipartForm + var output bytes.Buffer + output.WriteString("

input1: " + form.Value["input1"][0] + "

") + output.WriteString("

input2: " + form.Value["input2"][0] + "

") + + fileHeader := form.File["file"][0] + filename := fileHeader.Filename + file, err := fileHeader.Open() + if err != nil { + return err.Error() + } + + output.WriteString("

file: " + filename + " " + Md5(file) + "

") + return output.String() } func main() { - web.Get("/", index) - web.Post("/multipart", multipart) - web.Run("0.0.0.0:9999") + web.Get("/", index) + web.Post("/multipart", multipart) + web.Run("0.0.0.0:9999") } diff --git a/examples/multiserver.go b/examples/multiserver.go index 354b563e..2523cb90 100644 --- a/examples/multiserver.go +++ b/examples/multiserver.go @@ -1,20 +1,20 @@ package main import ( - "github.com/hoisie/web" + "github.com/hoisie/web" ) -func hello1(val string) string { return "hello1 " + val } +func hello1(val string) string { return "hello1 " + val + "\n" } -func hello2(val string) string { return "hello2 " + val } +func hello2(val string) string { return "hello2 " + val + "\n" } func main() { - var server1 web.Server - var server2 web.Server + var server1 web.Server + var server2 web.Server - server1.Get("/(.*)", hello1) - go server1.Run("0.0.0.0:9999") - server2.Get("/(.*)", hello2) - go server2.Run("0.0.0.0:8999") - <-make(chan int) + server1.Get("/(.*)", hello1) + go server1.Run("0.0.0.0:9999") + server2.Get("/(.*)", hello2) + go server2.Run("0.0.0.0:8999") + <-make(chan int) } diff --git a/examples/params.go b/examples/params.go index dc396a1b..ac70654d 100644 --- a/examples/params.go +++ b/examples/params.go @@ -1,8 +1,8 @@ package main import ( - "fmt" - "github.com/hoisie/web" + "fmt" + "github.com/hoisie/web" ) var page = ` @@ -33,11 +33,11 @@ var page = ` func index() string { return page } func process(ctx *web.Context) string { - return fmt.Sprintf("%v\n", ctx.Params) + return fmt.Sprintf("%v\n", ctx.Params) } func main() { - web.Get("/", index) - web.Post("/process", process) - web.Run("0.0.0.0:9999") + web.Get("/", index) + web.Post("/process", process) + web.Run("0.0.0.0:9999") } diff --git a/examples/secure_cookie.go b/examples/secure_cookie.go new file mode 100644 index 00000000..4da59bb7 --- /dev/null +++ b/examples/secure_cookie.go @@ -0,0 +1,52 @@ +package main + +import ( + "fmt" + "github.com/hoisie/web" + "html" +) + +var cookieName = "cookie" + +var notice = ` +
%v
+` +var form = ` + +
+ + +
+ + + +
+` + +func index(ctx *web.Context) string { + cookie, ok := ctx.GetSecureCookie(cookieName) + var top string + if !ok { + top = fmt.Sprintf(notice, "The cookie has not been set") + } else { + var val = html.EscapeString(cookie) + top = fmt.Sprintf(notice, "The value of the cookie is '"+val+"'.") + } + return top + form +} + +func update(ctx *web.Context) { + if ctx.Params["submit"] == "Delete" { + ctx.SetCookie(web.NewCookie(cookieName, "", -1)) + } else { + ctx.SetSecureCookie(cookieName, ctx.Params["cookie"], 0) + } + ctx.Redirect(301, "/") +} + +func main() { + web.Config.CookieSecret = "a long secure cookie secret" + web.Get("/", index) + web.Post("/update", update) + web.Run("0.0.0.0:9999") +} diff --git a/examples/streaming.go b/examples/streaming.go index e17ab372..2a5b948f 100644 --- a/examples/streaming.go +++ b/examples/streaming.go @@ -1,24 +1,24 @@ package main import ( - "github.com/hoisie/web" - "net/http" - "strconv" - "time" + "github.com/hoisie/web" + "net/http" + "strconv" + "time" ) func hello(ctx *web.Context, num string) { - flusher, _ := ctx.ResponseWriter.(http.Flusher) - flusher.Flush() - n, _ := strconv.ParseInt(num, 10, 64) - for i := int64(0); i < n; i++ { - ctx.WriteString("
hello world
") - flusher.Flush() - time.Sleep(1e9) - } + flusher, _ := ctx.ResponseWriter.(http.Flusher) + flusher.Flush() + n, _ := strconv.ParseInt(num, 10, 64) + for i := int64(0); i < n; i++ { + ctx.WriteString("
hello world
") + flusher.Flush() + time.Sleep(1e9) + } } func main() { - web.Get("/([0-9]+)", hello) - web.Run("0.0.0.0:9999") + web.Get("/([0-9]+)", hello) + web.Run("0.0.0.0:9999") } diff --git a/examples/tls.go b/examples/tls.go index 7517fe61..951418fd 100644 --- a/examples/tls.go +++ b/examples/tls.go @@ -1,8 +1,8 @@ package main import ( - "crypto/tls" - "github.com/hoisie/web" + "crypto/tls" + "github.com/hoisie/web" ) // an arbitrary self-signed certificate, generated with @@ -47,22 +47,22 @@ gWrxykqyLToIiAuL+pvC3Jv8IOPIiVFsY032rOqcwSGdVUyhTsG28+7KnR6744tM -----END CERTIFICATE----- ` -func hello(val string) string { return "hello " + val } +func hello(val string) string { return "hello " + val + "\n" } func main() { - config := tls.Config{ - Time: nil, - } + config := tls.Config{ + Time: nil, + } - config.Certificates = make([]tls.Certificate, 1) - var err error - config.Certificates[0], err = tls.X509KeyPair([]byte(cert), []byte(pkey)) - if err != nil { - println(err.Error()) - return - } + config.Certificates = make([]tls.Certificate, 1) + var err error + config.Certificates[0], err = tls.X509KeyPair([]byte(cert), []byte(pkey)) + if err != nil { + println(err.Error()) + return + } - // you must access the server with an HTTP address, i.e https://localhost:9999/world - web.Get("/(.*)", hello) - web.RunTLS("0.0.0.0:9999", &config) + // you must access the server with an HTTPS address, i.e https://localhost:9999/world + web.Get("/(.*)", hello) + web.RunTLS("0.0.0.0:9999", &config) } diff --git a/fcgi.go b/fcgi.go index fa380dc9..017b8024 100644 --- a/fcgi.go +++ b/fcgi.go @@ -1,27 +1,27 @@ package web import ( - "net" - "net/http/fcgi" + "net" + "net/http/fcgi" ) func (s *Server) listenAndServeFcgi(addr string) error { - var l net.Listener - var err error + var l net.Listener + var err error - //if the path begins with a "/", assume it's a unix address - if addr[0] == '/' { - l, err = net.Listen("unix", addr) - } else { - l, err = net.Listen("tcp", addr) - } + //if the path begins with a "/", assume it's a unix address + if addr[0] == '/' { + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } - //save the listener so it can be closed - s.l = l + //save the listener so it can be closed + s.l = l - if err != nil { - s.Logger.Println("FCGI listen error", err.Error()) - return err - } - return fcgi.Serve(s.l, s) + if err != nil { + s.Logger.Println("FCGI listen error", err.Error()) + return err + } + return fcgi.Serve(s.l, s) } diff --git a/helpers.go b/helpers.go index 93c8d5ad..a87e93f4 100644 --- a/helpers.go +++ b/helpers.go @@ -1,59 +1,59 @@ package web import ( - "bytes" - "encoding/base64" - "errors" - "net/http" - "net/url" - "os" - "regexp" - "strings" - "time" + "bytes" + "encoding/base64" + "errors" + "net/http" + "net/url" + "os" + "regexp" + "strings" + "time" ) // internal utility methods func webTime(t time.Time) string { - ftime := t.Format(time.RFC1123) - if strings.HasSuffix(ftime, "UTC") { - ftime = ftime[0:len(ftime)-3] + "GMT" - } - return ftime + ftime := t.Format(time.RFC1123) + if strings.HasSuffix(ftime, "UTC") { + ftime = ftime[0:len(ftime)-3] + "GMT" + } + return ftime } func dirExists(dir string) bool { - d, e := os.Stat(dir) - switch { - case e != nil: - return false - case !d.IsDir(): - return false - } + d, e := os.Stat(dir) + switch { + case e != nil: + return false + case !d.IsDir(): + return false + } - return true + return true } func fileExists(dir string) bool { - info, err := os.Stat(dir) - if err != nil { - return false - } + info, err := os.Stat(dir) + if err != nil { + return false + } - return !info.IsDir() + return !info.IsDir() } // Urlencode is a helper method that converts a map into URL-encoded form data. // It is a useful when constructing HTTP POST requests. func Urlencode(data map[string]string) string { - var buf bytes.Buffer - for k, v := range data { - buf.WriteString(url.QueryEscape(k)) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(v)) - buf.WriteByte('&') - } - s := buf.String() - return s[0 : len(s)-1] + var buf bytes.Buffer + for k, v := range data { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(v)) + buf.WriteByte('&') + } + s := buf.String() + return s[0 : len(s)-1] } var slugRegex = regexp.MustCompile(`(?i:[^a-z0-9\-_])`) @@ -62,50 +62,53 @@ var slugRegex = regexp.MustCompile(`(?i:[^a-z0-9\-_])`) // It's used to return clean, URL-friendly strings that can be // used in routing. func Slug(s string, sep string) string { - if s == "" { - return "" - } - slug := slugRegex.ReplaceAllString(s, sep) - if slug == "" { - return "" - } - quoted := regexp.QuoteMeta(sep) - sepRegex := regexp.MustCompile("(" + quoted + "){2,}") - slug = sepRegex.ReplaceAllString(slug, sep) - sepEnds := regexp.MustCompile("^" + quoted + "|" + quoted + "$") - slug = sepEnds.ReplaceAllString(slug, "") - return strings.ToLower(slug) + if s == "" { + return "" + } + slug := slugRegex.ReplaceAllString(s, sep) + if slug == "" { + return "" + } + quoted := regexp.QuoteMeta(sep) + sepRegex := regexp.MustCompile("(" + quoted + "){2,}") + slug = sepRegex.ReplaceAllString(slug, sep) + sepEnds := regexp.MustCompile("^" + quoted + "|" + quoted + "$") + slug = sepEnds.ReplaceAllString(slug, "") + return strings.ToLower(slug) } // NewCookie is a helper method that returns a new http.Cookie object. // Duration is specified in seconds. If the duration is zero, the cookie is permanent. // This can be used in conjunction with ctx.SetCookie. func NewCookie(name string, value string, age int64) *http.Cookie { - var utctime time.Time - if age == 0 { - // 2^31 - 1 seconds (roughly 2038) - utctime = time.Unix(2147483647, 0) - } else { - utctime = time.Unix(time.Now().Unix()+age, 0) - } - return &http.Cookie{Name: name, Value: value, Expires: utctime} + var utctime time.Time + if age == 0 { + // 2^31 - 1 seconds (roughly 2038) + utctime = time.Unix(2147483647, 0) + } else { + utctime = time.Unix(time.Now().Unix()+age, 0) + } + return &http.Cookie{Name: name, Value: value, Expires: utctime} } -// GetBasicAuth is a helper method of *Context that returns the decoded -// user and password from the *Context's authorization header +// GetBasicAuth returns the decoded user and password from the context's +// 'Authorization' header. func (ctx *Context) GetBasicAuth() (string, string, error) { - authHeader := ctx.Request.Header["Authorization"][0] - authString := strings.Split(string(authHeader), " ") - if authString[0] != "Basic" { - return "", "", errors.New("Not Basic Authentication") - } - decodedAuth, err := base64.StdEncoding.DecodeString(authString[1]) - if err != nil { - return "", "", err - } - authSlice := strings.Split(string(decodedAuth), ":") - if len(authSlice) != 2 { - return "", "", errors.New("Error delimiting authString into username/password. Malformed input: " + authString[1]) - } - return authSlice[0], authSlice[1], nil + if len(ctx.Request.Header["Authorization"]) == 0 { + return "", "", errors.New("No Authorization header provided") + } + authHeader := ctx.Request.Header["Authorization"][0] + authString := strings.Split(string(authHeader), " ") + if authString[0] != "Basic" { + return "", "", errors.New("Not Basic Authentication") + } + decodedAuth, err := base64.StdEncoding.DecodeString(authString[1]) + if err != nil { + return "", "", err + } + authSlice := strings.Split(string(decodedAuth), ":") + if len(authSlice) != 2 { + return "", "", errors.New("Error delimiting authString into username/password. Malformed input: " + authString[1]) + } + return authSlice[0], authSlice[1], nil } diff --git a/scgi.go b/scgi.go index 45d3aec5..eea2b4fe 100644 --- a/scgi.go +++ b/scgi.go @@ -1,180 +1,183 @@ package web import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/http/cgi" - "strconv" - "strings" + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/cgi" + "strconv" + "strings" ) type scgiBody struct { - reader io.Reader - conn io.ReadWriteCloser - closed bool + reader io.Reader + conn io.ReadWriteCloser + closed bool } func (b *scgiBody) Read(p []byte) (n int, err error) { - if b.closed { - return 0, errors.New("SCGI read after close") - } - return b.reader.Read(p) + if b.closed { + return 0, errors.New("SCGI read after close") + } + return b.reader.Read(p) } func (b *scgiBody) Close() error { - b.closed = true - return b.conn.Close() + b.closed = true + return b.conn.Close() } type scgiConn struct { - fd io.ReadWriteCloser - req *http.Request - headers http.Header - wroteHeaders bool + fd io.ReadWriteCloser + req *http.Request + headers http.Header + wroteHeaders bool } func (conn *scgiConn) WriteHeader(status int) { - if !conn.wroteHeaders { - conn.wroteHeaders = true + if !conn.wroteHeaders { + conn.wroteHeaders = true - var buf bytes.Buffer - text := statusText[status] + var buf bytes.Buffer + text := http.StatusText(status) - fmt.Fprintf(&buf, "HTTP/1.1 %d %s\r\n", status, text) + fmt.Fprintf(&buf, "HTTP/1.1 %d %s\r\n", status, text) - for k, v := range conn.headers { - for _, i := range v { - buf.WriteString(k + ": " + i + "\r\n") - } - } + for k, v := range conn.headers { + for _, i := range v { + buf.WriteString(k + ": " + i + "\r\n") + } + } - buf.WriteString("\r\n") - conn.fd.Write(buf.Bytes()) - } + buf.WriteString("\r\n") + conn.fd.Write(buf.Bytes()) + } } func (conn *scgiConn) Header() http.Header { - return conn.headers + return conn.headers } func (conn *scgiConn) Write(data []byte) (n int, err error) { - if !conn.wroteHeaders { - conn.WriteHeader(200) - } + if !conn.wroteHeaders { + conn.WriteHeader(200) + } - if conn.req.Method == "HEAD" { - return 0, errors.New("Body Not Allowed") - } + if conn.req.Method == "HEAD" { + return 0, errors.New("Body Not Allowed") + } - return conn.fd.Write(data) + return conn.fd.Write(data) } func (conn *scgiConn) Close() { conn.fd.Close() } func (conn *scgiConn) finishRequest() error { - var buf bytes.Buffer - if !conn.wroteHeaders { - conn.wroteHeaders = true - - for k, v := range conn.headers { - for _, i := range v { - buf.WriteString(k + ": " + i + "\r\n") - } - } - - buf.WriteString("\r\n") - conn.fd.Write(buf.Bytes()) - } - return nil + var buf bytes.Buffer + if !conn.wroteHeaders { + conn.wroteHeaders = true + + for k, v := range conn.headers { + for _, i := range v { + buf.WriteString(k + ": " + i + "\r\n") + } + } + + buf.WriteString("\r\n") + conn.fd.Write(buf.Bytes()) + } + return nil } func (s *Server) readScgiRequest(fd io.ReadWriteCloser) (*http.Request, error) { - reader := bufio.NewReader(fd) - line, err := reader.ReadString(':') - if err != nil { - s.Logger.Println("Error during SCGI read: ", err.Error()) - } - length, _ := strconv.Atoi(line[0 : len(line)-1]) - if length > 16384 { - s.Logger.Println("Error: max header size is 16k") - } - headerData := make([]byte, length) - _, err = reader.Read(headerData) - if err != nil { - return nil, err - } - - b, err := reader.ReadByte() - if err != nil { - return nil, err - } - // discard the trailing comma - if b != ',' { - return nil, errors.New("SCGI protocol error: missing comma") - } - headerList := bytes.Split(headerData, []byte{0}) - headers := map[string]string{} - for i := 0; i < len(headerList)-1; i += 2 { - headers[string(headerList[i])] = string(headerList[i+1]) - } - httpReq, err := cgi.RequestFromMap(headers) - if err != nil { - return nil, err - } - if httpReq.ContentLength > 0 { - httpReq.Body = &scgiBody{ - reader: io.LimitReader(reader, httpReq.ContentLength), - conn: fd, - } - } else { - httpReq.Body = &scgiBody{reader: reader, conn: fd} - } - return httpReq, nil + reader := bufio.NewReader(fd) + line, err := reader.ReadString(':') + if err != nil { + return nil, err + } + length, err := strconv.Atoi(line[0 : len(line)-1]) + if err != nil { + return nil, err + } + if length > 16384 { + return nil, errors.New("Max header size is 16k") + } + headerData := make([]byte, length) + _, err = reader.Read(headerData) + if err != nil { + return nil, err + } + b, err := reader.ReadByte() + if err != nil { + return nil, err + } + // discard the trailing comma + if b != ',' { + return nil, errors.New("SCGI protocol error: missing comma") + } + headerList := bytes.Split(headerData, []byte{0}) + headers := map[string]string{} + for i := 0; i < len(headerList)-1; i += 2 { + headers[string(headerList[i])] = string(headerList[i+1]) + } + httpReq, err := cgi.RequestFromMap(headers) + if err != nil { + return nil, err + } + if httpReq.ContentLength > 0 { + httpReq.Body = &scgiBody{ + reader: io.LimitReader(reader, httpReq.ContentLength), + conn: fd, + } + } else { + httpReq.Body = &scgiBody{reader: reader, conn: fd} + } + return httpReq, nil } func (s *Server) handleScgiRequest(fd io.ReadWriteCloser) { - req, err := s.readScgiRequest(fd) - if err != nil { - s.Logger.Println("SCGI error: %q", err.Error()) - } - sc := scgiConn{fd, req, make(map[string][]string), false} - s.routeHandler(req, &sc) - sc.finishRequest() - fd.Close() + defer fd.Close() + req, err := s.readScgiRequest(fd) + if err != nil { + s.Logger.Println("Error reading SCGI request: %q", err.Error()) + return + } + sc := scgiConn{fd, req, make(map[string][]string), false} + s.routeHandler(req, &sc) + sc.finishRequest() } func (s *Server) listenAndServeScgi(addr string) error { - var l net.Listener - var err error - - //if the path begins with a "/", assume it's a unix address - if strings.HasPrefix(addr, "/") { - l, err = net.Listen("unix", addr) - } else { - l, err = net.Listen("tcp", addr) - } - - //save the listener so it can be closed - s.l = l - - if err != nil { - s.Logger.Println("SCGI listen error", err.Error()) - return err - } - - for { - fd, err := l.Accept() - if err != nil { - s.Logger.Println("SCGI accept error", err.Error()) - return err - } - go s.handleScgiRequest(fd) - } - return nil + var l net.Listener + var err error + + //if the path begins with a "/", assume it's a unix address + if strings.HasPrefix(addr, "/") { + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } + + //save the listener so it can be closed + s.l = l + + if err != nil { + s.Logger.Println("SCGI listen error", err.Error()) + return err + } + + for { + fd, err := l.Accept() + if err != nil { + s.Logger.Println("SCGI accept error", err.Error()) + return err + } + go s.handleScgiRequest(fd) + } + return nil } diff --git a/secure_cookie.go b/secure_cookie.go new file mode 100644 index 00000000..dfd5d196 --- /dev/null +++ b/secure_cookie.go @@ -0,0 +1,112 @@ +package web + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha512" + "encoding/base64" + "errors" + "golang.org/x/crypto/pbkdf2" + "io" + "strings" +) + +const ( + pbkdf2Iterations = 64000 + keySize = 32 +) + +var ( + ErrMissingCookieSecret = errors.New("Secret Key for secure cookies has not been set. Assign one to web.Config.CookieSecret.") + ErrInvalidKey = errors.New("The keys for secure cookies have not been initialized. Ensure that a Run* method is being called") +) + +func (ctx *Context) SetSecureCookie(name string, val string, age int64) error { + server := ctx.Server + if len(server.Config.CookieSecret) == 0 { + return ErrMissingCookieSecret + } + if len(server.encKey) == 0 || len(server.signKey) == 0 { + return ErrInvalidKey + } + ciphertext, err := encrypt([]byte(val), server.encKey) + if err != nil { + return err + } + sig := sign(ciphertext, server.signKey) + data := base64.StdEncoding.EncodeToString(ciphertext) + "|" + base64.StdEncoding.EncodeToString(sig) + ctx.SetCookie(NewCookie(name, data, age)) + return nil +} + +func (ctx *Context) GetSecureCookie(name string) (string, bool) { + for _, cookie := range ctx.Request.Cookies() { + if cookie.Name != name { + continue + } + parts := strings.SplitN(cookie.Value, "|", 2) + if len(parts) != 2 { + return "", false + } + ciphertext, err := base64.StdEncoding.DecodeString(parts[0]) + if err != nil { + return "", false + } + sig, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return "", false + } + expectedSig := sign([]byte(ciphertext), ctx.Server.signKey) + if !bytes.Equal(expectedSig, sig) { + return "", false + } + plaintext, err := decrypt(ciphertext, ctx.Server.encKey) + if err != nil { + return "", false + } + return string(plaintext), true + } + return "", false +} + +func genKey(password string, salt string) []byte { + return pbkdf2.Key([]byte(password), []byte(salt), pbkdf2Iterations, keySize, sha512.New) +} + +func encrypt(plaintext []byte, key []byte) ([]byte, error) { + aesCipher, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + ciphertext := make([]byte, aes.BlockSize+len(plaintext)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + stream := cipher.NewCTR(aesCipher, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext) + return ciphertext, nil +} + +func decrypt(ciphertext []byte, key []byte) ([]byte, error) { + if len(ciphertext) <= aes.BlockSize { + return nil, errors.New("Invalid cipher text") + } + aesCipher, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + plaintext := make([]byte, len(ciphertext)-aes.BlockSize) + stream := cipher.NewCTR(aesCipher, ciphertext[:aes.BlockSize]) + stream.XORKeyStream(plaintext, ciphertext[aes.BlockSize:]) + return plaintext, nil +} + +func sign(data []byte, key []byte) []byte { + mac := hmac.New(sha512.New, key) + mac.Write(data) + return mac.Sum(nil) +} diff --git a/server.go b/server.go index 184f01c2..0e97a4dd 100644 --- a/server.go +++ b/server.go @@ -1,240 +1,252 @@ package web import ( - "bytes" - "code.google.com/p/go.net/websocket" - "crypto/tls" - "fmt" - "log" - "net" - "net/http" - "net/http/pprof" - "os" - "path" - "reflect" - "regexp" - "runtime" - "strconv" - "strings" - "time" + "bytes" + "crypto/tls" + "fmt" + "golang.org/x/net/websocket" + "log" + "net" + "net/http" + "net/http/pprof" + "os" + "path" + "reflect" + "regexp" + "runtime" + "strconv" + "strings" + "time" ) // ServerConfig is configuration for server objects. type ServerConfig struct { - StaticDir string - Addr string - Port int - CookieSecret string - RecoverPanic bool - Profiler bool + StaticDir string + Addr string + Port int + CookieSecret string + RecoverPanic bool + Profiler bool + ColorOutput bool } // Server represents a web.go server. type Server struct { - Config *ServerConfig - routes []route - Logger *log.Logger - Env map[string]interface{} - //save the listener so it can be closed - l net.Listener + Config *ServerConfig + routes []route + Logger *log.Logger + Env map[string]interface{} + //save the listener so it can be closed + l net.Listener + encKey []byte + signKey []byte } func NewServer() *Server { - return &Server{ - Config: Config, - Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime), - Env: map[string]interface{}{}, - } + return &Server{ + Config: Config, + Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime), + Env: map[string]interface{}{}, + } } func (s *Server) initServer() { - if s.Config == nil { - s.Config = &ServerConfig{} - } - - if s.Logger == nil { - s.Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime) - } + if s.Config == nil { + s.Config = &ServerConfig{} + } + + if s.Logger == nil { + s.Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime) + } + + if len(s.Config.CookieSecret) > 0 { + s.Logger.Println("Generating cookie encryption keys") + s.encKey = genKey(s.Config.CookieSecret, "encryption key salt") + s.signKey = genKey(s.Config.CookieSecret, "signature key salt") + } } type route struct { - r string - cr *regexp.Regexp - method string - handler reflect.Value - httpHandler http.Handler + r string + cr *regexp.Regexp + method string + handler reflect.Value + httpHandler http.Handler } func (s *Server) addRoute(r string, method string, handler interface{}) { - cr, err := regexp.Compile(r) - if err != nil { - s.Logger.Printf("Error in route regex %q\n", r) - return - } - - switch handler.(type) { - case http.Handler: - s.routes = append(s.routes, route{r: r, cr: cr, method: method, httpHandler: handler.(http.Handler)}) - case reflect.Value: - fv := handler.(reflect.Value) - s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv}) - default: - fv := reflect.ValueOf(handler) - s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv}) - } + cr, err := regexp.Compile(r) + if err != nil { + s.Logger.Printf("Error in route regex %q\n", r) + return + } + + switch handler.(type) { + case http.Handler: + s.routes = append(s.routes, route{r: r, cr: cr, method: method, httpHandler: handler.(http.Handler)}) + case reflect.Value: + fv := handler.(reflect.Value) + s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv}) + default: + fv := reflect.ValueOf(handler) + s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv}) + } } // ServeHTTP is the interface method for Go's http server package func (s *Server) ServeHTTP(c http.ResponseWriter, req *http.Request) { - s.Process(c, req) + s.Process(c, req) } // Process invokes the routing system for server s func (s *Server) Process(c http.ResponseWriter, req *http.Request) { - route := s.routeHandler(req, c) - if route != nil { - route.httpHandler.ServeHTTP(c, req) - } + route := s.routeHandler(req, c) + if route != nil { + route.httpHandler.ServeHTTP(c, req) + } } // Get adds a handler for the 'GET' http method for server s. func (s *Server) Get(route string, handler interface{}) { - s.addRoute(route, "GET", handler) + s.addRoute(route, "GET", handler) } // Post adds a handler for the 'POST' http method for server s. func (s *Server) Post(route string, handler interface{}) { - s.addRoute(route, "POST", handler) + s.addRoute(route, "POST", handler) } // Put adds a handler for the 'PUT' http method for server s. func (s *Server) Put(route string, handler interface{}) { - s.addRoute(route, "PUT", handler) + s.addRoute(route, "PUT", handler) } // Delete adds a handler for the 'DELETE' http method for server s. func (s *Server) Delete(route string, handler interface{}) { - s.addRoute(route, "DELETE", handler) + s.addRoute(route, "DELETE", handler) } // Match adds a handler for an arbitrary http method for server s. func (s *Server) Match(method string, route string, handler interface{}) { - s.addRoute(route, method, handler) + s.addRoute(route, method, handler) } -//Adds a custom handler. Only for webserver mode. Will have no effect when running as FCGI or SCGI. -func (s *Server) Handler(route string, method string, httpHandler http.Handler) { - s.addRoute(route, method, httpHandler) +// Add a custom http.Handler. Will have no effect when running as FCGI or SCGI. +func (s *Server) Handle(route string, method string, httpHandler http.Handler) { + s.addRoute(route, method, httpHandler) } //Adds a handler for websockets. Only for webserver mode. Will have no effect when running as FCGI or SCGI. func (s *Server) Websocket(route string, httpHandler websocket.Handler) { - s.addRoute(route, "GET", httpHandler) + s.addRoute(route, "GET", httpHandler) } // Run starts the web application and serves HTTP requests for s func (s *Server) Run(addr string) { - s.initServer() - - mux := http.NewServeMux() - if s.Config.Profiler { - mux.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) - mux.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) - mux.Handle("/debug/pprof/heap", pprof.Handler("heap")) - mux.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) - } - mux.Handle("/", s) - - s.Logger.Printf("web.go serving %s\n", addr) - - l, err := net.Listen("tcp", addr) - if err != nil { - log.Fatal("ListenAndServe:", err) - } - s.l = l - err = http.Serve(s.l, mux) - s.l.Close() + s.initServer() + + mux := http.NewServeMux() + if s.Config.Profiler { + mux.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + mux.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) + mux.Handle("/debug/pprof/heap", pprof.Handler("heap")) + mux.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) + } + mux.Handle("/", s) + + l, err := net.Listen("tcp", addr) + if err != nil { + log.Fatal("ListenAndServe:", err) + } + + s.Logger.Printf("web.go serving %s\n", l.Addr()) + + s.l = l + err = http.Serve(s.l, mux) + s.l.Close() } // RunFcgi starts the web application and serves FastCGI requests for s. func (s *Server) RunFcgi(addr string) { - s.initServer() - s.Logger.Printf("web.go serving fcgi %s\n", addr) - s.listenAndServeFcgi(addr) + s.initServer() + s.Logger.Printf("web.go serving fcgi %s\n", addr) + s.listenAndServeFcgi(addr) } // RunScgi starts the web application and serves SCGI requests for s. func (s *Server) RunScgi(addr string) { - s.initServer() - s.Logger.Printf("web.go serving scgi %s\n", addr) - s.listenAndServeScgi(addr) + s.initServer() + s.Logger.Printf("web.go serving scgi %s\n", addr) + s.listenAndServeScgi(addr) } // RunTLS starts the web application and serves HTTPS requests for s. func (s *Server) RunTLS(addr string, config *tls.Config) error { - s.initServer() - mux := http.NewServeMux() - mux.Handle("/", s) - l, err := tls.Listen("tcp", addr, config) - if err != nil { - log.Fatal("Listen:", err) - return err - } - - s.l = l - return http.Serve(s.l, mux) + s.initServer() + mux := http.NewServeMux() + mux.Handle("/", s) + + l, err := tls.Listen("tcp", addr, config) + if err != nil { + log.Fatal("Listen:", err) + return err + } + s.Logger.Printf("web.go serving %s\n", l.Addr()) + + s.l = l + return http.Serve(s.l, mux) } // Close stops server s. func (s *Server) Close() { - if s.l != nil { - s.l.Close() - } + if s.l != nil { + s.l.Close() + } } // safelyCall invokes `function` in recover block func (s *Server) safelyCall(function reflect.Value, args []reflect.Value) (resp []reflect.Value, e interface{}) { - defer func() { - if err := recover(); err != nil { - if !s.Config.RecoverPanic { - // go back to panic - panic(err) - } else { - e = err - resp = nil - s.Logger.Println("Handler crashed with error", err) - for i := 1; ; i += 1 { - _, file, line, ok := runtime.Caller(i) - if !ok { - break - } - s.Logger.Println(file, line) - } - } - } - }() - return function.Call(args), nil + defer func() { + if err := recover(); err != nil { + if !s.Config.RecoverPanic { + // go back to panic + panic(err) + } else { + e = err + resp = nil + s.Logger.Println("Handler crashed with error", err) + for i := 1; ; i += 1 { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + s.Logger.Println(file, line) + } + } + } + }() + return function.Call(args), nil } // requiresContext determines whether 'handlerType' contains // an argument to 'web.Ctx' as its first argument func requiresContext(handlerType reflect.Type) bool { - //if the method doesn't take arguments, no - if handlerType.NumIn() == 0 { - return false - } - - //if the first argument is not a pointer, no - a0 := handlerType.In(0) - if a0.Kind() != reflect.Ptr { - return false - } - //if the first argument is a context, yes - if a0.Elem() == contextType { - return true - } - - return false + //if the method doesn't take arguments, no + if handlerType.NumIn() == 0 { + return false + } + + //if the first argument is not a pointer, no + a0 := handlerType.In(0) + if a0.Kind() != reflect.Ptr { + return false + } + //if the first argument is a context, yes + if a0.Elem() == contextType { + return true + } + + return false } // tryServingFile attempts to serve a static file, and returns @@ -244,51 +256,66 @@ func requiresContext(handlerType reflect.Type) bool { // 2) The 'static' directory in the parent directory of the executable. // 3) The 'static' directory in the current working directory func (s *Server) tryServingFile(name string, req *http.Request, w http.ResponseWriter) bool { - //try to serve a static file - if s.Config.StaticDir != "" { - staticFile := path.Join(s.Config.StaticDir, name) - if fileExists(staticFile) { - http.ServeFile(w, req, staticFile) - return true - } - } else { - for _, staticDir := range defaultStaticDirs { - staticFile := path.Join(staticDir, name) - if fileExists(staticFile) { - http.ServeFile(w, req, staticFile) - return true - } - } - } - return false + //try to serve a static file + if s.Config.StaticDir != "" { + staticFile := path.Join(s.Config.StaticDir, name) + if fileExists(staticFile) { + http.ServeFile(w, req, staticFile) + return true + } + } else { + for _, staticDir := range defaultStaticDirs { + staticFile := path.Join(staticDir, name) + if fileExists(staticFile) { + http.ServeFile(w, req, staticFile) + return true + } + } + } + return false } func (s *Server) logRequest(ctx Context, sTime time.Time) { - //log the request - var logEntry bytes.Buffer - req := ctx.Request - requestPath := req.URL.Path - - duration := time.Now().Sub(sTime) - var client string - - // We suppose RemoteAddr is of the form Ip:Port as specified in the Request - // documentation at http://golang.org/pkg/net/http/#Request - pos := strings.LastIndex(req.RemoteAddr, ":") - if pos > 0 { - client = req.RemoteAddr[0:pos] - } else { - client = req.RemoteAddr - } - - fmt.Fprintf(&logEntry, "%s - \033[32;1m %s %s\033[0m - %v", client, req.Method, requestPath, duration) + //log the request + req := ctx.Request + requestPath := req.URL.Path + + duration := time.Now().Sub(sTime) + var client string + + // We suppose RemoteAddr is of the form Ip:Port as specified in the Request + // documentation at http://golang.org/pkg/net/http/#Request + pos := strings.LastIndex(req.RemoteAddr, ":") + if pos > 0 { + client = req.RemoteAddr[0:pos] + } else { + client = req.RemoteAddr + } + + var logEntry bytes.Buffer + logEntry.WriteString(client) + logEntry.WriteString(" - " + s.ttyGreen(req.Method+" "+requestPath)) + logEntry.WriteString(" - " + duration.String()) + if len(ctx.Params) > 0 { + logEntry.WriteString(" - " + s.ttyWhite(fmt.Sprintf("Params: %v\n", ctx.Params))) + } + ctx.Server.Logger.Print(logEntry.String()) +} - if len(ctx.Params) > 0 { - fmt.Fprintf(&logEntry, " - \033[37;1mParams: %v\033[0m\n", ctx.Params) - } +func (s *Server) ttyGreen(msg string) string { + return s.ttyColor(msg, ttyCodes.green) +} - ctx.Server.Logger.Print(logEntry.String()) +func (s *Server) ttyWhite(msg string) string { + return s.ttyColor(msg, ttyCodes.white) +} +func (s *Server) ttyColor(msg string, colorCode string) string { + if s.Config.ColorOutput { + return colorCode + msg + ttyCodes.reset + } else { + return msg + } } // the main route handler in web.go @@ -298,105 +325,105 @@ func (s *Server) logRequest(ctx Context, sTime time.Time) { // route. The caller is then responsible for calling the httpHandler associated // with the returned route. func (s *Server) routeHandler(req *http.Request, w http.ResponseWriter) (unused *route) { - requestPath := req.URL.Path - ctx := Context{req, map[string]string{}, s, w} - - //set some default headers - ctx.SetHeader("Server", "web.go", true) - tm := time.Now().UTC() - - //ignore errors from ParseForm because it's usually harmless. - req.ParseForm() - if len(req.Form) > 0 { - for k, v := range req.Form { - ctx.Params[k] = v[0] - } - } - - defer s.logRequest(ctx, tm) - - ctx.SetHeader("Date", webTime(tm), true) - - if req.Method == "GET" || req.Method == "HEAD" { - if s.tryServingFile(requestPath, req, w) { - return - } - } - - //Set the default content-type - ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true) - - for i := 0; i < len(s.routes); i++ { - route := s.routes[i] - cr := route.cr - //if the methods don't match, skip this handler (except HEAD can be used in place of GET) - if req.Method != route.method && !(req.Method == "HEAD" && route.method == "GET") { - continue - } - - if !cr.MatchString(requestPath) { - continue - } - match := cr.FindStringSubmatch(requestPath) - - if len(match[0]) != len(requestPath) { - continue - } - - if route.httpHandler != nil { - unused = &route - // We can not handle custom http handlers here, give back to the caller. - return - } - - var args []reflect.Value - handlerType := route.handler.Type() - if requiresContext(handlerType) { - args = append(args, reflect.ValueOf(&ctx)) - } - for _, arg := range match[1:] { - args = append(args, reflect.ValueOf(arg)) - } - - ret, err := s.safelyCall(route.handler, args) - if err != nil { - //there was an error or panic while calling the handler - ctx.Abort(500, "Server Error") - } - if len(ret) == 0 { - return - } - - sval := ret[0] - - var content []byte - - if sval.Kind() == reflect.String { - content = []byte(sval.String()) - } else if sval.Kind() == reflect.Slice && sval.Type().Elem().Kind() == reflect.Uint8 { - content = sval.Interface().([]byte) - } - ctx.SetHeader("Content-Length", strconv.Itoa(len(content)), true) - _, err = ctx.ResponseWriter.Write(content) - if err != nil { - ctx.Server.Logger.Println("Error during write: ", err) - } - return - } - - // try serving index.html or index.htm - if req.Method == "GET" || req.Method == "HEAD" { - if s.tryServingFile(path.Join(requestPath, "index.html"), req, w) { - return - } else if s.tryServingFile(path.Join(requestPath, "index.htm"), req, w) { - return - } - } - ctx.Abort(404, "Page not found") - return + requestPath := req.URL.Path + ctx := Context{req, map[string]string{}, s, w} + + //set some default headers + ctx.SetHeader("Server", "web.go", true) + tm := time.Now().UTC() + + //ignore errors from ParseForm because it's usually harmless. + req.ParseForm() + if len(req.Form) > 0 { + for k, v := range req.Form { + ctx.Params[k] = v[0] + } + } + + defer s.logRequest(ctx, tm) + + ctx.SetHeader("Date", webTime(tm), true) + + if req.Method == "GET" || req.Method == "HEAD" { + if s.tryServingFile(requestPath, req, w) { + return + } + } + + for i := 0; i < len(s.routes); i++ { + route := s.routes[i] + cr := route.cr + //if the methods don't match, skip this handler (except HEAD can be used in place of GET) + if req.Method != route.method && !(req.Method == "HEAD" && route.method == "GET") { + continue + } + + if !cr.MatchString(requestPath) { + continue + } + match := cr.FindStringSubmatch(requestPath) + + if len(match[0]) != len(requestPath) { + continue + } + + if route.httpHandler != nil { + unused = &route + // We can not handle custom http handlers here, give back to the caller. + return + } + + // set the default content-type + ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true) + + var args []reflect.Value + handlerType := route.handler.Type() + if requiresContext(handlerType) { + args = append(args, reflect.ValueOf(&ctx)) + } + for _, arg := range match[1:] { + args = append(args, reflect.ValueOf(arg)) + } + + ret, err := s.safelyCall(route.handler, args) + if err != nil { + //there was an error or panic while calling the handler + ctx.Abort(500, "Server Error") + } + if len(ret) == 0 { + return + } + + sval := ret[0] + + var content []byte + + if sval.Kind() == reflect.String { + content = []byte(sval.String()) + } else if sval.Kind() == reflect.Slice && sval.Type().Elem().Kind() == reflect.Uint8 { + content = sval.Interface().([]byte) + } + ctx.SetHeader("Content-Length", strconv.Itoa(len(content)), true) + _, err = ctx.ResponseWriter.Write(content) + if err != nil { + ctx.Server.Logger.Println("Error during write: ", err) + } + return + } + + // try serving index.html or index.htm + if req.Method == "GET" || req.Method == "HEAD" { + if s.tryServingFile(path.Join(requestPath, "index.html"), req, w) { + return + } else if s.tryServingFile(path.Join(requestPath, "index.htm"), req, w) { + return + } + } + ctx.Abort(404, "Page not found") + return } // SetLogger sets the logger for server s func (s *Server) SetLogger(logger *log.Logger) { - s.Logger = logger + s.Logger = logger } diff --git a/status.go b/status.go deleted file mode 100644 index 83053cce..00000000 --- a/status.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package web - -import "net/http" - -var statusText = map[int]string{ - http.StatusContinue: "Continue", - http.StatusSwitchingProtocols: "Switching Protocols", - - http.StatusOK: "OK", - http.StatusCreated: "Created", - http.StatusAccepted: "Accepted", - http.StatusNonAuthoritativeInfo: "Non-Authoritative Information", - http.StatusNoContent: "No Content", - http.StatusResetContent: "Reset Content", - http.StatusPartialContent: "Partial Content", - - http.StatusMultipleChoices: "Multiple Choices", - http.StatusMovedPermanently: "Moved Permanently", - http.StatusFound: "Found", - http.StatusSeeOther: "See Other", - http.StatusNotModified: "Not Modified", - http.StatusUseProxy: "Use Proxy", - http.StatusTemporaryRedirect: "Temporary Redirect", - - http.StatusBadRequest: "Bad Request", - http.StatusUnauthorized: "Unauthorized", - http.StatusPaymentRequired: "Payment Required", - http.StatusForbidden: "Forbidden", - http.StatusNotFound: "Not Found", - http.StatusMethodNotAllowed: "Method Not Allowed", - http.StatusNotAcceptable: "Not Acceptable", - http.StatusProxyAuthRequired: "Proxy Authentication Required", - http.StatusRequestTimeout: "Request Timeout", - http.StatusConflict: "Conflict", - http.StatusGone: "Gone", - http.StatusLengthRequired: "Length Required", - http.StatusPreconditionFailed: "Precondition Failed", - http.StatusRequestEntityTooLarge: "Request Entity Too Large", - http.StatusRequestURITooLong: "Request URI Too Long", - http.StatusUnsupportedMediaType: "Unsupported Media Type", - http.StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", - http.StatusExpectationFailed: "Expectation Failed", - - http.StatusInternalServerError: "Internal Server Error", - http.StatusNotImplemented: "Not Implemented", - http.StatusBadGateway: "Bad Gateway", - http.StatusServiceUnavailable: "Service Unavailable", - http.StatusGatewayTimeout: "Gateway Timeout", - http.StatusHTTPVersionNotSupported: "HTTP Version Not Supported", -} diff --git a/ttycolors.go b/ttycolors.go new file mode 100644 index 00000000..fe63c1af --- /dev/null +++ b/ttycolors.go @@ -0,0 +1,30 @@ +package web + +import ( + "golang.org/x/crypto/ssh/terminal" + "syscall" +) + +var ttyCodes struct { + green string + white string + reset string +} + +func init() { + ttyCodes.green = ttyBold("32") + ttyCodes.white = ttyBold("37") + ttyCodes.reset = ttyEscape("0") +} + +func ttyBold(code string) string { + return ttyEscape("1;" + code) +} + +func ttyEscape(code string) string { + if terminal.IsTerminal(syscall.Stdout) { + return "\x1b[" + code + "m" + } else { + return "" + } +} diff --git a/web.go b/web.go index 41bc8ec4..8d36130d 100644 --- a/web.go +++ b/web.go @@ -3,23 +3,15 @@ package web import ( - "bytes" - "code.google.com/p/go.net/websocket" - "crypto/hmac" - "crypto/sha1" - "crypto/tls" - "encoding/base64" - "fmt" - "io/ioutil" - "log" - "mime" - "net/http" - "os" - "path" - "reflect" - "strconv" - "strings" - "time" + "crypto/tls" + "golang.org/x/net/websocket" + "log" + "mime" + "net/http" + "os" + "path" + "reflect" + "strings" ) // A Context object is created for every incoming HTTP request, and is @@ -27,15 +19,15 @@ import ( // about the request, including the http.Request object, the GET and POST params, // and acts as a Writer for the response. type Context struct { - Request *http.Request - Params map[string]string - Server *Server - http.ResponseWriter + Request *http.Request + Params map[string]string + Server *Server + http.ResponseWriter } // WriteString writes string data into the response object. func (ctx *Context) WriteString(content string) { - ctx.ResponseWriter.Write([]byte(content)) + ctx.ResponseWriter.Write([]byte(content)) } // Abort is a helper method that sends an HTTP header and an optional @@ -43,36 +35,42 @@ func (ctx *Context) WriteString(content string) { // Once it has been called, any return value from the handler will // not be written to the response. func (ctx *Context) Abort(status int, body string) { - ctx.ResponseWriter.WriteHeader(status) - ctx.ResponseWriter.Write([]byte(body)) + ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true) + ctx.ResponseWriter.WriteHeader(status) + ctx.ResponseWriter.Write([]byte(body)) } // Redirect is a helper method for 3xx redirects. func (ctx *Context) Redirect(status int, url_ string) { - ctx.ResponseWriter.Header().Set("Location", url_) - ctx.ResponseWriter.WriteHeader(status) - ctx.ResponseWriter.Write([]byte("Redirecting to: " + url_)) + ctx.ResponseWriter.Header().Set("Location", url_) + ctx.ResponseWriter.WriteHeader(status) + ctx.ResponseWriter.Write([]byte("Redirecting to: " + url_)) } -// Notmodified writes a 304 HTTP response -func (ctx *Context) NotModified() { - ctx.ResponseWriter.WriteHeader(304) +//BadRequest writes a 400 HTTP response +func (ctx *Context) BadRequest() { + ctx.ResponseWriter.WriteHeader(400) } -// NotFound writes a 404 HTTP response -func (ctx *Context) NotFound(message string) { - ctx.ResponseWriter.WriteHeader(404) - ctx.ResponseWriter.Write([]byte(message)) +// Notmodified writes a 304 HTTP response +func (ctx *Context) NotModified() { + ctx.ResponseWriter.WriteHeader(304) } //Unauthorized writes a 401 HTTP response func (ctx *Context) Unauthorized() { - ctx.ResponseWriter.WriteHeader(401) + ctx.ResponseWriter.WriteHeader(401) } //Forbidden writes a 403 HTTP response func (ctx *Context) Forbidden() { - ctx.ResponseWriter.WriteHeader(403) + ctx.ResponseWriter.WriteHeader(403) +} + +// NotFound writes a 404 HTTP response +func (ctx *Context) NotFound(message string) { + ctx.ResponseWriter.WriteHeader(404) + ctx.ResponseWriter.Write([]byte(message)) } // ContentType sets the Content-Type header for an HTTP response. @@ -81,93 +79,34 @@ func (ctx *Context) Forbidden() { // verbatim. The return value is the content type as it was // set, or an empty string if none was found. func (ctx *Context) ContentType(val string) string { - var ctype string - if strings.ContainsRune(val, '/') { - ctype = val - } else { - if !strings.HasPrefix(val, ".") { - val = "." + val - } - ctype = mime.TypeByExtension(val) - } - if ctype != "" { - ctx.Header().Set("Content-Type", ctype) - } - return ctype + var ctype string + if strings.ContainsRune(val, '/') { + ctype = val + } else { + if !strings.HasPrefix(val, ".") { + val = "." + val + } + ctype = mime.TypeByExtension(val) + } + if ctype != "" { + ctx.Header().Set("Content-Type", ctype) + } + return ctype } // SetHeader sets a response header. If `unique` is true, the current value // of that header will be overwritten . If false, it will be appended. func (ctx *Context) SetHeader(hdr string, val string, unique bool) { - if unique { - ctx.Header().Set(hdr, val) - } else { - ctx.Header().Add(hdr, val) - } + if unique { + ctx.Header().Set(hdr, val) + } else { + ctx.Header().Add(hdr, val) + } } // SetCookie adds a cookie header to the response. func (ctx *Context) SetCookie(cookie *http.Cookie) { - ctx.SetHeader("Set-Cookie", cookie.String(), false) -} - -func getCookieSig(key string, val []byte, timestamp string) string { - hm := hmac.New(sha1.New, []byte(key)) - - hm.Write(val) - hm.Write([]byte(timestamp)) - - hex := fmt.Sprintf("%02x", hm.Sum(nil)) - return hex -} - -func (ctx *Context) SetSecureCookie(name string, val string, age int64) { - //base64 encode the val - if len(ctx.Server.Config.CookieSecret) == 0 { - ctx.Server.Logger.Println("Secret Key for secure cookies has not been set. Please assign a cookie secret to web.Config.CookieSecret.") - return - } - var buf bytes.Buffer - encoder := base64.NewEncoder(base64.StdEncoding, &buf) - encoder.Write([]byte(val)) - encoder.Close() - vs := buf.String() - vb := buf.Bytes() - timestamp := strconv.FormatInt(time.Now().Unix(), 10) - sig := getCookieSig(ctx.Server.Config.CookieSecret, vb, timestamp) - cookie := strings.Join([]string{vs, timestamp, sig}, "|") - ctx.SetCookie(NewCookie(name, cookie, age)) -} - -func (ctx *Context) GetSecureCookie(name string) (string, bool) { - for _, cookie := range ctx.Request.Cookies() { - if cookie.Name != name { - continue - } - - parts := strings.SplitN(cookie.Value, "|", 3) - - val := parts[0] - timestamp := parts[1] - sig := parts[2] - - if getCookieSig(ctx.Server.Config.CookieSecret, []byte(val), timestamp) != sig { - return "", false - } - - ts, _ := strconv.ParseInt(timestamp, 0, 64) - - if time.Now().Unix()-31*86400 > ts { - return "", false - } - - buf := bytes.NewBufferString(val) - encoder := base64.NewDecoder(base64.StdEncoding, buf) - - res, _ := ioutil.ReadAll(encoder) - return string(res), true - } - return "", false + ctx.SetHeader("Set-Cookie", cookie.String(), false) } // small optimization: cache the context type instead of repeteadly calling reflect.Typeof @@ -176,96 +115,97 @@ var contextType reflect.Type var defaultStaticDirs []string func init() { - contextType = reflect.TypeOf(Context{}) - //find the location of the exe file - wd, _ := os.Getwd() - arg0 := path.Clean(os.Args[0]) - var exeFile string - if strings.HasPrefix(arg0, "/") { - exeFile = arg0 - } else { - //TODO for robustness, search each directory in $PATH - exeFile = path.Join(wd, arg0) - } - parent, _ := path.Split(exeFile) - defaultStaticDirs = append(defaultStaticDirs, path.Join(parent, "static")) - defaultStaticDirs = append(defaultStaticDirs, path.Join(wd, "static")) - return + contextType = reflect.TypeOf(Context{}) + //find the location of the exe file + wd, _ := os.Getwd() + arg0 := path.Clean(os.Args[0]) + var exeFile string + if strings.HasPrefix(arg0, "/") { + exeFile = arg0 + } else { + //TODO for robustness, search each directory in $PATH + exeFile = path.Join(wd, arg0) + } + parent, _ := path.Split(exeFile) + defaultStaticDirs = append(defaultStaticDirs, path.Join(parent, "static")) + defaultStaticDirs = append(defaultStaticDirs, path.Join(wd, "static")) + return } // Process invokes the main server's routing system. func Process(c http.ResponseWriter, req *http.Request) { - mainServer.Process(c, req) + mainServer.Process(c, req) } // Run starts the web application and serves HTTP requests for the main server. func Run(addr string) { - mainServer.Run(addr) + mainServer.Run(addr) } // RunTLS starts the web application and serves HTTPS requests for the main server. func RunTLS(addr string, config *tls.Config) { - mainServer.RunTLS(addr, config) + mainServer.RunTLS(addr, config) } // RunScgi starts the web application and serves SCGI requests for the main server. func RunScgi(addr string) { - mainServer.RunScgi(addr) + mainServer.RunScgi(addr) } // RunFcgi starts the web application and serves FastCGI requests for the main server. func RunFcgi(addr string) { - mainServer.RunFcgi(addr) + mainServer.RunFcgi(addr) } // Close stops the main server. func Close() { - mainServer.Close() + mainServer.Close() } // Get adds a handler for the 'GET' http method in the main server. func Get(route string, handler interface{}) { - mainServer.Get(route, handler) + mainServer.Get(route, handler) } // Post adds a handler for the 'POST' http method in the main server. func Post(route string, handler interface{}) { - mainServer.addRoute(route, "POST", handler) + mainServer.addRoute(route, "POST", handler) } // Put adds a handler for the 'PUT' http method in the main server. func Put(route string, handler interface{}) { - mainServer.addRoute(route, "PUT", handler) + mainServer.addRoute(route, "PUT", handler) } // Delete adds a handler for the 'DELETE' http method in the main server. func Delete(route string, handler interface{}) { - mainServer.addRoute(route, "DELETE", handler) + mainServer.addRoute(route, "DELETE", handler) } // Match adds a handler for an arbitrary http method in the main server. func Match(method string, route string, handler interface{}) { - mainServer.addRoute(route, method, handler) + mainServer.addRoute(route, method, handler) } -//Adds a custom handler. Only for webserver mode. Will have no effect when running as FCGI or SCGI. -func Handler(route string, method string, httpHandler http.Handler) { - mainServer.Handler(route, method, httpHandler) +// Add a custom http.Handler. Will have no effect when running as FCGI or SCGI. +func Handle(route string, method string, httpHandler http.Handler) { + mainServer.Handle(route, method, httpHandler) } //Adds a handler for websockets. Only for webserver mode. Will have no effect when running as FCGI or SCGI. func Websocket(route string, httpHandler websocket.Handler) { - mainServer.Websocket(route, httpHandler) + mainServer.Websocket(route, httpHandler) } // SetLogger sets the logger for the main server. func SetLogger(logger *log.Logger) { - mainServer.Logger = logger + mainServer.Logger = logger } // Config is the configuration of the main server. var Config = &ServerConfig{ - RecoverPanic: true, + RecoverPanic: true, + ColorOutput: true, } var mainServer = NewServer() diff --git a/web_test.go b/web_test.go index 609ca509..6468b796 100644 --- a/web_test.go +++ b/web_test.go @@ -1,580 +1,679 @@ package web import ( - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "io/ioutil" - "log" - "net/http" - "net/url" - "runtime" - "strconv" - "strings" - "testing" + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "net/url" + "runtime" + "strconv" + "strings" + "testing" ) func init() { - runtime.GOMAXPROCS(4) + runtime.GOMAXPROCS(runtime.NumCPU()) } // ioBuffer is a helper that implements io.ReadWriteCloser, // which is helpful in imitating a net.Conn type ioBuffer struct { - input *bytes.Buffer - output *bytes.Buffer - closed bool + input *bytes.Buffer + output *bytes.Buffer + closed bool } func (buf *ioBuffer) Write(p []uint8) (n int, err error) { - if buf.closed { - return 0, errors.New("Write after Close on ioBuffer") - } - return buf.output.Write(p) + if buf.closed { + return 0, errors.New("Write after Close on ioBuffer") + } + return buf.output.Write(p) } func (buf *ioBuffer) Read(p []byte) (n int, err error) { - if buf.closed { - return 0, errors.New("Read after Close on ioBuffer") - } - return buf.input.Read(p) + if buf.closed { + return 0, errors.New("Read after Close on ioBuffer") + } + return buf.input.Read(p) } //noop func (buf *ioBuffer) Close() error { - buf.closed = true - return nil + buf.closed = true + return nil } type testResponse struct { - statusCode int - status string - body string - headers map[string][]string - cookies map[string]string + statusCode int + status string + body string + headers map[string][]string + cookies map[string]string } func buildTestResponse(buf *bytes.Buffer) *testResponse { - response := testResponse{headers: make(map[string][]string), cookies: make(map[string]string)} - s := buf.String() - contents := strings.SplitN(s, "\r\n\r\n", 2) + response := testResponse{headers: make(map[string][]string), cookies: make(map[string]string)} + s := buf.String() + contents := strings.SplitN(s, "\r\n\r\n", 2) - header := contents[0] + header := contents[0] - if len(contents) > 1 { - response.body = contents[1] - } + if len(contents) > 1 { + response.body = contents[1] + } - headers := strings.Split(header, "\r\n") + headers := strings.Split(header, "\r\n") - statusParts := strings.SplitN(headers[0], " ", 3) - response.statusCode, _ = strconv.Atoi(statusParts[1]) + statusParts := strings.SplitN(headers[0], " ", 3) + response.statusCode, _ = strconv.Atoi(statusParts[1]) - for _, h := range headers[1:] { - split := strings.SplitN(h, ":", 2) - name := strings.TrimSpace(split[0]) - value := strings.TrimSpace(split[1]) - if _, ok := response.headers[name]; !ok { - response.headers[name] = []string{} - } + for _, h := range headers[1:] { + split := strings.SplitN(h, ":", 2) + name := strings.TrimSpace(split[0]) + value := strings.TrimSpace(split[1]) + if _, ok := response.headers[name]; !ok { + response.headers[name] = []string{} + } - newheaders := make([]string, len(response.headers[name])+1) - copy(newheaders, response.headers[name]) - newheaders[len(newheaders)-1] = value - response.headers[name] = newheaders + newheaders := make([]string, len(response.headers[name])+1) + copy(newheaders, response.headers[name]) + newheaders[len(newheaders)-1] = value + response.headers[name] = newheaders - //if the header is a cookie, set it - if name == "Set-Cookie" { - i := strings.Index(value, ";") - cookie := value[0:i] - cookieParts := strings.SplitN(cookie, "=", 2) - response.cookies[strings.TrimSpace(cookieParts[0])] = strings.TrimSpace(cookieParts[1]) - } - } + //if the header is a cookie, set it + if name == "Set-Cookie" { + i := strings.Index(value, ";") + cookie := value[0:i] + cookieParts := strings.SplitN(cookie, "=", 2) + response.cookies[strings.TrimSpace(cookieParts[0])] = strings.TrimSpace(cookieParts[1]) + } + } - return &response + return &response } func getTestResponse(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *testResponse { - req := buildTestRequest(method, path, body, headers, cookies) - var buf bytes.Buffer + req := buildTestRequest(method, path, body, headers, cookies) + var buf bytes.Buffer - tcpb := ioBuffer{input: nil, output: &buf} - c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &tcpb} - mainServer.Process(&c, req) - return buildTestResponse(&buf) + tcpb := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &tcpb} + mainServer.Process(&c, req) + return buildTestResponse(&buf) } func testGet(path string, headers map[string]string) *testResponse { - var header http.Header - for k, v := range headers { - header.Set(k, v) - } - return getTestResponse("GET", path, "", header, nil) + var header http.Header + for k, v := range headers { + header.Set(k, v) + } + return getTestResponse("GET", path, "", header, nil) } type Test struct { - method string - path string - headers map[string][]string - body string - expectedStatus int - expectedBody string + method string + path string + headers map[string][]string + body string + expectedStatus int + expectedBody string } //initialize the routes func init() { - mainServer.SetLogger(log.New(ioutil.Discard, "", 0)) - Get("/", func() string { return "index" }) - Get("/panic", func() { panic(0) }) - Get("/echo/(.*)", func(s string) string { return s }) - Get("/multiecho/(.*)/(.*)/(.*)/(.*)", func(a, b, c, d string) string { return a + b + c + d }) - Post("/post/echo/(.*)", func(s string) string { return s }) - Post("/post/echoparam/(.*)", func(ctx *Context, name string) string { return ctx.Params[name] }) - - Get("/error/code/(.*)", func(ctx *Context, code string) string { - n, _ := strconv.Atoi(code) - message := statusText[n] - ctx.Abort(n, message) - return "" - }) - - Get("/error/notfound/(.*)", func(ctx *Context, message string) { ctx.NotFound(message) }) - - Get("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) - Post("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) - - Get("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) - Post("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) - - Post("/posterror/code/(.*)/(.*)", func(ctx *Context, code string, message string) string { - n, _ := strconv.Atoi(code) - ctx.Abort(n, message) - return "" - }) - - Get("/writetest", func(ctx *Context) { ctx.WriteString("hello") }) - - Post("/securecookie/set/(.+)/(.+)", func(ctx *Context, name string, val string) string { - ctx.SetSecureCookie(name, val, 60) - return "" - }) - - Get("/securecookie/get/(.+)", func(ctx *Context, name string) string { - val, ok := ctx.GetSecureCookie(name) - if !ok { - return "" - } - return val - }) - Get("/getparam", func(ctx *Context) string { return ctx.Params["a"] }) - Get("/fullparams", func(ctx *Context) string { - return strings.Join(ctx.Request.Form["a"], ",") - }) - - Get("/json", func(ctx *Context) string { - ctx.ContentType("json") - data, _ := json.Marshal(ctx.Params) - return string(data) - }) - - Get("/jsonbytes", func(ctx *Context) []byte { - ctx.ContentType("json") - data, _ := json.Marshal(ctx.Params) - return data - }) - - Post("/parsejson", func(ctx *Context) string { - var tmp = struct { - A string - B string - }{} - json.NewDecoder(ctx.Request.Body).Decode(&tmp) - return tmp.A + " " + tmp.B - }) - - Match("OPTIONS", "/options", func(ctx *Context) { - ctx.SetHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS", true) - ctx.SetHeader("Access-Control-Max-Age", "1000", true) - ctx.WriteHeader(200) - }) - - Get("/dupeheader", func(ctx *Context) string { - ctx.SetHeader("Server", "myserver", true) - return "" - }) - - Get("/authorization", func(ctx *Context) string { - user, pass, err := ctx.GetBasicAuth() - if err != nil { - return "fail" - } - return user + pass - }) + mainServer.SetLogger(log.New(ioutil.Discard, "", 0)) + Get("/", func() string { return "index" }) + Get("/panic", func() { panic(0) }) + Get("/echo/(.*)", func(s string) string { return s }) + Get("/multiecho/(.*)/(.*)/(.*)/(.*)", func(a, b, c, d string) string { return a + b + c + d }) + Post("/post/echo/(.*)", func(s string) string { return s }) + Post("/post/echoparam/(.*)", func(ctx *Context, name string) string { return ctx.Params[name] }) + + Get("/error/code/(.*)", func(ctx *Context, code string) string { + n, _ := strconv.Atoi(code) + message := http.StatusText(n) + ctx.Abort(n, message) + return "" + }) + + Get("/error/notfound/(.*)", func(ctx *Context, message string) { ctx.NotFound(message) }) + + Get("/error/badrequest", func(ctx *Context) { ctx.BadRequest() }) + Post("/error/badrequest", func(ctx *Context) { ctx.BadRequest() }) + + Get("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) + Post("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) + + Get("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) + Post("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) + + Post("/posterror/code/(.*)/(.*)", func(ctx *Context, code string, message string) string { + n, _ := strconv.Atoi(code) + ctx.Abort(n, message) + return "" + }) + + Get("/writetest", func(ctx *Context) { ctx.WriteString("hello") }) + + Post("/securecookie/set/(.+)/(.+)", func(ctx *Context, name string, val string) string { + ctx.SetSecureCookie(name, val, 60) + return "" + }) + + Get("/securecookie/get/(.+)", func(ctx *Context, name string) string { + val, ok := ctx.GetSecureCookie(name) + if !ok { + return "" + } + return val + }) + Get("/getparam", func(ctx *Context) string { return ctx.Params["a"] }) + Get("/fullparams", func(ctx *Context) string { + return strings.Join(ctx.Request.Form["a"], ",") + }) + + Get("/json", func(ctx *Context) string { + ctx.ContentType("json") + data, _ := json.Marshal(ctx.Params) + return string(data) + }) + + Get("/jsonbytes", func(ctx *Context) []byte { + ctx.ContentType("json") + data, _ := json.Marshal(ctx.Params) + return data + }) + + Post("/parsejson", func(ctx *Context) string { + var tmp = struct { + A string + B string + }{} + json.NewDecoder(ctx.Request.Body).Decode(&tmp) + return tmp.A + " " + tmp.B + }) + + Match("OPTIONS", "/options", func(ctx *Context) { + ctx.SetHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS", true) + ctx.SetHeader("Access-Control-Max-Age", "1000", true) + ctx.WriteHeader(200) + }) + + Get("/dupeheader", func(ctx *Context) string { + ctx.SetHeader("Server", "myserver", true) + return "" + }) + + Get("/authorization", func(ctx *Context) string { + user, pass, err := ctx.GetBasicAuth() + if err != nil { + return "fail" + } + return user + pass + }) } var tests = []Test{ - {"GET", "/", nil, "", 200, "index"}, - {"GET", "/echo/hello", nil, "", 200, "hello"}, - {"GET", "/echo/hello", nil, "", 200, "hello"}, - {"GET", "/multiecho/a/b/c/d", nil, "", 200, "abcd"}, - {"POST", "/post/echo/hello", nil, "", 200, "hello"}, - {"POST", "/post/echo/hello", nil, "", 200, "hello"}, - {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello", 200, "hello"}, - {"POST", "/post/echoparam/c?c=hello", nil, "", 200, "hello"}, - {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello\x00", 200, "hello\x00"}, - //long url - {"GET", "/echo/" + strings.Repeat("0123456789", 100), nil, "", 200, strings.Repeat("0123456789", 100)}, - {"GET", "/writetest", nil, "", 200, "hello"}, - {"GET", "/error/unauthorized", nil, "", 401, ""}, - {"POST", "/error/unauthorized", nil, "", 401, ""}, - {"GET", "/error/forbidden", nil, "", 403, ""}, - {"POST", "/error/forbidden", nil, "", 403, ""}, - {"GET", "/error/notfound/notfound", nil, "", 404, "notfound"}, - {"GET", "/doesnotexist", nil, "", 404, "Page not found"}, - {"POST", "/doesnotexist", nil, "", 404, "Page not found"}, - {"GET", "/error/code/500", nil, "", 500, statusText[500]}, - {"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"}, - {"GET", "/getparam?a=abcd", nil, "", 200, "abcd"}, - {"GET", "/getparam?b=abcd", nil, "", 200, ""}, - {"GET", "/fullparams?a=1&a=2&a=3", nil, "", 200, "1,2,3"}, - {"GET", "/panic", nil, "", 500, "Server Error"}, - {"GET", "/json?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, - {"GET", "/jsonbytes?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, - {"POST", "/parsejson", map[string][]string{"Content-Type": {"application/json"}}, `{"a":"hello", "b":"world"}`, 200, "hello world"}, - //{"GET", "/testenv", "", 200, "hello world"}, - {"GET", "/authorization", map[string][]string{"Authorization": {BuildBasicAuthCredentials("foo", "bar")}}, "", 200, "foobar"}, + {"GET", "/", nil, "", 200, "index"}, + {"GET", "/echo/hello", nil, "", 200, "hello"}, + {"GET", "/echo/hello", nil, "", 200, "hello"}, + {"GET", "/multiecho/a/b/c/d", nil, "", 200, "abcd"}, + {"POST", "/post/echo/hello", nil, "", 200, "hello"}, + {"POST", "/post/echo/hello", nil, "", 200, "hello"}, + {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello", 200, "hello"}, + {"POST", "/post/echoparam/c?c=hello", nil, "", 200, "hello"}, + {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello\x00", 200, "hello\x00"}, + //long url + {"GET", "/echo/" + strings.Repeat("0123456789", 100), nil, "", 200, strings.Repeat("0123456789", 100)}, + {"GET", "/writetest", nil, "", 200, "hello"}, + {"GET", "/error/badrequest", nil, "", 400, ""}, + {"POST", "/error/badrequest", nil, "", 400, ""}, + {"GET", "/error/unauthorized", nil, "", 401, ""}, + {"POST", "/error/unauthorized", nil, "", 401, ""}, + {"GET", "/error/forbidden", nil, "", 403, ""}, + {"POST", "/error/forbidden", nil, "", 403, ""}, + {"GET", "/error/notfound/notfound", nil, "", 404, "notfound"}, + {"GET", "/doesnotexist", nil, "", 404, "Page not found"}, + {"POST", "/doesnotexist", nil, "", 404, "Page not found"}, + {"GET", "/error/code/500", nil, "", 500, http.StatusText(500)}, + {"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"}, + {"GET", "/getparam?a=abcd", nil, "", 200, "abcd"}, + {"GET", "/getparam?b=abcd", nil, "", 200, ""}, + {"GET", "/fullparams?a=1&a=2&a=3", nil, "", 200, "1,2,3"}, + {"GET", "/panic", nil, "", 500, "Server Error"}, + {"GET", "/json?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, + {"GET", "/jsonbytes?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, + {"POST", "/parsejson", map[string][]string{"Content-Type": {"application/json"}}, `{"a":"hello", "b":"world"}`, 200, "hello world"}, + //{"GET", "/testenv", "", 200, "hello world"}, + {"GET", "/authorization", map[string][]string{"Authorization": {BuildBasicAuthCredentials("foo", "bar")}}, "", 200, "foobar"}, + {"GET", "/authorization", nil, "", 200, "fail"}, } func buildTestRequest(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *http.Request { - host := "127.0.0.1" - port := "80" - rawurl := "http://" + host + ":" + port + path - url_, _ := url.Parse(rawurl) - proto := "HTTP/1.1" - - if headers == nil { - headers = map[string][]string{} - } - - headers["User-Agent"] = []string{"web.go test"} - if method == "POST" { - headers["Content-Length"] = []string{fmt.Sprintf("%d", len(body))} - if headers["Content-Type"] == nil { - headers["Content-Type"] = []string{"text/plain"} - } - } - - req := http.Request{Method: method, - URL: url_, - Proto: proto, - Host: host, - Header: http.Header(headers), - Body: ioutil.NopCloser(bytes.NewBufferString(body)), - } - - for _, cookie := range cookies { - req.AddCookie(cookie) - } - return &req + host := "127.0.0.1" + port := "80" + rawurl := "http://" + host + ":" + port + path + url_, _ := url.Parse(rawurl) + proto := "HTTP/1.1" + + if headers == nil { + headers = map[string][]string{} + } + + headers["User-Agent"] = []string{"web.go test"} + if method == "POST" { + headers["Content-Length"] = []string{fmt.Sprintf("%d", len(body))} + if headers["Content-Type"] == nil { + headers["Content-Type"] = []string{"text/plain"} + } + } + + req := http.Request{Method: method, + URL: url_, + Proto: proto, + Host: host, + Header: http.Header(headers), + Body: ioutil.NopCloser(bytes.NewBufferString(body)), + } + + for _, cookie := range cookies { + req.AddCookie(cookie) + } + return &req } func TestRouting(t *testing.T) { - for _, test := range tests { - resp := getTestResponse(test.method, test.path, test.body, test.headers, nil) - - if resp.statusCode != test.expectedStatus { - t.Fatalf("%v(%v) expected status %d got %d", test.method, test.path, test.expectedStatus, resp.statusCode) - } - if resp.body != test.expectedBody { - t.Fatalf("%v(%v) expected %q got %q", test.method, test.path, test.expectedBody, resp.body) - } - if cl, ok := resp.headers["Content-Length"]; ok { - clExp, _ := strconv.Atoi(cl[0]) - clAct := len(resp.body) - if clExp != clAct { - t.Fatalf("Content-length doesn't match. expected %d got %d", clExp, clAct) - } - } - } + for _, test := range tests { + resp := getTestResponse(test.method, test.path, test.body, test.headers, nil) + + if resp.statusCode != test.expectedStatus { + t.Fatalf("%v(%v) expected status %d got %d", test.method, test.path, test.expectedStatus, resp.statusCode) + } + if resp.body != test.expectedBody { + t.Fatalf("%v(%v) expected %q got %q", test.method, test.path, test.expectedBody, resp.body) + } + if cl, ok := resp.headers["Content-Length"]; ok { + clExp, _ := strconv.Atoi(cl[0]) + clAct := len(resp.body) + if clExp != clAct { + t.Fatalf("Content-length doesn't match. expected %d got %d", clExp, clAct) + } + } + } } func TestHead(t *testing.T) { - for _, test := range tests { - - if test.method != "GET" { - continue - } - getresp := getTestResponse("GET", test.path, test.body, test.headers, nil) - headresp := getTestResponse("HEAD", test.path, test.body, test.headers, nil) - - if getresp.statusCode != headresp.statusCode { - t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) - } - if len(headresp.body) != 0 { - t.Fatalf("head request arrived with a body") - } - - var cl []string - var getcl, headcl int - var hascl1, hascl2 bool - - if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { - getcl, _ = strconv.Atoi(cl[0]) - } - - if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { - headcl, _ = strconv.Atoi(cl[0]) - } - - if hascl1 != hascl2 { - t.Fatalf("head and get: one has content-length, one doesn't") - } - - if hascl1 == true && getcl != headcl { - t.Fatalf("head and get content-length differ") - } - } + for _, test := range tests { + + if test.method != "GET" { + continue + } + getresp := getTestResponse("GET", test.path, test.body, test.headers, nil) + headresp := getTestResponse("HEAD", test.path, test.body, test.headers, nil) + + if getresp.statusCode != headresp.statusCode { + t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) + } + if len(headresp.body) != 0 { + t.Fatalf("head request arrived with a body") + } + + var cl []string + var getcl, headcl int + var hascl1, hascl2 bool + + if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { + getcl, _ = strconv.Atoi(cl[0]) + } + + if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { + headcl, _ = strconv.Atoi(cl[0]) + } + + if hascl1 != hascl2 { + t.Fatalf("head and get: one has content-length, one doesn't") + } + + if hascl1 == true && getcl != headcl { + t.Fatalf("head and get content-length differ") + } + } } func buildTestScgiRequest(method string, path string, body string, headers map[string][]string) *bytes.Buffer { - var headerBuf bytes.Buffer - scgiHeaders := make(map[string]string) - - headerBuf.WriteString("CONTENT_LENGTH") - headerBuf.WriteByte(0) - headerBuf.WriteString(fmt.Sprintf("%d", len(body))) - headerBuf.WriteByte(0) - - scgiHeaders["REQUEST_METHOD"] = method - scgiHeaders["HTTP_HOST"] = "127.0.0.1" - scgiHeaders["REQUEST_URI"] = path - scgiHeaders["SERVER_PORT"] = "80" - scgiHeaders["SERVER_PROTOCOL"] = "HTTP/1.1" - scgiHeaders["USER_AGENT"] = "web.go test framework" - - for k, v := range headers { - //Skip content-length - if k == "Content-Length" { - continue - } - key := "HTTP_" + strings.ToUpper(strings.Replace(k, "-", "_", -1)) - scgiHeaders[key] = v[0] - } - for k, v := range scgiHeaders { - headerBuf.WriteString(k) - headerBuf.WriteByte(0) - headerBuf.WriteString(v) - headerBuf.WriteByte(0) - } - headerData := headerBuf.Bytes() - - var buf bytes.Buffer - //extra 1 is for the comma at the end - dlen := len(headerData) - fmt.Fprintf(&buf, "%d:", dlen) - buf.Write(headerData) - buf.WriteByte(',') - buf.WriteString(body) - return &buf + var headerBuf bytes.Buffer + scgiHeaders := make(map[string]string) + + headerBuf.WriteString("CONTENT_LENGTH") + headerBuf.WriteByte(0) + headerBuf.WriteString(fmt.Sprintf("%d", len(body))) + headerBuf.WriteByte(0) + + scgiHeaders["REQUEST_METHOD"] = method + scgiHeaders["HTTP_HOST"] = "127.0.0.1" + scgiHeaders["REQUEST_URI"] = path + scgiHeaders["SERVER_PORT"] = "80" + scgiHeaders["SERVER_PROTOCOL"] = "HTTP/1.1" + scgiHeaders["USER_AGENT"] = "web.go test framework" + + for k, v := range headers { + //Skip content-length + if k == "Content-Length" { + continue + } + key := "HTTP_" + strings.ToUpper(strings.Replace(k, "-", "_", -1)) + scgiHeaders[key] = v[0] + } + for k, v := range scgiHeaders { + headerBuf.WriteString(k) + headerBuf.WriteByte(0) + headerBuf.WriteString(v) + headerBuf.WriteByte(0) + } + headerData := headerBuf.Bytes() + + var buf bytes.Buffer + //extra 1 is for the comma at the end + dlen := len(headerData) + fmt.Fprintf(&buf, "%d:", dlen) + buf.Write(headerData) + buf.WriteByte(',') + buf.WriteString(body) + return &buf } func TestScgi(t *testing.T) { - for _, test := range tests { - req := buildTestScgiRequest(test.method, test.path, test.body, test.headers) - var output bytes.Buffer - nb := ioBuffer{input: req, output: &output} - mainServer.handleScgiRequest(&nb) - resp := buildTestResponse(&output) - - if resp.statusCode != test.expectedStatus { - t.Fatalf("expected status %d got %d", test.expectedStatus, resp.statusCode) - } - - if resp.body != test.expectedBody { - t.Fatalf("Scgi expected %q got %q", test.expectedBody, resp.body) - } - } + for _, test := range tests { + req := buildTestScgiRequest(test.method, test.path, test.body, test.headers) + var output bytes.Buffer + nb := ioBuffer{input: req, output: &output} + mainServer.handleScgiRequest(&nb) + resp := buildTestResponse(&output) + + if resp.statusCode != test.expectedStatus { + t.Fatalf("expected status %d got %d", test.expectedStatus, resp.statusCode) + } + + if resp.body != test.expectedBody { + t.Fatalf("Scgi expected %q got %q", test.expectedBody, resp.body) + } + } } func TestScgiHead(t *testing.T) { - for _, test := range tests { - - if test.method != "GET" { - continue - } - - req := buildTestScgiRequest("GET", test.path, test.body, make(map[string][]string)) - var output bytes.Buffer - nb := ioBuffer{input: req, output: &output} - mainServer.handleScgiRequest(&nb) - getresp := buildTestResponse(&output) - - req = buildTestScgiRequest("HEAD", test.path, test.body, make(map[string][]string)) - var output2 bytes.Buffer - nb = ioBuffer{input: req, output: &output2} - mainServer.handleScgiRequest(&nb) - headresp := buildTestResponse(&output2) - - if getresp.statusCode != headresp.statusCode { - t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) - } - if len(headresp.body) != 0 { - t.Fatalf("head request arrived with a body") - } - - var cl []string - var getcl, headcl int - var hascl1, hascl2 bool - - if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { - getcl, _ = strconv.Atoi(cl[0]) - } - - if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { - headcl, _ = strconv.Atoi(cl[0]) - } - - if hascl1 != hascl2 { - t.Fatalf("head and get: one has content-length, one doesn't") - } - - if hascl1 == true && getcl != headcl { - t.Fatalf("head and get content-length differ") - } - } + for _, test := range tests { + + if test.method != "GET" { + continue + } + + req := buildTestScgiRequest("GET", test.path, test.body, make(map[string][]string)) + var output bytes.Buffer + nb := ioBuffer{input: req, output: &output} + mainServer.handleScgiRequest(&nb) + getresp := buildTestResponse(&output) + + req = buildTestScgiRequest("HEAD", test.path, test.body, make(map[string][]string)) + var output2 bytes.Buffer + nb = ioBuffer{input: req, output: &output2} + mainServer.handleScgiRequest(&nb) + headresp := buildTestResponse(&output2) + + if getresp.statusCode != headresp.statusCode { + t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) + } + if len(headresp.body) != 0 { + t.Fatalf("head request arrived with a body") + } + + var cl []string + var getcl, headcl int + var hascl1, hascl2 bool + + if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { + getcl, _ = strconv.Atoi(cl[0]) + } + + if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { + headcl, _ = strconv.Atoi(cl[0]) + } + + if hascl1 != hascl2 { + t.Fatalf("head and get: one has content-length, one doesn't") + } + + if hascl1 == true && getcl != headcl { + t.Fatalf("head and get content-length differ") + } + } } func TestReadScgiRequest(t *testing.T) { - headers := map[string][]string{"User-Agent": {"web.go"}} - req := buildTestScgiRequest("POST", "/hello", "Hello world!", headers) - var s Server - httpReq, err := s.readScgiRequest(&ioBuffer{input: req, output: nil}) - if err != nil { - t.Fatalf("Error while reading SCGI request: ", err.Error()) - } - if httpReq.ContentLength != 12 { - t.Fatalf("Content length mismatch, expected %d, got %d ", 12, httpReq.ContentLength) - } - var body bytes.Buffer - io.Copy(&body, httpReq.Body) - if body.String() != "Hello world!" { - t.Fatalf("Body mismatch, expected %q, got %q ", "Hello world!", body.String()) - } + headers := map[string][]string{"User-Agent": {"web.go"}} + req := buildTestScgiRequest("POST", "/hello", "Hello world!", headers) + var s Server + httpReq, err := s.readScgiRequest(&ioBuffer{input: req, output: nil}) + if err != nil { + t.Fatalf("Error while reading SCGI request: ", err.Error()) + } + if httpReq.ContentLength != 12 { + t.Fatalf("Content length mismatch, expected %d, got %d ", 12, httpReq.ContentLength) + } + var body bytes.Buffer + io.Copy(&body, httpReq.Body) + if body.String() != "Hello world!" { + t.Fatalf("Body mismatch, expected %q, got %q ", "Hello world!", body.String()) + } } func makeCookie(vals map[string]string) []*http.Cookie { - var cookies []*http.Cookie - for k, v := range vals { - c := &http.Cookie{ - Name: k, - Value: v, - } - cookies = append(cookies, c) - } - return cookies + var cookies []*http.Cookie + for k, v := range vals { + c := &http.Cookie{ + Name: k, + Value: v, + } + cookies = append(cookies, c) + } + return cookies } func TestSecureCookie(t *testing.T) { - mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" - resp1 := getTestResponse("POST", "/securecookie/set/a/1", "", nil, nil) - sval, ok := resp1.cookies["a"] - if !ok { - t.Fatalf("Failed to get cookie ") - } - cookies := makeCookie(map[string]string{"a": sval}) + mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" + mainServer.initServer() + resp1 := getTestResponse("POST", "/securecookie/set/a/1", "", nil, nil) + sval, ok := resp1.cookies["a"] + if !ok { + t.Fatalf("Failed to get cookie ") + } + cookies := makeCookie(map[string]string{"a": sval}) + + resp2 := getTestResponse("GET", "/securecookie/get/a", "", nil, cookies) + + if resp2.body != "1" { + t.Fatalf("SecureCookie test failed") + } +} + +func TestEmptySecureCookie(t *testing.T) { + mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" + cookies := makeCookie(map[string]string{"empty": ""}) - resp2 := getTestResponse("GET", "/securecookie/get/a", "", nil, cookies) + resp2 := getTestResponse("GET", "/securecookie/get/empty", "", nil, cookies) - if resp2.body != "1" { - t.Fatalf("SecureCookie test failed") - } + if resp2.body != "" { + t.Fatalf("Expected an empty secure cookie") + } } func TestEarlyClose(t *testing.T) { - var server1 Server - server1.Close() + var server1 Server + server1.Close() } func TestOptions(t *testing.T) { - resp := getTestResponse("OPTIONS", "/options", "", nil, nil) - if resp.headers["Access-Control-Allow-Methods"][0] != "POST, GET, OPTIONS" { - t.Fatalf("TestOptions - Access-Control-Allow-Methods failed") - } - if resp.headers["Access-Control-Max-Age"][0] != "1000" { - t.Fatalf("TestOptions - Access-Control-Max-Age failed") - } + resp := getTestResponse("OPTIONS", "/options", "", nil, nil) + if resp.headers["Access-Control-Allow-Methods"][0] != "POST, GET, OPTIONS" { + t.Fatalf("TestOptions - Access-Control-Allow-Methods failed") + } + if resp.headers["Access-Control-Max-Age"][0] != "1000" { + t.Fatalf("TestOptions - Access-Control-Max-Age failed") + } } func TestSlug(t *testing.T) { - tests := [][]string{ - {"", ""}, - {"a", "a"}, - {"a/b", "a-b"}, - {"a b", "a-b"}, - {"a////b", "a-b"}, - {" a////b ", "a-b"}, - {" Manowar / Friends ", "manowar-friends"}, - } - - for _, test := range tests { - v := Slug(test[0], "-") - if v != test[1] { - t.Fatalf("TestSlug(%v) failed, expected %v, got %v", test[0], test[1], v) - } - } + tests := [][]string{ + {"", ""}, + {"a", "a"}, + {"a/b", "a-b"}, + {"a b", "a-b"}, + {"a////b", "a-b"}, + {" a////b ", "a-b"}, + {" Manowar / Friends ", "manowar-friends"}, + } + + for _, test := range tests { + v := Slug(test[0], "-") + if v != test[1] { + t.Fatalf("TestSlug(%v) failed, expected %v, got %v", test[0], test[1], v) + } + } } // tests that we don't duplicate headers func TestDuplicateHeader(t *testing.T) { - resp := testGet("/dupeheader", nil) - if len(resp.headers["Server"]) > 1 { - t.Fatalf("Expected only one header, got %#v", resp.headers["Server"]) - } - if resp.headers["Server"][0] != "myserver" { - t.Fatalf("Incorrect header, exp 'myserver', got %q", resp.headers["Server"][0]) - } + resp := testGet("/dupeheader", nil) + if len(resp.headers["Server"]) > 1 { + t.Fatalf("Expected only one header, got %#v", resp.headers["Server"]) + } + if resp.headers["Server"][0] != "myserver" { + t.Fatalf("Incorrect header, exp 'myserver', got %q", resp.headers["Server"][0]) + } +} + +// test that output contains ASCII color codes by default +func TestColorOutputDefault(t *testing.T) { + s := NewServer() + var logOutput bytes.Buffer + logger := log.New(&logOutput, "", 0) + s.Logger = logger + s.Get("/test", func() string { + return "test" + }) + req := buildTestRequest("GET", "/test", "", nil, nil) + var buf bytes.Buffer + iob := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} + s.Process(&c, req) + if !strings.Contains(logOutput.String(), "\x1b") { + t.Fatalf("The default log output does not seem to be colored") + } +} + +// test that output contains ASCII color codes by default +func TestNoColorOutput(t *testing.T) { + s := NewServer() + s.Config.ColorOutput = false + var logOutput bytes.Buffer + logger := log.New(&logOutput, "", 0) + s.Logger = logger + s.Get("/test", func() string { + return "test" + }) + req := buildTestRequest("GET", "/test", "", nil, nil) + var buf bytes.Buffer + iob := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} + s.Process(&c, req) + if strings.Contains(logOutput.String(), "\x1b") { + t.Fatalf("The log contains color escape codes") + } +} + +// a malformed SCGI request should be discarded and not cause a panic +func TestMaformedScgiRequest(t *testing.T) { + var headerBuf bytes.Buffer + + headerBuf.WriteString("CONTENT_LENGTH") + headerBuf.WriteByte(0) + headerBuf.WriteString("0") + headerBuf.WriteByte(0) + headerData := headerBuf.Bytes() + + var buf bytes.Buffer + fmt.Fprintf(&buf, "%d:", len(headerData)) + buf.Write(headerData) + buf.WriteByte(',') + + var output bytes.Buffer + nb := ioBuffer{input: &buf, output: &output} + mainServer.handleScgiRequest(&nb) + if !nb.closed { + t.Fatalf("The connection should have been closed") + } +} + +type TestHandler struct{} + +func (t *TestHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) { +} + +// When a custom HTTP handler is used, the Content-Type header should not be set to a default. +// Go's FileHandler does not replace the Content-Type header if it is already set. +func TestCustomHandlerContentType(t *testing.T) { + s := NewServer() + s.SetLogger(log.New(ioutil.Discard, "", 0)) + s.Handle("/testHandler", "GET", &TestHandler{}) + req := buildTestRequest("GET", "/testHandler", "", nil, nil) + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: nil} + s.Process(&c, req) + if c.headers["Content-Type"] != nil { + t.Fatalf("A default Content-Type should not be present when using a custom HTTP handler") + } } func BuildBasicAuthCredentials(user string, pass string) string { - s := user + ":" + pass - return "Basic " + base64.StdEncoding.EncodeToString([]byte(s)) + s := user + ":" + pass + return "Basic " + base64.StdEncoding.EncodeToString([]byte(s)) } func BenchmarkProcessGet(b *testing.B) { - s := NewServer() - s.SetLogger(log.New(ioutil.Discard, "", 0)) - s.Get("/echo/(.*)", func(s string) string { - return s - }) - req := buildTestRequest("GET", "/echo/hi", "", nil, nil) - var buf bytes.Buffer - iob := ioBuffer{input: nil, output: &buf} - c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - s.Process(&c, req) - } + s := NewServer() + s.SetLogger(log.New(ioutil.Discard, "", 0)) + s.Get("/echo/(.*)", func(s string) string { + return s + }) + req := buildTestRequest("GET", "/echo/hi", "", nil, nil) + var buf bytes.Buffer + iob := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Process(&c, req) + } } func BenchmarkProcessPost(b *testing.B) { - s := NewServer() - s.SetLogger(log.New(ioutil.Discard, "", 0)) - s.Post("/echo/(.*)", func(s string) string { - return s - }) - req := buildTestRequest("POST", "/echo/hi", "", nil, nil) - var buf bytes.Buffer - iob := ioBuffer{input: nil, output: &buf} - c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - s.Process(&c, req) - } + s := NewServer() + s.SetLogger(log.New(ioutil.Discard, "", 0)) + s.Post("/echo/(.*)", func(s string) string { + return s + }) + req := buildTestRequest("POST", "/echo/hi", "", nil, nil) + var buf bytes.Buffer + iob := ioBuffer{input: nil, output: &buf} + c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Process(&c, req) + } }