diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 00000000..b5ebb9ce
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,13 @@
+language: go
+
+go:
+ - 1.3
+ - 1.4
+ - 1.5
+ - 1.6
+
+install: go get
+
+script:
+ - diff -u <(echo -n) <(gofmt -d -s .)
+ - go test -short
diff --git a/Makefile b/Makefile
deleted file mode 100644
index 25b06108..00000000
--- a/Makefile
+++ /dev/null
@@ -1,6 +0,0 @@
-GOFMT=gofmt -s -tabs=false -tabwidth=4
-
-GOFILES=$(wildcard *.go **/*.go)
-
-format:
- ${GOFMT} -w ${GOFILES}
diff --git a/Readme.md b/Readme.md
index 9750f4df..2be27536 100644
--- a/Readme.md
+++ b/Readme.md
@@ -1,3 +1,5 @@
+[](https://travis-ci.org/hoisie/web)
+
# web.go
web.go is the simplest way to write web applications in the Go programming language. It's ideal for writing simple, performant backend web services.
@@ -77,12 +79,12 @@ In this example, if you visit `http://localhost:9999/?a=1&b=2`, you'll see the f
## Documentation
-API docs are hosted at http://webgo.io
+API docs are hosted at https://hoisie.github.io/web/
If you use web.go, I'd greatly appreciate a quick message about what you're building with it. This will help me get a sense of usage patterns, and helps me focus development efforts on features that people will actually use.
## About
-web.go was written by [Michael Hoisie](http://hoisie.com).
+web.go was written by Michael Hoisie
diff --git a/examples/arcchallenge.go b/examples/arcchallenge.go
index 1de600c2..ea9feb4c 100644
--- a/examples/arcchallenge.go
+++ b/examples/arcchallenge.go
@@ -1,10 +1,10 @@
package main
import (
- "fmt"
- "github.com/hoisie/web"
- "math/rand"
- "time"
+ "fmt"
+ "github.com/hoisie/web"
+ "math/rand"
+ "time"
)
var form = `
`
@@ -12,22 +12,22 @@ var form = `
+`
+
+func index(ctx *web.Context) string {
+ cookie, ok := ctx.GetSecureCookie(cookieName)
+ var top string
+ if !ok {
+ top = fmt.Sprintf(notice, "The cookie has not been set")
+ } else {
+ var val = html.EscapeString(cookie)
+ top = fmt.Sprintf(notice, "The value of the cookie is '"+val+"'.")
+ }
+ return top + form
+}
+
+func update(ctx *web.Context) {
+ if ctx.Params["submit"] == "Delete" {
+ ctx.SetCookie(web.NewCookie(cookieName, "", -1))
+ } else {
+ ctx.SetSecureCookie(cookieName, ctx.Params["cookie"], 0)
+ }
+ ctx.Redirect(301, "/")
+}
+
+func main() {
+ web.Config.CookieSecret = "a long secure cookie secret"
+ web.Get("/", index)
+ web.Post("/update", update)
+ web.Run("0.0.0.0:9999")
+}
diff --git a/examples/streaming.go b/examples/streaming.go
index e17ab372..2a5b948f 100644
--- a/examples/streaming.go
+++ b/examples/streaming.go
@@ -1,24 +1,24 @@
package main
import (
- "github.com/hoisie/web"
- "net/http"
- "strconv"
- "time"
+ "github.com/hoisie/web"
+ "net/http"
+ "strconv"
+ "time"
)
func hello(ctx *web.Context, num string) {
- flusher, _ := ctx.ResponseWriter.(http.Flusher)
- flusher.Flush()
- n, _ := strconv.ParseInt(num, 10, 64)
- for i := int64(0); i < n; i++ {
- ctx.WriteString("
hello world")
- flusher.Flush()
- time.Sleep(1e9)
- }
+ flusher, _ := ctx.ResponseWriter.(http.Flusher)
+ flusher.Flush()
+ n, _ := strconv.ParseInt(num, 10, 64)
+ for i := int64(0); i < n; i++ {
+ ctx.WriteString("
hello world")
+ flusher.Flush()
+ time.Sleep(1e9)
+ }
}
func main() {
- web.Get("/([0-9]+)", hello)
- web.Run("0.0.0.0:9999")
+ web.Get("/([0-9]+)", hello)
+ web.Run("0.0.0.0:9999")
}
diff --git a/examples/tls.go b/examples/tls.go
index 7517fe61..951418fd 100644
--- a/examples/tls.go
+++ b/examples/tls.go
@@ -1,8 +1,8 @@
package main
import (
- "crypto/tls"
- "github.com/hoisie/web"
+ "crypto/tls"
+ "github.com/hoisie/web"
)
// an arbitrary self-signed certificate, generated with
@@ -47,22 +47,22 @@ gWrxykqyLToIiAuL+pvC3Jv8IOPIiVFsY032rOqcwSGdVUyhTsG28+7KnR6744tM
-----END CERTIFICATE-----
`
-func hello(val string) string { return "hello " + val }
+func hello(val string) string { return "hello " + val + "\n" }
func main() {
- config := tls.Config{
- Time: nil,
- }
+ config := tls.Config{
+ Time: nil,
+ }
- config.Certificates = make([]tls.Certificate, 1)
- var err error
- config.Certificates[0], err = tls.X509KeyPair([]byte(cert), []byte(pkey))
- if err != nil {
- println(err.Error())
- return
- }
+ config.Certificates = make([]tls.Certificate, 1)
+ var err error
+ config.Certificates[0], err = tls.X509KeyPair([]byte(cert), []byte(pkey))
+ if err != nil {
+ println(err.Error())
+ return
+ }
- // you must access the server with an HTTP address, i.e https://localhost:9999/world
- web.Get("/(.*)", hello)
- web.RunTLS("0.0.0.0:9999", &config)
+ // you must access the server with an HTTPS address, i.e https://localhost:9999/world
+ web.Get("/(.*)", hello)
+ web.RunTLS("0.0.0.0:9999", &config)
}
diff --git a/fcgi.go b/fcgi.go
index fa380dc9..017b8024 100644
--- a/fcgi.go
+++ b/fcgi.go
@@ -1,27 +1,27 @@
package web
import (
- "net"
- "net/http/fcgi"
+ "net"
+ "net/http/fcgi"
)
func (s *Server) listenAndServeFcgi(addr string) error {
- var l net.Listener
- var err error
+ var l net.Listener
+ var err error
- //if the path begins with a "/", assume it's a unix address
- if addr[0] == '/' {
- l, err = net.Listen("unix", addr)
- } else {
- l, err = net.Listen("tcp", addr)
- }
+ //if the path begins with a "/", assume it's a unix address
+ if addr[0] == '/' {
+ l, err = net.Listen("unix", addr)
+ } else {
+ l, err = net.Listen("tcp", addr)
+ }
- //save the listener so it can be closed
- s.l = l
+ //save the listener so it can be closed
+ s.l = l
- if err != nil {
- s.Logger.Println("FCGI listen error", err.Error())
- return err
- }
- return fcgi.Serve(s.l, s)
+ if err != nil {
+ s.Logger.Println("FCGI listen error", err.Error())
+ return err
+ }
+ return fcgi.Serve(s.l, s)
}
diff --git a/helpers.go b/helpers.go
index 93c8d5ad..a87e93f4 100644
--- a/helpers.go
+++ b/helpers.go
@@ -1,59 +1,59 @@
package web
import (
- "bytes"
- "encoding/base64"
- "errors"
- "net/http"
- "net/url"
- "os"
- "regexp"
- "strings"
- "time"
+ "bytes"
+ "encoding/base64"
+ "errors"
+ "net/http"
+ "net/url"
+ "os"
+ "regexp"
+ "strings"
+ "time"
)
// internal utility methods
func webTime(t time.Time) string {
- ftime := t.Format(time.RFC1123)
- if strings.HasSuffix(ftime, "UTC") {
- ftime = ftime[0:len(ftime)-3] + "GMT"
- }
- return ftime
+ ftime := t.Format(time.RFC1123)
+ if strings.HasSuffix(ftime, "UTC") {
+ ftime = ftime[0:len(ftime)-3] + "GMT"
+ }
+ return ftime
}
func dirExists(dir string) bool {
- d, e := os.Stat(dir)
- switch {
- case e != nil:
- return false
- case !d.IsDir():
- return false
- }
+ d, e := os.Stat(dir)
+ switch {
+ case e != nil:
+ return false
+ case !d.IsDir():
+ return false
+ }
- return true
+ return true
}
func fileExists(dir string) bool {
- info, err := os.Stat(dir)
- if err != nil {
- return false
- }
+ info, err := os.Stat(dir)
+ if err != nil {
+ return false
+ }
- return !info.IsDir()
+ return !info.IsDir()
}
// Urlencode is a helper method that converts a map into URL-encoded form data.
// It is a useful when constructing HTTP POST requests.
func Urlencode(data map[string]string) string {
- var buf bytes.Buffer
- for k, v := range data {
- buf.WriteString(url.QueryEscape(k))
- buf.WriteByte('=')
- buf.WriteString(url.QueryEscape(v))
- buf.WriteByte('&')
- }
- s := buf.String()
- return s[0 : len(s)-1]
+ var buf bytes.Buffer
+ for k, v := range data {
+ buf.WriteString(url.QueryEscape(k))
+ buf.WriteByte('=')
+ buf.WriteString(url.QueryEscape(v))
+ buf.WriteByte('&')
+ }
+ s := buf.String()
+ return s[0 : len(s)-1]
}
var slugRegex = regexp.MustCompile(`(?i:[^a-z0-9\-_])`)
@@ -62,50 +62,53 @@ var slugRegex = regexp.MustCompile(`(?i:[^a-z0-9\-_])`)
// It's used to return clean, URL-friendly strings that can be
// used in routing.
func Slug(s string, sep string) string {
- if s == "" {
- return ""
- }
- slug := slugRegex.ReplaceAllString(s, sep)
- if slug == "" {
- return ""
- }
- quoted := regexp.QuoteMeta(sep)
- sepRegex := regexp.MustCompile("(" + quoted + "){2,}")
- slug = sepRegex.ReplaceAllString(slug, sep)
- sepEnds := regexp.MustCompile("^" + quoted + "|" + quoted + "$")
- slug = sepEnds.ReplaceAllString(slug, "")
- return strings.ToLower(slug)
+ if s == "" {
+ return ""
+ }
+ slug := slugRegex.ReplaceAllString(s, sep)
+ if slug == "" {
+ return ""
+ }
+ quoted := regexp.QuoteMeta(sep)
+ sepRegex := regexp.MustCompile("(" + quoted + "){2,}")
+ slug = sepRegex.ReplaceAllString(slug, sep)
+ sepEnds := regexp.MustCompile("^" + quoted + "|" + quoted + "$")
+ slug = sepEnds.ReplaceAllString(slug, "")
+ return strings.ToLower(slug)
}
// NewCookie is a helper method that returns a new http.Cookie object.
// Duration is specified in seconds. If the duration is zero, the cookie is permanent.
// This can be used in conjunction with ctx.SetCookie.
func NewCookie(name string, value string, age int64) *http.Cookie {
- var utctime time.Time
- if age == 0 {
- // 2^31 - 1 seconds (roughly 2038)
- utctime = time.Unix(2147483647, 0)
- } else {
- utctime = time.Unix(time.Now().Unix()+age, 0)
- }
- return &http.Cookie{Name: name, Value: value, Expires: utctime}
+ var utctime time.Time
+ if age == 0 {
+ // 2^31 - 1 seconds (roughly 2038)
+ utctime = time.Unix(2147483647, 0)
+ } else {
+ utctime = time.Unix(time.Now().Unix()+age, 0)
+ }
+ return &http.Cookie{Name: name, Value: value, Expires: utctime}
}
-// GetBasicAuth is a helper method of *Context that returns the decoded
-// user and password from the *Context's authorization header
+// GetBasicAuth returns the decoded user and password from the context's
+// 'Authorization' header.
func (ctx *Context) GetBasicAuth() (string, string, error) {
- authHeader := ctx.Request.Header["Authorization"][0]
- authString := strings.Split(string(authHeader), " ")
- if authString[0] != "Basic" {
- return "", "", errors.New("Not Basic Authentication")
- }
- decodedAuth, err := base64.StdEncoding.DecodeString(authString[1])
- if err != nil {
- return "", "", err
- }
- authSlice := strings.Split(string(decodedAuth), ":")
- if len(authSlice) != 2 {
- return "", "", errors.New("Error delimiting authString into username/password. Malformed input: " + authString[1])
- }
- return authSlice[0], authSlice[1], nil
+ if len(ctx.Request.Header["Authorization"]) == 0 {
+ return "", "", errors.New("No Authorization header provided")
+ }
+ authHeader := ctx.Request.Header["Authorization"][0]
+ authString := strings.Split(string(authHeader), " ")
+ if authString[0] != "Basic" {
+ return "", "", errors.New("Not Basic Authentication")
+ }
+ decodedAuth, err := base64.StdEncoding.DecodeString(authString[1])
+ if err != nil {
+ return "", "", err
+ }
+ authSlice := strings.Split(string(decodedAuth), ":")
+ if len(authSlice) != 2 {
+ return "", "", errors.New("Error delimiting authString into username/password. Malformed input: " + authString[1])
+ }
+ return authSlice[0], authSlice[1], nil
}
diff --git a/scgi.go b/scgi.go
index 45d3aec5..eea2b4fe 100644
--- a/scgi.go
+++ b/scgi.go
@@ -1,180 +1,183 @@
package web
import (
- "bufio"
- "bytes"
- "errors"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/http/cgi"
- "strconv"
- "strings"
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/http/cgi"
+ "strconv"
+ "strings"
)
type scgiBody struct {
- reader io.Reader
- conn io.ReadWriteCloser
- closed bool
+ reader io.Reader
+ conn io.ReadWriteCloser
+ closed bool
}
func (b *scgiBody) Read(p []byte) (n int, err error) {
- if b.closed {
- return 0, errors.New("SCGI read after close")
- }
- return b.reader.Read(p)
+ if b.closed {
+ return 0, errors.New("SCGI read after close")
+ }
+ return b.reader.Read(p)
}
func (b *scgiBody) Close() error {
- b.closed = true
- return b.conn.Close()
+ b.closed = true
+ return b.conn.Close()
}
type scgiConn struct {
- fd io.ReadWriteCloser
- req *http.Request
- headers http.Header
- wroteHeaders bool
+ fd io.ReadWriteCloser
+ req *http.Request
+ headers http.Header
+ wroteHeaders bool
}
func (conn *scgiConn) WriteHeader(status int) {
- if !conn.wroteHeaders {
- conn.wroteHeaders = true
+ if !conn.wroteHeaders {
+ conn.wroteHeaders = true
- var buf bytes.Buffer
- text := statusText[status]
+ var buf bytes.Buffer
+ text := http.StatusText(status)
- fmt.Fprintf(&buf, "HTTP/1.1 %d %s\r\n", status, text)
+ fmt.Fprintf(&buf, "HTTP/1.1 %d %s\r\n", status, text)
- for k, v := range conn.headers {
- for _, i := range v {
- buf.WriteString(k + ": " + i + "\r\n")
- }
- }
+ for k, v := range conn.headers {
+ for _, i := range v {
+ buf.WriteString(k + ": " + i + "\r\n")
+ }
+ }
- buf.WriteString("\r\n")
- conn.fd.Write(buf.Bytes())
- }
+ buf.WriteString("\r\n")
+ conn.fd.Write(buf.Bytes())
+ }
}
func (conn *scgiConn) Header() http.Header {
- return conn.headers
+ return conn.headers
}
func (conn *scgiConn) Write(data []byte) (n int, err error) {
- if !conn.wroteHeaders {
- conn.WriteHeader(200)
- }
+ if !conn.wroteHeaders {
+ conn.WriteHeader(200)
+ }
- if conn.req.Method == "HEAD" {
- return 0, errors.New("Body Not Allowed")
- }
+ if conn.req.Method == "HEAD" {
+ return 0, errors.New("Body Not Allowed")
+ }
- return conn.fd.Write(data)
+ return conn.fd.Write(data)
}
func (conn *scgiConn) Close() { conn.fd.Close() }
func (conn *scgiConn) finishRequest() error {
- var buf bytes.Buffer
- if !conn.wroteHeaders {
- conn.wroteHeaders = true
-
- for k, v := range conn.headers {
- for _, i := range v {
- buf.WriteString(k + ": " + i + "\r\n")
- }
- }
-
- buf.WriteString("\r\n")
- conn.fd.Write(buf.Bytes())
- }
- return nil
+ var buf bytes.Buffer
+ if !conn.wroteHeaders {
+ conn.wroteHeaders = true
+
+ for k, v := range conn.headers {
+ for _, i := range v {
+ buf.WriteString(k + ": " + i + "\r\n")
+ }
+ }
+
+ buf.WriteString("\r\n")
+ conn.fd.Write(buf.Bytes())
+ }
+ return nil
}
func (s *Server) readScgiRequest(fd io.ReadWriteCloser) (*http.Request, error) {
- reader := bufio.NewReader(fd)
- line, err := reader.ReadString(':')
- if err != nil {
- s.Logger.Println("Error during SCGI read: ", err.Error())
- }
- length, _ := strconv.Atoi(line[0 : len(line)-1])
- if length > 16384 {
- s.Logger.Println("Error: max header size is 16k")
- }
- headerData := make([]byte, length)
- _, err = reader.Read(headerData)
- if err != nil {
- return nil, err
- }
-
- b, err := reader.ReadByte()
- if err != nil {
- return nil, err
- }
- // discard the trailing comma
- if b != ',' {
- return nil, errors.New("SCGI protocol error: missing comma")
- }
- headerList := bytes.Split(headerData, []byte{0})
- headers := map[string]string{}
- for i := 0; i < len(headerList)-1; i += 2 {
- headers[string(headerList[i])] = string(headerList[i+1])
- }
- httpReq, err := cgi.RequestFromMap(headers)
- if err != nil {
- return nil, err
- }
- if httpReq.ContentLength > 0 {
- httpReq.Body = &scgiBody{
- reader: io.LimitReader(reader, httpReq.ContentLength),
- conn: fd,
- }
- } else {
- httpReq.Body = &scgiBody{reader: reader, conn: fd}
- }
- return httpReq, nil
+ reader := bufio.NewReader(fd)
+ line, err := reader.ReadString(':')
+ if err != nil {
+ return nil, err
+ }
+ length, err := strconv.Atoi(line[0 : len(line)-1])
+ if err != nil {
+ return nil, err
+ }
+ if length > 16384 {
+ return nil, errors.New("Max header size is 16k")
+ }
+ headerData := make([]byte, length)
+ _, err = reader.Read(headerData)
+ if err != nil {
+ return nil, err
+ }
+ b, err := reader.ReadByte()
+ if err != nil {
+ return nil, err
+ }
+ // discard the trailing comma
+ if b != ',' {
+ return nil, errors.New("SCGI protocol error: missing comma")
+ }
+ headerList := bytes.Split(headerData, []byte{0})
+ headers := map[string]string{}
+ for i := 0; i < len(headerList)-1; i += 2 {
+ headers[string(headerList[i])] = string(headerList[i+1])
+ }
+ httpReq, err := cgi.RequestFromMap(headers)
+ if err != nil {
+ return nil, err
+ }
+ if httpReq.ContentLength > 0 {
+ httpReq.Body = &scgiBody{
+ reader: io.LimitReader(reader, httpReq.ContentLength),
+ conn: fd,
+ }
+ } else {
+ httpReq.Body = &scgiBody{reader: reader, conn: fd}
+ }
+ return httpReq, nil
}
func (s *Server) handleScgiRequest(fd io.ReadWriteCloser) {
- req, err := s.readScgiRequest(fd)
- if err != nil {
- s.Logger.Println("SCGI error: %q", err.Error())
- }
- sc := scgiConn{fd, req, make(map[string][]string), false}
- s.routeHandler(req, &sc)
- sc.finishRequest()
- fd.Close()
+ defer fd.Close()
+ req, err := s.readScgiRequest(fd)
+ if err != nil {
+ s.Logger.Println("Error reading SCGI request: %q", err.Error())
+ return
+ }
+ sc := scgiConn{fd, req, make(map[string][]string), false}
+ s.routeHandler(req, &sc)
+ sc.finishRequest()
}
func (s *Server) listenAndServeScgi(addr string) error {
- var l net.Listener
- var err error
-
- //if the path begins with a "/", assume it's a unix address
- if strings.HasPrefix(addr, "/") {
- l, err = net.Listen("unix", addr)
- } else {
- l, err = net.Listen("tcp", addr)
- }
-
- //save the listener so it can be closed
- s.l = l
-
- if err != nil {
- s.Logger.Println("SCGI listen error", err.Error())
- return err
- }
-
- for {
- fd, err := l.Accept()
- if err != nil {
- s.Logger.Println("SCGI accept error", err.Error())
- return err
- }
- go s.handleScgiRequest(fd)
- }
- return nil
+ var l net.Listener
+ var err error
+
+ //if the path begins with a "/", assume it's a unix address
+ if strings.HasPrefix(addr, "/") {
+ l, err = net.Listen("unix", addr)
+ } else {
+ l, err = net.Listen("tcp", addr)
+ }
+
+ //save the listener so it can be closed
+ s.l = l
+
+ if err != nil {
+ s.Logger.Println("SCGI listen error", err.Error())
+ return err
+ }
+
+ for {
+ fd, err := l.Accept()
+ if err != nil {
+ s.Logger.Println("SCGI accept error", err.Error())
+ return err
+ }
+ go s.handleScgiRequest(fd)
+ }
+ return nil
}
diff --git a/secure_cookie.go b/secure_cookie.go
new file mode 100644
index 00000000..dfd5d196
--- /dev/null
+++ b/secure_cookie.go
@@ -0,0 +1,112 @@
+package web
+
+import (
+ "bytes"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/sha512"
+ "encoding/base64"
+ "errors"
+ "golang.org/x/crypto/pbkdf2"
+ "io"
+ "strings"
+)
+
+const (
+ pbkdf2Iterations = 64000
+ keySize = 32
+)
+
+var (
+ ErrMissingCookieSecret = errors.New("Secret Key for secure cookies has not been set. Assign one to web.Config.CookieSecret.")
+ ErrInvalidKey = errors.New("The keys for secure cookies have not been initialized. Ensure that a Run* method is being called")
+)
+
+func (ctx *Context) SetSecureCookie(name string, val string, age int64) error {
+ server := ctx.Server
+ if len(server.Config.CookieSecret) == 0 {
+ return ErrMissingCookieSecret
+ }
+ if len(server.encKey) == 0 || len(server.signKey) == 0 {
+ return ErrInvalidKey
+ }
+ ciphertext, err := encrypt([]byte(val), server.encKey)
+ if err != nil {
+ return err
+ }
+ sig := sign(ciphertext, server.signKey)
+ data := base64.StdEncoding.EncodeToString(ciphertext) + "|" + base64.StdEncoding.EncodeToString(sig)
+ ctx.SetCookie(NewCookie(name, data, age))
+ return nil
+}
+
+func (ctx *Context) GetSecureCookie(name string) (string, bool) {
+ for _, cookie := range ctx.Request.Cookies() {
+ if cookie.Name != name {
+ continue
+ }
+ parts := strings.SplitN(cookie.Value, "|", 2)
+ if len(parts) != 2 {
+ return "", false
+ }
+ ciphertext, err := base64.StdEncoding.DecodeString(parts[0])
+ if err != nil {
+ return "", false
+ }
+ sig, err := base64.StdEncoding.DecodeString(parts[1])
+ if err != nil {
+ return "", false
+ }
+ expectedSig := sign([]byte(ciphertext), ctx.Server.signKey)
+ if !bytes.Equal(expectedSig, sig) {
+ return "", false
+ }
+ plaintext, err := decrypt(ciphertext, ctx.Server.encKey)
+ if err != nil {
+ return "", false
+ }
+ return string(plaintext), true
+ }
+ return "", false
+}
+
+func genKey(password string, salt string) []byte {
+ return pbkdf2.Key([]byte(password), []byte(salt), pbkdf2Iterations, keySize, sha512.New)
+}
+
+func encrypt(plaintext []byte, key []byte) ([]byte, error) {
+ aesCipher, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, err
+ }
+ ciphertext := make([]byte, aes.BlockSize+len(plaintext))
+ iv := ciphertext[:aes.BlockSize]
+ if _, err := io.ReadFull(rand.Reader, iv); err != nil {
+ return nil, err
+ }
+ stream := cipher.NewCTR(aesCipher, iv)
+ stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)
+ return ciphertext, nil
+}
+
+func decrypt(ciphertext []byte, key []byte) ([]byte, error) {
+ if len(ciphertext) <= aes.BlockSize {
+ return nil, errors.New("Invalid cipher text")
+ }
+ aesCipher, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, err
+ }
+ plaintext := make([]byte, len(ciphertext)-aes.BlockSize)
+ stream := cipher.NewCTR(aesCipher, ciphertext[:aes.BlockSize])
+ stream.XORKeyStream(plaintext, ciphertext[aes.BlockSize:])
+ return plaintext, nil
+}
+
+func sign(data []byte, key []byte) []byte {
+ mac := hmac.New(sha512.New, key)
+ mac.Write(data)
+ return mac.Sum(nil)
+}
diff --git a/server.go b/server.go
index 184f01c2..0e97a4dd 100644
--- a/server.go
+++ b/server.go
@@ -1,240 +1,252 @@
package web
import (
- "bytes"
- "code.google.com/p/go.net/websocket"
- "crypto/tls"
- "fmt"
- "log"
- "net"
- "net/http"
- "net/http/pprof"
- "os"
- "path"
- "reflect"
- "regexp"
- "runtime"
- "strconv"
- "strings"
- "time"
+ "bytes"
+ "crypto/tls"
+ "fmt"
+ "golang.org/x/net/websocket"
+ "log"
+ "net"
+ "net/http"
+ "net/http/pprof"
+ "os"
+ "path"
+ "reflect"
+ "regexp"
+ "runtime"
+ "strconv"
+ "strings"
+ "time"
)
// ServerConfig is configuration for server objects.
type ServerConfig struct {
- StaticDir string
- Addr string
- Port int
- CookieSecret string
- RecoverPanic bool
- Profiler bool
+ StaticDir string
+ Addr string
+ Port int
+ CookieSecret string
+ RecoverPanic bool
+ Profiler bool
+ ColorOutput bool
}
// Server represents a web.go server.
type Server struct {
- Config *ServerConfig
- routes []route
- Logger *log.Logger
- Env map[string]interface{}
- //save the listener so it can be closed
- l net.Listener
+ Config *ServerConfig
+ routes []route
+ Logger *log.Logger
+ Env map[string]interface{}
+ //save the listener so it can be closed
+ l net.Listener
+ encKey []byte
+ signKey []byte
}
func NewServer() *Server {
- return &Server{
- Config: Config,
- Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime),
- Env: map[string]interface{}{},
- }
+ return &Server{
+ Config: Config,
+ Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime),
+ Env: map[string]interface{}{},
+ }
}
func (s *Server) initServer() {
- if s.Config == nil {
- s.Config = &ServerConfig{}
- }
-
- if s.Logger == nil {
- s.Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime)
- }
+ if s.Config == nil {
+ s.Config = &ServerConfig{}
+ }
+
+ if s.Logger == nil {
+ s.Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime)
+ }
+
+ if len(s.Config.CookieSecret) > 0 {
+ s.Logger.Println("Generating cookie encryption keys")
+ s.encKey = genKey(s.Config.CookieSecret, "encryption key salt")
+ s.signKey = genKey(s.Config.CookieSecret, "signature key salt")
+ }
}
type route struct {
- r string
- cr *regexp.Regexp
- method string
- handler reflect.Value
- httpHandler http.Handler
+ r string
+ cr *regexp.Regexp
+ method string
+ handler reflect.Value
+ httpHandler http.Handler
}
func (s *Server) addRoute(r string, method string, handler interface{}) {
- cr, err := regexp.Compile(r)
- if err != nil {
- s.Logger.Printf("Error in route regex %q\n", r)
- return
- }
-
- switch handler.(type) {
- case http.Handler:
- s.routes = append(s.routes, route{r: r, cr: cr, method: method, httpHandler: handler.(http.Handler)})
- case reflect.Value:
- fv := handler.(reflect.Value)
- s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv})
- default:
- fv := reflect.ValueOf(handler)
- s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv})
- }
+ cr, err := regexp.Compile(r)
+ if err != nil {
+ s.Logger.Printf("Error in route regex %q\n", r)
+ return
+ }
+
+ switch handler.(type) {
+ case http.Handler:
+ s.routes = append(s.routes, route{r: r, cr: cr, method: method, httpHandler: handler.(http.Handler)})
+ case reflect.Value:
+ fv := handler.(reflect.Value)
+ s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv})
+ default:
+ fv := reflect.ValueOf(handler)
+ s.routes = append(s.routes, route{r: r, cr: cr, method: method, handler: fv})
+ }
}
// ServeHTTP is the interface method for Go's http server package
func (s *Server) ServeHTTP(c http.ResponseWriter, req *http.Request) {
- s.Process(c, req)
+ s.Process(c, req)
}
// Process invokes the routing system for server s
func (s *Server) Process(c http.ResponseWriter, req *http.Request) {
- route := s.routeHandler(req, c)
- if route != nil {
- route.httpHandler.ServeHTTP(c, req)
- }
+ route := s.routeHandler(req, c)
+ if route != nil {
+ route.httpHandler.ServeHTTP(c, req)
+ }
}
// Get adds a handler for the 'GET' http method for server s.
func (s *Server) Get(route string, handler interface{}) {
- s.addRoute(route, "GET", handler)
+ s.addRoute(route, "GET", handler)
}
// Post adds a handler for the 'POST' http method for server s.
func (s *Server) Post(route string, handler interface{}) {
- s.addRoute(route, "POST", handler)
+ s.addRoute(route, "POST", handler)
}
// Put adds a handler for the 'PUT' http method for server s.
func (s *Server) Put(route string, handler interface{}) {
- s.addRoute(route, "PUT", handler)
+ s.addRoute(route, "PUT", handler)
}
// Delete adds a handler for the 'DELETE' http method for server s.
func (s *Server) Delete(route string, handler interface{}) {
- s.addRoute(route, "DELETE", handler)
+ s.addRoute(route, "DELETE", handler)
}
// Match adds a handler for an arbitrary http method for server s.
func (s *Server) Match(method string, route string, handler interface{}) {
- s.addRoute(route, method, handler)
+ s.addRoute(route, method, handler)
}
-//Adds a custom handler. Only for webserver mode. Will have no effect when running as FCGI or SCGI.
-func (s *Server) Handler(route string, method string, httpHandler http.Handler) {
- s.addRoute(route, method, httpHandler)
+// Add a custom http.Handler. Will have no effect when running as FCGI or SCGI.
+func (s *Server) Handle(route string, method string, httpHandler http.Handler) {
+ s.addRoute(route, method, httpHandler)
}
//Adds a handler for websockets. Only for webserver mode. Will have no effect when running as FCGI or SCGI.
func (s *Server) Websocket(route string, httpHandler websocket.Handler) {
- s.addRoute(route, "GET", httpHandler)
+ s.addRoute(route, "GET", httpHandler)
}
// Run starts the web application and serves HTTP requests for s
func (s *Server) Run(addr string) {
- s.initServer()
-
- mux := http.NewServeMux()
- if s.Config.Profiler {
- mux.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline))
- mux.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile))
- mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
- mux.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol))
- }
- mux.Handle("/", s)
-
- s.Logger.Printf("web.go serving %s\n", addr)
-
- l, err := net.Listen("tcp", addr)
- if err != nil {
- log.Fatal("ListenAndServe:", err)
- }
- s.l = l
- err = http.Serve(s.l, mux)
- s.l.Close()
+ s.initServer()
+
+ mux := http.NewServeMux()
+ if s.Config.Profiler {
+ mux.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline))
+ mux.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile))
+ mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
+ mux.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol))
+ }
+ mux.Handle("/", s)
+
+ l, err := net.Listen("tcp", addr)
+ if err != nil {
+ log.Fatal("ListenAndServe:", err)
+ }
+
+ s.Logger.Printf("web.go serving %s\n", l.Addr())
+
+ s.l = l
+ err = http.Serve(s.l, mux)
+ s.l.Close()
}
// RunFcgi starts the web application and serves FastCGI requests for s.
func (s *Server) RunFcgi(addr string) {
- s.initServer()
- s.Logger.Printf("web.go serving fcgi %s\n", addr)
- s.listenAndServeFcgi(addr)
+ s.initServer()
+ s.Logger.Printf("web.go serving fcgi %s\n", addr)
+ s.listenAndServeFcgi(addr)
}
// RunScgi starts the web application and serves SCGI requests for s.
func (s *Server) RunScgi(addr string) {
- s.initServer()
- s.Logger.Printf("web.go serving scgi %s\n", addr)
- s.listenAndServeScgi(addr)
+ s.initServer()
+ s.Logger.Printf("web.go serving scgi %s\n", addr)
+ s.listenAndServeScgi(addr)
}
// RunTLS starts the web application and serves HTTPS requests for s.
func (s *Server) RunTLS(addr string, config *tls.Config) error {
- s.initServer()
- mux := http.NewServeMux()
- mux.Handle("/", s)
- l, err := tls.Listen("tcp", addr, config)
- if err != nil {
- log.Fatal("Listen:", err)
- return err
- }
-
- s.l = l
- return http.Serve(s.l, mux)
+ s.initServer()
+ mux := http.NewServeMux()
+ mux.Handle("/", s)
+
+ l, err := tls.Listen("tcp", addr, config)
+ if err != nil {
+ log.Fatal("Listen:", err)
+ return err
+ }
+ s.Logger.Printf("web.go serving %s\n", l.Addr())
+
+ s.l = l
+ return http.Serve(s.l, mux)
}
// Close stops server s.
func (s *Server) Close() {
- if s.l != nil {
- s.l.Close()
- }
+ if s.l != nil {
+ s.l.Close()
+ }
}
// safelyCall invokes `function` in recover block
func (s *Server) safelyCall(function reflect.Value, args []reflect.Value) (resp []reflect.Value, e interface{}) {
- defer func() {
- if err := recover(); err != nil {
- if !s.Config.RecoverPanic {
- // go back to panic
- panic(err)
- } else {
- e = err
- resp = nil
- s.Logger.Println("Handler crashed with error", err)
- for i := 1; ; i += 1 {
- _, file, line, ok := runtime.Caller(i)
- if !ok {
- break
- }
- s.Logger.Println(file, line)
- }
- }
- }
- }()
- return function.Call(args), nil
+ defer func() {
+ if err := recover(); err != nil {
+ if !s.Config.RecoverPanic {
+ // go back to panic
+ panic(err)
+ } else {
+ e = err
+ resp = nil
+ s.Logger.Println("Handler crashed with error", err)
+ for i := 1; ; i += 1 {
+ _, file, line, ok := runtime.Caller(i)
+ if !ok {
+ break
+ }
+ s.Logger.Println(file, line)
+ }
+ }
+ }
+ }()
+ return function.Call(args), nil
}
// requiresContext determines whether 'handlerType' contains
// an argument to 'web.Ctx' as its first argument
func requiresContext(handlerType reflect.Type) bool {
- //if the method doesn't take arguments, no
- if handlerType.NumIn() == 0 {
- return false
- }
-
- //if the first argument is not a pointer, no
- a0 := handlerType.In(0)
- if a0.Kind() != reflect.Ptr {
- return false
- }
- //if the first argument is a context, yes
- if a0.Elem() == contextType {
- return true
- }
-
- return false
+ //if the method doesn't take arguments, no
+ if handlerType.NumIn() == 0 {
+ return false
+ }
+
+ //if the first argument is not a pointer, no
+ a0 := handlerType.In(0)
+ if a0.Kind() != reflect.Ptr {
+ return false
+ }
+ //if the first argument is a context, yes
+ if a0.Elem() == contextType {
+ return true
+ }
+
+ return false
}
// tryServingFile attempts to serve a static file, and returns
@@ -244,51 +256,66 @@ func requiresContext(handlerType reflect.Type) bool {
// 2) The 'static' directory in the parent directory of the executable.
// 3) The 'static' directory in the current working directory
func (s *Server) tryServingFile(name string, req *http.Request, w http.ResponseWriter) bool {
- //try to serve a static file
- if s.Config.StaticDir != "" {
- staticFile := path.Join(s.Config.StaticDir, name)
- if fileExists(staticFile) {
- http.ServeFile(w, req, staticFile)
- return true
- }
- } else {
- for _, staticDir := range defaultStaticDirs {
- staticFile := path.Join(staticDir, name)
- if fileExists(staticFile) {
- http.ServeFile(w, req, staticFile)
- return true
- }
- }
- }
- return false
+ //try to serve a static file
+ if s.Config.StaticDir != "" {
+ staticFile := path.Join(s.Config.StaticDir, name)
+ if fileExists(staticFile) {
+ http.ServeFile(w, req, staticFile)
+ return true
+ }
+ } else {
+ for _, staticDir := range defaultStaticDirs {
+ staticFile := path.Join(staticDir, name)
+ if fileExists(staticFile) {
+ http.ServeFile(w, req, staticFile)
+ return true
+ }
+ }
+ }
+ return false
}
func (s *Server) logRequest(ctx Context, sTime time.Time) {
- //log the request
- var logEntry bytes.Buffer
- req := ctx.Request
- requestPath := req.URL.Path
-
- duration := time.Now().Sub(sTime)
- var client string
-
- // We suppose RemoteAddr is of the form Ip:Port as specified in the Request
- // documentation at http://golang.org/pkg/net/http/#Request
- pos := strings.LastIndex(req.RemoteAddr, ":")
- if pos > 0 {
- client = req.RemoteAddr[0:pos]
- } else {
- client = req.RemoteAddr
- }
-
- fmt.Fprintf(&logEntry, "%s - \033[32;1m %s %s\033[0m - %v", client, req.Method, requestPath, duration)
+ //log the request
+ req := ctx.Request
+ requestPath := req.URL.Path
+
+ duration := time.Now().Sub(sTime)
+ var client string
+
+ // We suppose RemoteAddr is of the form Ip:Port as specified in the Request
+ // documentation at http://golang.org/pkg/net/http/#Request
+ pos := strings.LastIndex(req.RemoteAddr, ":")
+ if pos > 0 {
+ client = req.RemoteAddr[0:pos]
+ } else {
+ client = req.RemoteAddr
+ }
+
+ var logEntry bytes.Buffer
+ logEntry.WriteString(client)
+ logEntry.WriteString(" - " + s.ttyGreen(req.Method+" "+requestPath))
+ logEntry.WriteString(" - " + duration.String())
+ if len(ctx.Params) > 0 {
+ logEntry.WriteString(" - " + s.ttyWhite(fmt.Sprintf("Params: %v\n", ctx.Params)))
+ }
+ ctx.Server.Logger.Print(logEntry.String())
+}
- if len(ctx.Params) > 0 {
- fmt.Fprintf(&logEntry, " - \033[37;1mParams: %v\033[0m\n", ctx.Params)
- }
+func (s *Server) ttyGreen(msg string) string {
+ return s.ttyColor(msg, ttyCodes.green)
+}
- ctx.Server.Logger.Print(logEntry.String())
+func (s *Server) ttyWhite(msg string) string {
+ return s.ttyColor(msg, ttyCodes.white)
+}
+func (s *Server) ttyColor(msg string, colorCode string) string {
+ if s.Config.ColorOutput {
+ return colorCode + msg + ttyCodes.reset
+ } else {
+ return msg
+ }
}
// the main route handler in web.go
@@ -298,105 +325,105 @@ func (s *Server) logRequest(ctx Context, sTime time.Time) {
// route. The caller is then responsible for calling the httpHandler associated
// with the returned route.
func (s *Server) routeHandler(req *http.Request, w http.ResponseWriter) (unused *route) {
- requestPath := req.URL.Path
- ctx := Context{req, map[string]string{}, s, w}
-
- //set some default headers
- ctx.SetHeader("Server", "web.go", true)
- tm := time.Now().UTC()
-
- //ignore errors from ParseForm because it's usually harmless.
- req.ParseForm()
- if len(req.Form) > 0 {
- for k, v := range req.Form {
- ctx.Params[k] = v[0]
- }
- }
-
- defer s.logRequest(ctx, tm)
-
- ctx.SetHeader("Date", webTime(tm), true)
-
- if req.Method == "GET" || req.Method == "HEAD" {
- if s.tryServingFile(requestPath, req, w) {
- return
- }
- }
-
- //Set the default content-type
- ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true)
-
- for i := 0; i < len(s.routes); i++ {
- route := s.routes[i]
- cr := route.cr
- //if the methods don't match, skip this handler (except HEAD can be used in place of GET)
- if req.Method != route.method && !(req.Method == "HEAD" && route.method == "GET") {
- continue
- }
-
- if !cr.MatchString(requestPath) {
- continue
- }
- match := cr.FindStringSubmatch(requestPath)
-
- if len(match[0]) != len(requestPath) {
- continue
- }
-
- if route.httpHandler != nil {
- unused = &route
- // We can not handle custom http handlers here, give back to the caller.
- return
- }
-
- var args []reflect.Value
- handlerType := route.handler.Type()
- if requiresContext(handlerType) {
- args = append(args, reflect.ValueOf(&ctx))
- }
- for _, arg := range match[1:] {
- args = append(args, reflect.ValueOf(arg))
- }
-
- ret, err := s.safelyCall(route.handler, args)
- if err != nil {
- //there was an error or panic while calling the handler
- ctx.Abort(500, "Server Error")
- }
- if len(ret) == 0 {
- return
- }
-
- sval := ret[0]
-
- var content []byte
-
- if sval.Kind() == reflect.String {
- content = []byte(sval.String())
- } else if sval.Kind() == reflect.Slice && sval.Type().Elem().Kind() == reflect.Uint8 {
- content = sval.Interface().([]byte)
- }
- ctx.SetHeader("Content-Length", strconv.Itoa(len(content)), true)
- _, err = ctx.ResponseWriter.Write(content)
- if err != nil {
- ctx.Server.Logger.Println("Error during write: ", err)
- }
- return
- }
-
- // try serving index.html or index.htm
- if req.Method == "GET" || req.Method == "HEAD" {
- if s.tryServingFile(path.Join(requestPath, "index.html"), req, w) {
- return
- } else if s.tryServingFile(path.Join(requestPath, "index.htm"), req, w) {
- return
- }
- }
- ctx.Abort(404, "Page not found")
- return
+ requestPath := req.URL.Path
+ ctx := Context{req, map[string]string{}, s, w}
+
+ //set some default headers
+ ctx.SetHeader("Server", "web.go", true)
+ tm := time.Now().UTC()
+
+ //ignore errors from ParseForm because it's usually harmless.
+ req.ParseForm()
+ if len(req.Form) > 0 {
+ for k, v := range req.Form {
+ ctx.Params[k] = v[0]
+ }
+ }
+
+ defer s.logRequest(ctx, tm)
+
+ ctx.SetHeader("Date", webTime(tm), true)
+
+ if req.Method == "GET" || req.Method == "HEAD" {
+ if s.tryServingFile(requestPath, req, w) {
+ return
+ }
+ }
+
+ for i := 0; i < len(s.routes); i++ {
+ route := s.routes[i]
+ cr := route.cr
+ //if the methods don't match, skip this handler (except HEAD can be used in place of GET)
+ if req.Method != route.method && !(req.Method == "HEAD" && route.method == "GET") {
+ continue
+ }
+
+ if !cr.MatchString(requestPath) {
+ continue
+ }
+ match := cr.FindStringSubmatch(requestPath)
+
+ if len(match[0]) != len(requestPath) {
+ continue
+ }
+
+ if route.httpHandler != nil {
+ unused = &route
+ // We can not handle custom http handlers here, give back to the caller.
+ return
+ }
+
+ // set the default content-type
+ ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true)
+
+ var args []reflect.Value
+ handlerType := route.handler.Type()
+ if requiresContext(handlerType) {
+ args = append(args, reflect.ValueOf(&ctx))
+ }
+ for _, arg := range match[1:] {
+ args = append(args, reflect.ValueOf(arg))
+ }
+
+ ret, err := s.safelyCall(route.handler, args)
+ if err != nil {
+ //there was an error or panic while calling the handler
+ ctx.Abort(500, "Server Error")
+ }
+ if len(ret) == 0 {
+ return
+ }
+
+ sval := ret[0]
+
+ var content []byte
+
+ if sval.Kind() == reflect.String {
+ content = []byte(sval.String())
+ } else if sval.Kind() == reflect.Slice && sval.Type().Elem().Kind() == reflect.Uint8 {
+ content = sval.Interface().([]byte)
+ }
+ ctx.SetHeader("Content-Length", strconv.Itoa(len(content)), true)
+ _, err = ctx.ResponseWriter.Write(content)
+ if err != nil {
+ ctx.Server.Logger.Println("Error during write: ", err)
+ }
+ return
+ }
+
+ // try serving index.html or index.htm
+ if req.Method == "GET" || req.Method == "HEAD" {
+ if s.tryServingFile(path.Join(requestPath, "index.html"), req, w) {
+ return
+ } else if s.tryServingFile(path.Join(requestPath, "index.htm"), req, w) {
+ return
+ }
+ }
+ ctx.Abort(404, "Page not found")
+ return
}
// SetLogger sets the logger for server s
func (s *Server) SetLogger(logger *log.Logger) {
- s.Logger = logger
+ s.Logger = logger
}
diff --git a/status.go b/status.go
deleted file mode 100644
index 83053cce..00000000
--- a/status.go
+++ /dev/null
@@ -1,54 +0,0 @@
-// Copyright 2010 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package web
-
-import "net/http"
-
-var statusText = map[int]string{
- http.StatusContinue: "Continue",
- http.StatusSwitchingProtocols: "Switching Protocols",
-
- http.StatusOK: "OK",
- http.StatusCreated: "Created",
- http.StatusAccepted: "Accepted",
- http.StatusNonAuthoritativeInfo: "Non-Authoritative Information",
- http.StatusNoContent: "No Content",
- http.StatusResetContent: "Reset Content",
- http.StatusPartialContent: "Partial Content",
-
- http.StatusMultipleChoices: "Multiple Choices",
- http.StatusMovedPermanently: "Moved Permanently",
- http.StatusFound: "Found",
- http.StatusSeeOther: "See Other",
- http.StatusNotModified: "Not Modified",
- http.StatusUseProxy: "Use Proxy",
- http.StatusTemporaryRedirect: "Temporary Redirect",
-
- http.StatusBadRequest: "Bad Request",
- http.StatusUnauthorized: "Unauthorized",
- http.StatusPaymentRequired: "Payment Required",
- http.StatusForbidden: "Forbidden",
- http.StatusNotFound: "Not Found",
- http.StatusMethodNotAllowed: "Method Not Allowed",
- http.StatusNotAcceptable: "Not Acceptable",
- http.StatusProxyAuthRequired: "Proxy Authentication Required",
- http.StatusRequestTimeout: "Request Timeout",
- http.StatusConflict: "Conflict",
- http.StatusGone: "Gone",
- http.StatusLengthRequired: "Length Required",
- http.StatusPreconditionFailed: "Precondition Failed",
- http.StatusRequestEntityTooLarge: "Request Entity Too Large",
- http.StatusRequestURITooLong: "Request URI Too Long",
- http.StatusUnsupportedMediaType: "Unsupported Media Type",
- http.StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable",
- http.StatusExpectationFailed: "Expectation Failed",
-
- http.StatusInternalServerError: "Internal Server Error",
- http.StatusNotImplemented: "Not Implemented",
- http.StatusBadGateway: "Bad Gateway",
- http.StatusServiceUnavailable: "Service Unavailable",
- http.StatusGatewayTimeout: "Gateway Timeout",
- http.StatusHTTPVersionNotSupported: "HTTP Version Not Supported",
-}
diff --git a/ttycolors.go b/ttycolors.go
new file mode 100644
index 00000000..fe63c1af
--- /dev/null
+++ b/ttycolors.go
@@ -0,0 +1,30 @@
+package web
+
+import (
+ "golang.org/x/crypto/ssh/terminal"
+ "syscall"
+)
+
+var ttyCodes struct {
+ green string
+ white string
+ reset string
+}
+
+func init() {
+ ttyCodes.green = ttyBold("32")
+ ttyCodes.white = ttyBold("37")
+ ttyCodes.reset = ttyEscape("0")
+}
+
+func ttyBold(code string) string {
+ return ttyEscape("1;" + code)
+}
+
+func ttyEscape(code string) string {
+ if terminal.IsTerminal(syscall.Stdout) {
+ return "\x1b[" + code + "m"
+ } else {
+ return ""
+ }
+}
diff --git a/web.go b/web.go
index 41bc8ec4..8d36130d 100644
--- a/web.go
+++ b/web.go
@@ -3,23 +3,15 @@
package web
import (
- "bytes"
- "code.google.com/p/go.net/websocket"
- "crypto/hmac"
- "crypto/sha1"
- "crypto/tls"
- "encoding/base64"
- "fmt"
- "io/ioutil"
- "log"
- "mime"
- "net/http"
- "os"
- "path"
- "reflect"
- "strconv"
- "strings"
- "time"
+ "crypto/tls"
+ "golang.org/x/net/websocket"
+ "log"
+ "mime"
+ "net/http"
+ "os"
+ "path"
+ "reflect"
+ "strings"
)
// A Context object is created for every incoming HTTP request, and is
@@ -27,15 +19,15 @@ import (
// about the request, including the http.Request object, the GET and POST params,
// and acts as a Writer for the response.
type Context struct {
- Request *http.Request
- Params map[string]string
- Server *Server
- http.ResponseWriter
+ Request *http.Request
+ Params map[string]string
+ Server *Server
+ http.ResponseWriter
}
// WriteString writes string data into the response object.
func (ctx *Context) WriteString(content string) {
- ctx.ResponseWriter.Write([]byte(content))
+ ctx.ResponseWriter.Write([]byte(content))
}
// Abort is a helper method that sends an HTTP header and an optional
@@ -43,36 +35,42 @@ func (ctx *Context) WriteString(content string) {
// Once it has been called, any return value from the handler will
// not be written to the response.
func (ctx *Context) Abort(status int, body string) {
- ctx.ResponseWriter.WriteHeader(status)
- ctx.ResponseWriter.Write([]byte(body))
+ ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true)
+ ctx.ResponseWriter.WriteHeader(status)
+ ctx.ResponseWriter.Write([]byte(body))
}
// Redirect is a helper method for 3xx redirects.
func (ctx *Context) Redirect(status int, url_ string) {
- ctx.ResponseWriter.Header().Set("Location", url_)
- ctx.ResponseWriter.WriteHeader(status)
- ctx.ResponseWriter.Write([]byte("Redirecting to: " + url_))
+ ctx.ResponseWriter.Header().Set("Location", url_)
+ ctx.ResponseWriter.WriteHeader(status)
+ ctx.ResponseWriter.Write([]byte("Redirecting to: " + url_))
}
-// Notmodified writes a 304 HTTP response
-func (ctx *Context) NotModified() {
- ctx.ResponseWriter.WriteHeader(304)
+//BadRequest writes a 400 HTTP response
+func (ctx *Context) BadRequest() {
+ ctx.ResponseWriter.WriteHeader(400)
}
-// NotFound writes a 404 HTTP response
-func (ctx *Context) NotFound(message string) {
- ctx.ResponseWriter.WriteHeader(404)
- ctx.ResponseWriter.Write([]byte(message))
+// Notmodified writes a 304 HTTP response
+func (ctx *Context) NotModified() {
+ ctx.ResponseWriter.WriteHeader(304)
}
//Unauthorized writes a 401 HTTP response
func (ctx *Context) Unauthorized() {
- ctx.ResponseWriter.WriteHeader(401)
+ ctx.ResponseWriter.WriteHeader(401)
}
//Forbidden writes a 403 HTTP response
func (ctx *Context) Forbidden() {
- ctx.ResponseWriter.WriteHeader(403)
+ ctx.ResponseWriter.WriteHeader(403)
+}
+
+// NotFound writes a 404 HTTP response
+func (ctx *Context) NotFound(message string) {
+ ctx.ResponseWriter.WriteHeader(404)
+ ctx.ResponseWriter.Write([]byte(message))
}
// ContentType sets the Content-Type header for an HTTP response.
@@ -81,93 +79,34 @@ func (ctx *Context) Forbidden() {
// verbatim. The return value is the content type as it was
// set, or an empty string if none was found.
func (ctx *Context) ContentType(val string) string {
- var ctype string
- if strings.ContainsRune(val, '/') {
- ctype = val
- } else {
- if !strings.HasPrefix(val, ".") {
- val = "." + val
- }
- ctype = mime.TypeByExtension(val)
- }
- if ctype != "" {
- ctx.Header().Set("Content-Type", ctype)
- }
- return ctype
+ var ctype string
+ if strings.ContainsRune(val, '/') {
+ ctype = val
+ } else {
+ if !strings.HasPrefix(val, ".") {
+ val = "." + val
+ }
+ ctype = mime.TypeByExtension(val)
+ }
+ if ctype != "" {
+ ctx.Header().Set("Content-Type", ctype)
+ }
+ return ctype
}
// SetHeader sets a response header. If `unique` is true, the current value
// of that header will be overwritten . If false, it will be appended.
func (ctx *Context) SetHeader(hdr string, val string, unique bool) {
- if unique {
- ctx.Header().Set(hdr, val)
- } else {
- ctx.Header().Add(hdr, val)
- }
+ if unique {
+ ctx.Header().Set(hdr, val)
+ } else {
+ ctx.Header().Add(hdr, val)
+ }
}
// SetCookie adds a cookie header to the response.
func (ctx *Context) SetCookie(cookie *http.Cookie) {
- ctx.SetHeader("Set-Cookie", cookie.String(), false)
-}
-
-func getCookieSig(key string, val []byte, timestamp string) string {
- hm := hmac.New(sha1.New, []byte(key))
-
- hm.Write(val)
- hm.Write([]byte(timestamp))
-
- hex := fmt.Sprintf("%02x", hm.Sum(nil))
- return hex
-}
-
-func (ctx *Context) SetSecureCookie(name string, val string, age int64) {
- //base64 encode the val
- if len(ctx.Server.Config.CookieSecret) == 0 {
- ctx.Server.Logger.Println("Secret Key for secure cookies has not been set. Please assign a cookie secret to web.Config.CookieSecret.")
- return
- }
- var buf bytes.Buffer
- encoder := base64.NewEncoder(base64.StdEncoding, &buf)
- encoder.Write([]byte(val))
- encoder.Close()
- vs := buf.String()
- vb := buf.Bytes()
- timestamp := strconv.FormatInt(time.Now().Unix(), 10)
- sig := getCookieSig(ctx.Server.Config.CookieSecret, vb, timestamp)
- cookie := strings.Join([]string{vs, timestamp, sig}, "|")
- ctx.SetCookie(NewCookie(name, cookie, age))
-}
-
-func (ctx *Context) GetSecureCookie(name string) (string, bool) {
- for _, cookie := range ctx.Request.Cookies() {
- if cookie.Name != name {
- continue
- }
-
- parts := strings.SplitN(cookie.Value, "|", 3)
-
- val := parts[0]
- timestamp := parts[1]
- sig := parts[2]
-
- if getCookieSig(ctx.Server.Config.CookieSecret, []byte(val), timestamp) != sig {
- return "", false
- }
-
- ts, _ := strconv.ParseInt(timestamp, 0, 64)
-
- if time.Now().Unix()-31*86400 > ts {
- return "", false
- }
-
- buf := bytes.NewBufferString(val)
- encoder := base64.NewDecoder(base64.StdEncoding, buf)
-
- res, _ := ioutil.ReadAll(encoder)
- return string(res), true
- }
- return "", false
+ ctx.SetHeader("Set-Cookie", cookie.String(), false)
}
// small optimization: cache the context type instead of repeteadly calling reflect.Typeof
@@ -176,96 +115,97 @@ var contextType reflect.Type
var defaultStaticDirs []string
func init() {
- contextType = reflect.TypeOf(Context{})
- //find the location of the exe file
- wd, _ := os.Getwd()
- arg0 := path.Clean(os.Args[0])
- var exeFile string
- if strings.HasPrefix(arg0, "/") {
- exeFile = arg0
- } else {
- //TODO for robustness, search each directory in $PATH
- exeFile = path.Join(wd, arg0)
- }
- parent, _ := path.Split(exeFile)
- defaultStaticDirs = append(defaultStaticDirs, path.Join(parent, "static"))
- defaultStaticDirs = append(defaultStaticDirs, path.Join(wd, "static"))
- return
+ contextType = reflect.TypeOf(Context{})
+ //find the location of the exe file
+ wd, _ := os.Getwd()
+ arg0 := path.Clean(os.Args[0])
+ var exeFile string
+ if strings.HasPrefix(arg0, "/") {
+ exeFile = arg0
+ } else {
+ //TODO for robustness, search each directory in $PATH
+ exeFile = path.Join(wd, arg0)
+ }
+ parent, _ := path.Split(exeFile)
+ defaultStaticDirs = append(defaultStaticDirs, path.Join(parent, "static"))
+ defaultStaticDirs = append(defaultStaticDirs, path.Join(wd, "static"))
+ return
}
// Process invokes the main server's routing system.
func Process(c http.ResponseWriter, req *http.Request) {
- mainServer.Process(c, req)
+ mainServer.Process(c, req)
}
// Run starts the web application and serves HTTP requests for the main server.
func Run(addr string) {
- mainServer.Run(addr)
+ mainServer.Run(addr)
}
// RunTLS starts the web application and serves HTTPS requests for the main server.
func RunTLS(addr string, config *tls.Config) {
- mainServer.RunTLS(addr, config)
+ mainServer.RunTLS(addr, config)
}
// RunScgi starts the web application and serves SCGI requests for the main server.
func RunScgi(addr string) {
- mainServer.RunScgi(addr)
+ mainServer.RunScgi(addr)
}
// RunFcgi starts the web application and serves FastCGI requests for the main server.
func RunFcgi(addr string) {
- mainServer.RunFcgi(addr)
+ mainServer.RunFcgi(addr)
}
// Close stops the main server.
func Close() {
- mainServer.Close()
+ mainServer.Close()
}
// Get adds a handler for the 'GET' http method in the main server.
func Get(route string, handler interface{}) {
- mainServer.Get(route, handler)
+ mainServer.Get(route, handler)
}
// Post adds a handler for the 'POST' http method in the main server.
func Post(route string, handler interface{}) {
- mainServer.addRoute(route, "POST", handler)
+ mainServer.addRoute(route, "POST", handler)
}
// Put adds a handler for the 'PUT' http method in the main server.
func Put(route string, handler interface{}) {
- mainServer.addRoute(route, "PUT", handler)
+ mainServer.addRoute(route, "PUT", handler)
}
// Delete adds a handler for the 'DELETE' http method in the main server.
func Delete(route string, handler interface{}) {
- mainServer.addRoute(route, "DELETE", handler)
+ mainServer.addRoute(route, "DELETE", handler)
}
// Match adds a handler for an arbitrary http method in the main server.
func Match(method string, route string, handler interface{}) {
- mainServer.addRoute(route, method, handler)
+ mainServer.addRoute(route, method, handler)
}
-//Adds a custom handler. Only for webserver mode. Will have no effect when running as FCGI or SCGI.
-func Handler(route string, method string, httpHandler http.Handler) {
- mainServer.Handler(route, method, httpHandler)
+// Add a custom http.Handler. Will have no effect when running as FCGI or SCGI.
+func Handle(route string, method string, httpHandler http.Handler) {
+ mainServer.Handle(route, method, httpHandler)
}
//Adds a handler for websockets. Only for webserver mode. Will have no effect when running as FCGI or SCGI.
func Websocket(route string, httpHandler websocket.Handler) {
- mainServer.Websocket(route, httpHandler)
+ mainServer.Websocket(route, httpHandler)
}
// SetLogger sets the logger for the main server.
func SetLogger(logger *log.Logger) {
- mainServer.Logger = logger
+ mainServer.Logger = logger
}
// Config is the configuration of the main server.
var Config = &ServerConfig{
- RecoverPanic: true,
+ RecoverPanic: true,
+ ColorOutput: true,
}
var mainServer = NewServer()
diff --git a/web_test.go b/web_test.go
index 609ca509..6468b796 100644
--- a/web_test.go
+++ b/web_test.go
@@ -1,580 +1,679 @@
package web
import (
- "bytes"
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "io/ioutil"
- "log"
- "net/http"
- "net/url"
- "runtime"
- "strconv"
- "strings"
- "testing"
+ "bytes"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/url"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
)
func init() {
- runtime.GOMAXPROCS(4)
+ runtime.GOMAXPROCS(runtime.NumCPU())
}
// ioBuffer is a helper that implements io.ReadWriteCloser,
// which is helpful in imitating a net.Conn
type ioBuffer struct {
- input *bytes.Buffer
- output *bytes.Buffer
- closed bool
+ input *bytes.Buffer
+ output *bytes.Buffer
+ closed bool
}
func (buf *ioBuffer) Write(p []uint8) (n int, err error) {
- if buf.closed {
- return 0, errors.New("Write after Close on ioBuffer")
- }
- return buf.output.Write(p)
+ if buf.closed {
+ return 0, errors.New("Write after Close on ioBuffer")
+ }
+ return buf.output.Write(p)
}
func (buf *ioBuffer) Read(p []byte) (n int, err error) {
- if buf.closed {
- return 0, errors.New("Read after Close on ioBuffer")
- }
- return buf.input.Read(p)
+ if buf.closed {
+ return 0, errors.New("Read after Close on ioBuffer")
+ }
+ return buf.input.Read(p)
}
//noop
func (buf *ioBuffer) Close() error {
- buf.closed = true
- return nil
+ buf.closed = true
+ return nil
}
type testResponse struct {
- statusCode int
- status string
- body string
- headers map[string][]string
- cookies map[string]string
+ statusCode int
+ status string
+ body string
+ headers map[string][]string
+ cookies map[string]string
}
func buildTestResponse(buf *bytes.Buffer) *testResponse {
- response := testResponse{headers: make(map[string][]string), cookies: make(map[string]string)}
- s := buf.String()
- contents := strings.SplitN(s, "\r\n\r\n", 2)
+ response := testResponse{headers: make(map[string][]string), cookies: make(map[string]string)}
+ s := buf.String()
+ contents := strings.SplitN(s, "\r\n\r\n", 2)
- header := contents[0]
+ header := contents[0]
- if len(contents) > 1 {
- response.body = contents[1]
- }
+ if len(contents) > 1 {
+ response.body = contents[1]
+ }
- headers := strings.Split(header, "\r\n")
+ headers := strings.Split(header, "\r\n")
- statusParts := strings.SplitN(headers[0], " ", 3)
- response.statusCode, _ = strconv.Atoi(statusParts[1])
+ statusParts := strings.SplitN(headers[0], " ", 3)
+ response.statusCode, _ = strconv.Atoi(statusParts[1])
- for _, h := range headers[1:] {
- split := strings.SplitN(h, ":", 2)
- name := strings.TrimSpace(split[0])
- value := strings.TrimSpace(split[1])
- if _, ok := response.headers[name]; !ok {
- response.headers[name] = []string{}
- }
+ for _, h := range headers[1:] {
+ split := strings.SplitN(h, ":", 2)
+ name := strings.TrimSpace(split[0])
+ value := strings.TrimSpace(split[1])
+ if _, ok := response.headers[name]; !ok {
+ response.headers[name] = []string{}
+ }
- newheaders := make([]string, len(response.headers[name])+1)
- copy(newheaders, response.headers[name])
- newheaders[len(newheaders)-1] = value
- response.headers[name] = newheaders
+ newheaders := make([]string, len(response.headers[name])+1)
+ copy(newheaders, response.headers[name])
+ newheaders[len(newheaders)-1] = value
+ response.headers[name] = newheaders
- //if the header is a cookie, set it
- if name == "Set-Cookie" {
- i := strings.Index(value, ";")
- cookie := value[0:i]
- cookieParts := strings.SplitN(cookie, "=", 2)
- response.cookies[strings.TrimSpace(cookieParts[0])] = strings.TrimSpace(cookieParts[1])
- }
- }
+ //if the header is a cookie, set it
+ if name == "Set-Cookie" {
+ i := strings.Index(value, ";")
+ cookie := value[0:i]
+ cookieParts := strings.SplitN(cookie, "=", 2)
+ response.cookies[strings.TrimSpace(cookieParts[0])] = strings.TrimSpace(cookieParts[1])
+ }
+ }
- return &response
+ return &response
}
func getTestResponse(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *testResponse {
- req := buildTestRequest(method, path, body, headers, cookies)
- var buf bytes.Buffer
+ req := buildTestRequest(method, path, body, headers, cookies)
+ var buf bytes.Buffer
- tcpb := ioBuffer{input: nil, output: &buf}
- c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &tcpb}
- mainServer.Process(&c, req)
- return buildTestResponse(&buf)
+ tcpb := ioBuffer{input: nil, output: &buf}
+ c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &tcpb}
+ mainServer.Process(&c, req)
+ return buildTestResponse(&buf)
}
func testGet(path string, headers map[string]string) *testResponse {
- var header http.Header
- for k, v := range headers {
- header.Set(k, v)
- }
- return getTestResponse("GET", path, "", header, nil)
+ var header http.Header
+ for k, v := range headers {
+ header.Set(k, v)
+ }
+ return getTestResponse("GET", path, "", header, nil)
}
type Test struct {
- method string
- path string
- headers map[string][]string
- body string
- expectedStatus int
- expectedBody string
+ method string
+ path string
+ headers map[string][]string
+ body string
+ expectedStatus int
+ expectedBody string
}
//initialize the routes
func init() {
- mainServer.SetLogger(log.New(ioutil.Discard, "", 0))
- Get("/", func() string { return "index" })
- Get("/panic", func() { panic(0) })
- Get("/echo/(.*)", func(s string) string { return s })
- Get("/multiecho/(.*)/(.*)/(.*)/(.*)", func(a, b, c, d string) string { return a + b + c + d })
- Post("/post/echo/(.*)", func(s string) string { return s })
- Post("/post/echoparam/(.*)", func(ctx *Context, name string) string { return ctx.Params[name] })
-
- Get("/error/code/(.*)", func(ctx *Context, code string) string {
- n, _ := strconv.Atoi(code)
- message := statusText[n]
- ctx.Abort(n, message)
- return ""
- })
-
- Get("/error/notfound/(.*)", func(ctx *Context, message string) { ctx.NotFound(message) })
-
- Get("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() })
- Post("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() })
-
- Get("/error/forbidden", func(ctx *Context) { ctx.Forbidden() })
- Post("/error/forbidden", func(ctx *Context) { ctx.Forbidden() })
-
- Post("/posterror/code/(.*)/(.*)", func(ctx *Context, code string, message string) string {
- n, _ := strconv.Atoi(code)
- ctx.Abort(n, message)
- return ""
- })
-
- Get("/writetest", func(ctx *Context) { ctx.WriteString("hello") })
-
- Post("/securecookie/set/(.+)/(.+)", func(ctx *Context, name string, val string) string {
- ctx.SetSecureCookie(name, val, 60)
- return ""
- })
-
- Get("/securecookie/get/(.+)", func(ctx *Context, name string) string {
- val, ok := ctx.GetSecureCookie(name)
- if !ok {
- return ""
- }
- return val
- })
- Get("/getparam", func(ctx *Context) string { return ctx.Params["a"] })
- Get("/fullparams", func(ctx *Context) string {
- return strings.Join(ctx.Request.Form["a"], ",")
- })
-
- Get("/json", func(ctx *Context) string {
- ctx.ContentType("json")
- data, _ := json.Marshal(ctx.Params)
- return string(data)
- })
-
- Get("/jsonbytes", func(ctx *Context) []byte {
- ctx.ContentType("json")
- data, _ := json.Marshal(ctx.Params)
- return data
- })
-
- Post("/parsejson", func(ctx *Context) string {
- var tmp = struct {
- A string
- B string
- }{}
- json.NewDecoder(ctx.Request.Body).Decode(&tmp)
- return tmp.A + " " + tmp.B
- })
-
- Match("OPTIONS", "/options", func(ctx *Context) {
- ctx.SetHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS", true)
- ctx.SetHeader("Access-Control-Max-Age", "1000", true)
- ctx.WriteHeader(200)
- })
-
- Get("/dupeheader", func(ctx *Context) string {
- ctx.SetHeader("Server", "myserver", true)
- return ""
- })
-
- Get("/authorization", func(ctx *Context) string {
- user, pass, err := ctx.GetBasicAuth()
- if err != nil {
- return "fail"
- }
- return user + pass
- })
+ mainServer.SetLogger(log.New(ioutil.Discard, "", 0))
+ Get("/", func() string { return "index" })
+ Get("/panic", func() { panic(0) })
+ Get("/echo/(.*)", func(s string) string { return s })
+ Get("/multiecho/(.*)/(.*)/(.*)/(.*)", func(a, b, c, d string) string { return a + b + c + d })
+ Post("/post/echo/(.*)", func(s string) string { return s })
+ Post("/post/echoparam/(.*)", func(ctx *Context, name string) string { return ctx.Params[name] })
+
+ Get("/error/code/(.*)", func(ctx *Context, code string) string {
+ n, _ := strconv.Atoi(code)
+ message := http.StatusText(n)
+ ctx.Abort(n, message)
+ return ""
+ })
+
+ Get("/error/notfound/(.*)", func(ctx *Context, message string) { ctx.NotFound(message) })
+
+ Get("/error/badrequest", func(ctx *Context) { ctx.BadRequest() })
+ Post("/error/badrequest", func(ctx *Context) { ctx.BadRequest() })
+
+ Get("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() })
+ Post("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() })
+
+ Get("/error/forbidden", func(ctx *Context) { ctx.Forbidden() })
+ Post("/error/forbidden", func(ctx *Context) { ctx.Forbidden() })
+
+ Post("/posterror/code/(.*)/(.*)", func(ctx *Context, code string, message string) string {
+ n, _ := strconv.Atoi(code)
+ ctx.Abort(n, message)
+ return ""
+ })
+
+ Get("/writetest", func(ctx *Context) { ctx.WriteString("hello") })
+
+ Post("/securecookie/set/(.+)/(.+)", func(ctx *Context, name string, val string) string {
+ ctx.SetSecureCookie(name, val, 60)
+ return ""
+ })
+
+ Get("/securecookie/get/(.+)", func(ctx *Context, name string) string {
+ val, ok := ctx.GetSecureCookie(name)
+ if !ok {
+ return ""
+ }
+ return val
+ })
+ Get("/getparam", func(ctx *Context) string { return ctx.Params["a"] })
+ Get("/fullparams", func(ctx *Context) string {
+ return strings.Join(ctx.Request.Form["a"], ",")
+ })
+
+ Get("/json", func(ctx *Context) string {
+ ctx.ContentType("json")
+ data, _ := json.Marshal(ctx.Params)
+ return string(data)
+ })
+
+ Get("/jsonbytes", func(ctx *Context) []byte {
+ ctx.ContentType("json")
+ data, _ := json.Marshal(ctx.Params)
+ return data
+ })
+
+ Post("/parsejson", func(ctx *Context) string {
+ var tmp = struct {
+ A string
+ B string
+ }{}
+ json.NewDecoder(ctx.Request.Body).Decode(&tmp)
+ return tmp.A + " " + tmp.B
+ })
+
+ Match("OPTIONS", "/options", func(ctx *Context) {
+ ctx.SetHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS", true)
+ ctx.SetHeader("Access-Control-Max-Age", "1000", true)
+ ctx.WriteHeader(200)
+ })
+
+ Get("/dupeheader", func(ctx *Context) string {
+ ctx.SetHeader("Server", "myserver", true)
+ return ""
+ })
+
+ Get("/authorization", func(ctx *Context) string {
+ user, pass, err := ctx.GetBasicAuth()
+ if err != nil {
+ return "fail"
+ }
+ return user + pass
+ })
}
var tests = []Test{
- {"GET", "/", nil, "", 200, "index"},
- {"GET", "/echo/hello", nil, "", 200, "hello"},
- {"GET", "/echo/hello", nil, "", 200, "hello"},
- {"GET", "/multiecho/a/b/c/d", nil, "", 200, "abcd"},
- {"POST", "/post/echo/hello", nil, "", 200, "hello"},
- {"POST", "/post/echo/hello", nil, "", 200, "hello"},
- {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello", 200, "hello"},
- {"POST", "/post/echoparam/c?c=hello", nil, "", 200, "hello"},
- {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello\x00", 200, "hello\x00"},
- //long url
- {"GET", "/echo/" + strings.Repeat("0123456789", 100), nil, "", 200, strings.Repeat("0123456789", 100)},
- {"GET", "/writetest", nil, "", 200, "hello"},
- {"GET", "/error/unauthorized", nil, "", 401, ""},
- {"POST", "/error/unauthorized", nil, "", 401, ""},
- {"GET", "/error/forbidden", nil, "", 403, ""},
- {"POST", "/error/forbidden", nil, "", 403, ""},
- {"GET", "/error/notfound/notfound", nil, "", 404, "notfound"},
- {"GET", "/doesnotexist", nil, "", 404, "Page not found"},
- {"POST", "/doesnotexist", nil, "", 404, "Page not found"},
- {"GET", "/error/code/500", nil, "", 500, statusText[500]},
- {"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"},
- {"GET", "/getparam?a=abcd", nil, "", 200, "abcd"},
- {"GET", "/getparam?b=abcd", nil, "", 200, ""},
- {"GET", "/fullparams?a=1&a=2&a=3", nil, "", 200, "1,2,3"},
- {"GET", "/panic", nil, "", 500, "Server Error"},
- {"GET", "/json?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`},
- {"GET", "/jsonbytes?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`},
- {"POST", "/parsejson", map[string][]string{"Content-Type": {"application/json"}}, `{"a":"hello", "b":"world"}`, 200, "hello world"},
- //{"GET", "/testenv", "", 200, "hello world"},
- {"GET", "/authorization", map[string][]string{"Authorization": {BuildBasicAuthCredentials("foo", "bar")}}, "", 200, "foobar"},
+ {"GET", "/", nil, "", 200, "index"},
+ {"GET", "/echo/hello", nil, "", 200, "hello"},
+ {"GET", "/echo/hello", nil, "", 200, "hello"},
+ {"GET", "/multiecho/a/b/c/d", nil, "", 200, "abcd"},
+ {"POST", "/post/echo/hello", nil, "", 200, "hello"},
+ {"POST", "/post/echo/hello", nil, "", 200, "hello"},
+ {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello", 200, "hello"},
+ {"POST", "/post/echoparam/c?c=hello", nil, "", 200, "hello"},
+ {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello\x00", 200, "hello\x00"},
+ //long url
+ {"GET", "/echo/" + strings.Repeat("0123456789", 100), nil, "", 200, strings.Repeat("0123456789", 100)},
+ {"GET", "/writetest", nil, "", 200, "hello"},
+ {"GET", "/error/badrequest", nil, "", 400, ""},
+ {"POST", "/error/badrequest", nil, "", 400, ""},
+ {"GET", "/error/unauthorized", nil, "", 401, ""},
+ {"POST", "/error/unauthorized", nil, "", 401, ""},
+ {"GET", "/error/forbidden", nil, "", 403, ""},
+ {"POST", "/error/forbidden", nil, "", 403, ""},
+ {"GET", "/error/notfound/notfound", nil, "", 404, "notfound"},
+ {"GET", "/doesnotexist", nil, "", 404, "Page not found"},
+ {"POST", "/doesnotexist", nil, "", 404, "Page not found"},
+ {"GET", "/error/code/500", nil, "", 500, http.StatusText(500)},
+ {"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"},
+ {"GET", "/getparam?a=abcd", nil, "", 200, "abcd"},
+ {"GET", "/getparam?b=abcd", nil, "", 200, ""},
+ {"GET", "/fullparams?a=1&a=2&a=3", nil, "", 200, "1,2,3"},
+ {"GET", "/panic", nil, "", 500, "Server Error"},
+ {"GET", "/json?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`},
+ {"GET", "/jsonbytes?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`},
+ {"POST", "/parsejson", map[string][]string{"Content-Type": {"application/json"}}, `{"a":"hello", "b":"world"}`, 200, "hello world"},
+ //{"GET", "/testenv", "", 200, "hello world"},
+ {"GET", "/authorization", map[string][]string{"Authorization": {BuildBasicAuthCredentials("foo", "bar")}}, "", 200, "foobar"},
+ {"GET", "/authorization", nil, "", 200, "fail"},
}
func buildTestRequest(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *http.Request {
- host := "127.0.0.1"
- port := "80"
- rawurl := "http://" + host + ":" + port + path
- url_, _ := url.Parse(rawurl)
- proto := "HTTP/1.1"
-
- if headers == nil {
- headers = map[string][]string{}
- }
-
- headers["User-Agent"] = []string{"web.go test"}
- if method == "POST" {
- headers["Content-Length"] = []string{fmt.Sprintf("%d", len(body))}
- if headers["Content-Type"] == nil {
- headers["Content-Type"] = []string{"text/plain"}
- }
- }
-
- req := http.Request{Method: method,
- URL: url_,
- Proto: proto,
- Host: host,
- Header: http.Header(headers),
- Body: ioutil.NopCloser(bytes.NewBufferString(body)),
- }
-
- for _, cookie := range cookies {
- req.AddCookie(cookie)
- }
- return &req
+ host := "127.0.0.1"
+ port := "80"
+ rawurl := "http://" + host + ":" + port + path
+ url_, _ := url.Parse(rawurl)
+ proto := "HTTP/1.1"
+
+ if headers == nil {
+ headers = map[string][]string{}
+ }
+
+ headers["User-Agent"] = []string{"web.go test"}
+ if method == "POST" {
+ headers["Content-Length"] = []string{fmt.Sprintf("%d", len(body))}
+ if headers["Content-Type"] == nil {
+ headers["Content-Type"] = []string{"text/plain"}
+ }
+ }
+
+ req := http.Request{Method: method,
+ URL: url_,
+ Proto: proto,
+ Host: host,
+ Header: http.Header(headers),
+ Body: ioutil.NopCloser(bytes.NewBufferString(body)),
+ }
+
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ return &req
}
func TestRouting(t *testing.T) {
- for _, test := range tests {
- resp := getTestResponse(test.method, test.path, test.body, test.headers, nil)
-
- if resp.statusCode != test.expectedStatus {
- t.Fatalf("%v(%v) expected status %d got %d", test.method, test.path, test.expectedStatus, resp.statusCode)
- }
- if resp.body != test.expectedBody {
- t.Fatalf("%v(%v) expected %q got %q", test.method, test.path, test.expectedBody, resp.body)
- }
- if cl, ok := resp.headers["Content-Length"]; ok {
- clExp, _ := strconv.Atoi(cl[0])
- clAct := len(resp.body)
- if clExp != clAct {
- t.Fatalf("Content-length doesn't match. expected %d got %d", clExp, clAct)
- }
- }
- }
+ for _, test := range tests {
+ resp := getTestResponse(test.method, test.path, test.body, test.headers, nil)
+
+ if resp.statusCode != test.expectedStatus {
+ t.Fatalf("%v(%v) expected status %d got %d", test.method, test.path, test.expectedStatus, resp.statusCode)
+ }
+ if resp.body != test.expectedBody {
+ t.Fatalf("%v(%v) expected %q got %q", test.method, test.path, test.expectedBody, resp.body)
+ }
+ if cl, ok := resp.headers["Content-Length"]; ok {
+ clExp, _ := strconv.Atoi(cl[0])
+ clAct := len(resp.body)
+ if clExp != clAct {
+ t.Fatalf("Content-length doesn't match. expected %d got %d", clExp, clAct)
+ }
+ }
+ }
}
func TestHead(t *testing.T) {
- for _, test := range tests {
-
- if test.method != "GET" {
- continue
- }
- getresp := getTestResponse("GET", test.path, test.body, test.headers, nil)
- headresp := getTestResponse("HEAD", test.path, test.body, test.headers, nil)
-
- if getresp.statusCode != headresp.statusCode {
- t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode)
- }
- if len(headresp.body) != 0 {
- t.Fatalf("head request arrived with a body")
- }
-
- var cl []string
- var getcl, headcl int
- var hascl1, hascl2 bool
-
- if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 {
- getcl, _ = strconv.Atoi(cl[0])
- }
-
- if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 {
- headcl, _ = strconv.Atoi(cl[0])
- }
-
- if hascl1 != hascl2 {
- t.Fatalf("head and get: one has content-length, one doesn't")
- }
-
- if hascl1 == true && getcl != headcl {
- t.Fatalf("head and get content-length differ")
- }
- }
+ for _, test := range tests {
+
+ if test.method != "GET" {
+ continue
+ }
+ getresp := getTestResponse("GET", test.path, test.body, test.headers, nil)
+ headresp := getTestResponse("HEAD", test.path, test.body, test.headers, nil)
+
+ if getresp.statusCode != headresp.statusCode {
+ t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode)
+ }
+ if len(headresp.body) != 0 {
+ t.Fatalf("head request arrived with a body")
+ }
+
+ var cl []string
+ var getcl, headcl int
+ var hascl1, hascl2 bool
+
+ if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 {
+ getcl, _ = strconv.Atoi(cl[0])
+ }
+
+ if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 {
+ headcl, _ = strconv.Atoi(cl[0])
+ }
+
+ if hascl1 != hascl2 {
+ t.Fatalf("head and get: one has content-length, one doesn't")
+ }
+
+ if hascl1 == true && getcl != headcl {
+ t.Fatalf("head and get content-length differ")
+ }
+ }
}
func buildTestScgiRequest(method string, path string, body string, headers map[string][]string) *bytes.Buffer {
- var headerBuf bytes.Buffer
- scgiHeaders := make(map[string]string)
-
- headerBuf.WriteString("CONTENT_LENGTH")
- headerBuf.WriteByte(0)
- headerBuf.WriteString(fmt.Sprintf("%d", len(body)))
- headerBuf.WriteByte(0)
-
- scgiHeaders["REQUEST_METHOD"] = method
- scgiHeaders["HTTP_HOST"] = "127.0.0.1"
- scgiHeaders["REQUEST_URI"] = path
- scgiHeaders["SERVER_PORT"] = "80"
- scgiHeaders["SERVER_PROTOCOL"] = "HTTP/1.1"
- scgiHeaders["USER_AGENT"] = "web.go test framework"
-
- for k, v := range headers {
- //Skip content-length
- if k == "Content-Length" {
- continue
- }
- key := "HTTP_" + strings.ToUpper(strings.Replace(k, "-", "_", -1))
- scgiHeaders[key] = v[0]
- }
- for k, v := range scgiHeaders {
- headerBuf.WriteString(k)
- headerBuf.WriteByte(0)
- headerBuf.WriteString(v)
- headerBuf.WriteByte(0)
- }
- headerData := headerBuf.Bytes()
-
- var buf bytes.Buffer
- //extra 1 is for the comma at the end
- dlen := len(headerData)
- fmt.Fprintf(&buf, "%d:", dlen)
- buf.Write(headerData)
- buf.WriteByte(',')
- buf.WriteString(body)
- return &buf
+ var headerBuf bytes.Buffer
+ scgiHeaders := make(map[string]string)
+
+ headerBuf.WriteString("CONTENT_LENGTH")
+ headerBuf.WriteByte(0)
+ headerBuf.WriteString(fmt.Sprintf("%d", len(body)))
+ headerBuf.WriteByte(0)
+
+ scgiHeaders["REQUEST_METHOD"] = method
+ scgiHeaders["HTTP_HOST"] = "127.0.0.1"
+ scgiHeaders["REQUEST_URI"] = path
+ scgiHeaders["SERVER_PORT"] = "80"
+ scgiHeaders["SERVER_PROTOCOL"] = "HTTP/1.1"
+ scgiHeaders["USER_AGENT"] = "web.go test framework"
+
+ for k, v := range headers {
+ //Skip content-length
+ if k == "Content-Length" {
+ continue
+ }
+ key := "HTTP_" + strings.ToUpper(strings.Replace(k, "-", "_", -1))
+ scgiHeaders[key] = v[0]
+ }
+ for k, v := range scgiHeaders {
+ headerBuf.WriteString(k)
+ headerBuf.WriteByte(0)
+ headerBuf.WriteString(v)
+ headerBuf.WriteByte(0)
+ }
+ headerData := headerBuf.Bytes()
+
+ var buf bytes.Buffer
+ //extra 1 is for the comma at the end
+ dlen := len(headerData)
+ fmt.Fprintf(&buf, "%d:", dlen)
+ buf.Write(headerData)
+ buf.WriteByte(',')
+ buf.WriteString(body)
+ return &buf
}
func TestScgi(t *testing.T) {
- for _, test := range tests {
- req := buildTestScgiRequest(test.method, test.path, test.body, test.headers)
- var output bytes.Buffer
- nb := ioBuffer{input: req, output: &output}
- mainServer.handleScgiRequest(&nb)
- resp := buildTestResponse(&output)
-
- if resp.statusCode != test.expectedStatus {
- t.Fatalf("expected status %d got %d", test.expectedStatus, resp.statusCode)
- }
-
- if resp.body != test.expectedBody {
- t.Fatalf("Scgi expected %q got %q", test.expectedBody, resp.body)
- }
- }
+ for _, test := range tests {
+ req := buildTestScgiRequest(test.method, test.path, test.body, test.headers)
+ var output bytes.Buffer
+ nb := ioBuffer{input: req, output: &output}
+ mainServer.handleScgiRequest(&nb)
+ resp := buildTestResponse(&output)
+
+ if resp.statusCode != test.expectedStatus {
+ t.Fatalf("expected status %d got %d", test.expectedStatus, resp.statusCode)
+ }
+
+ if resp.body != test.expectedBody {
+ t.Fatalf("Scgi expected %q got %q", test.expectedBody, resp.body)
+ }
+ }
}
func TestScgiHead(t *testing.T) {
- for _, test := range tests {
-
- if test.method != "GET" {
- continue
- }
-
- req := buildTestScgiRequest("GET", test.path, test.body, make(map[string][]string))
- var output bytes.Buffer
- nb := ioBuffer{input: req, output: &output}
- mainServer.handleScgiRequest(&nb)
- getresp := buildTestResponse(&output)
-
- req = buildTestScgiRequest("HEAD", test.path, test.body, make(map[string][]string))
- var output2 bytes.Buffer
- nb = ioBuffer{input: req, output: &output2}
- mainServer.handleScgiRequest(&nb)
- headresp := buildTestResponse(&output2)
-
- if getresp.statusCode != headresp.statusCode {
- t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode)
- }
- if len(headresp.body) != 0 {
- t.Fatalf("head request arrived with a body")
- }
-
- var cl []string
- var getcl, headcl int
- var hascl1, hascl2 bool
-
- if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 {
- getcl, _ = strconv.Atoi(cl[0])
- }
-
- if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 {
- headcl, _ = strconv.Atoi(cl[0])
- }
-
- if hascl1 != hascl2 {
- t.Fatalf("head and get: one has content-length, one doesn't")
- }
-
- if hascl1 == true && getcl != headcl {
- t.Fatalf("head and get content-length differ")
- }
- }
+ for _, test := range tests {
+
+ if test.method != "GET" {
+ continue
+ }
+
+ req := buildTestScgiRequest("GET", test.path, test.body, make(map[string][]string))
+ var output bytes.Buffer
+ nb := ioBuffer{input: req, output: &output}
+ mainServer.handleScgiRequest(&nb)
+ getresp := buildTestResponse(&output)
+
+ req = buildTestScgiRequest("HEAD", test.path, test.body, make(map[string][]string))
+ var output2 bytes.Buffer
+ nb = ioBuffer{input: req, output: &output2}
+ mainServer.handleScgiRequest(&nb)
+ headresp := buildTestResponse(&output2)
+
+ if getresp.statusCode != headresp.statusCode {
+ t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode)
+ }
+ if len(headresp.body) != 0 {
+ t.Fatalf("head request arrived with a body")
+ }
+
+ var cl []string
+ var getcl, headcl int
+ var hascl1, hascl2 bool
+
+ if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 {
+ getcl, _ = strconv.Atoi(cl[0])
+ }
+
+ if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 {
+ headcl, _ = strconv.Atoi(cl[0])
+ }
+
+ if hascl1 != hascl2 {
+ t.Fatalf("head and get: one has content-length, one doesn't")
+ }
+
+ if hascl1 == true && getcl != headcl {
+ t.Fatalf("head and get content-length differ")
+ }
+ }
}
func TestReadScgiRequest(t *testing.T) {
- headers := map[string][]string{"User-Agent": {"web.go"}}
- req := buildTestScgiRequest("POST", "/hello", "Hello world!", headers)
- var s Server
- httpReq, err := s.readScgiRequest(&ioBuffer{input: req, output: nil})
- if err != nil {
- t.Fatalf("Error while reading SCGI request: ", err.Error())
- }
- if httpReq.ContentLength != 12 {
- t.Fatalf("Content length mismatch, expected %d, got %d ", 12, httpReq.ContentLength)
- }
- var body bytes.Buffer
- io.Copy(&body, httpReq.Body)
- if body.String() != "Hello world!" {
- t.Fatalf("Body mismatch, expected %q, got %q ", "Hello world!", body.String())
- }
+ headers := map[string][]string{"User-Agent": {"web.go"}}
+ req := buildTestScgiRequest("POST", "/hello", "Hello world!", headers)
+ var s Server
+ httpReq, err := s.readScgiRequest(&ioBuffer{input: req, output: nil})
+ if err != nil {
+ t.Fatalf("Error while reading SCGI request: ", err.Error())
+ }
+ if httpReq.ContentLength != 12 {
+ t.Fatalf("Content length mismatch, expected %d, got %d ", 12, httpReq.ContentLength)
+ }
+ var body bytes.Buffer
+ io.Copy(&body, httpReq.Body)
+ if body.String() != "Hello world!" {
+ t.Fatalf("Body mismatch, expected %q, got %q ", "Hello world!", body.String())
+ }
}
func makeCookie(vals map[string]string) []*http.Cookie {
- var cookies []*http.Cookie
- for k, v := range vals {
- c := &http.Cookie{
- Name: k,
- Value: v,
- }
- cookies = append(cookies, c)
- }
- return cookies
+ var cookies []*http.Cookie
+ for k, v := range vals {
+ c := &http.Cookie{
+ Name: k,
+ Value: v,
+ }
+ cookies = append(cookies, c)
+ }
+ return cookies
}
func TestSecureCookie(t *testing.T) {
- mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd"
- resp1 := getTestResponse("POST", "/securecookie/set/a/1", "", nil, nil)
- sval, ok := resp1.cookies["a"]
- if !ok {
- t.Fatalf("Failed to get cookie ")
- }
- cookies := makeCookie(map[string]string{"a": sval})
+ mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd"
+ mainServer.initServer()
+ resp1 := getTestResponse("POST", "/securecookie/set/a/1", "", nil, nil)
+ sval, ok := resp1.cookies["a"]
+ if !ok {
+ t.Fatalf("Failed to get cookie ")
+ }
+ cookies := makeCookie(map[string]string{"a": sval})
+
+ resp2 := getTestResponse("GET", "/securecookie/get/a", "", nil, cookies)
+
+ if resp2.body != "1" {
+ t.Fatalf("SecureCookie test failed")
+ }
+}
+
+func TestEmptySecureCookie(t *testing.T) {
+ mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd"
+ cookies := makeCookie(map[string]string{"empty": ""})
- resp2 := getTestResponse("GET", "/securecookie/get/a", "", nil, cookies)
+ resp2 := getTestResponse("GET", "/securecookie/get/empty", "", nil, cookies)
- if resp2.body != "1" {
- t.Fatalf("SecureCookie test failed")
- }
+ if resp2.body != "" {
+ t.Fatalf("Expected an empty secure cookie")
+ }
}
func TestEarlyClose(t *testing.T) {
- var server1 Server
- server1.Close()
+ var server1 Server
+ server1.Close()
}
func TestOptions(t *testing.T) {
- resp := getTestResponse("OPTIONS", "/options", "", nil, nil)
- if resp.headers["Access-Control-Allow-Methods"][0] != "POST, GET, OPTIONS" {
- t.Fatalf("TestOptions - Access-Control-Allow-Methods failed")
- }
- if resp.headers["Access-Control-Max-Age"][0] != "1000" {
- t.Fatalf("TestOptions - Access-Control-Max-Age failed")
- }
+ resp := getTestResponse("OPTIONS", "/options", "", nil, nil)
+ if resp.headers["Access-Control-Allow-Methods"][0] != "POST, GET, OPTIONS" {
+ t.Fatalf("TestOptions - Access-Control-Allow-Methods failed")
+ }
+ if resp.headers["Access-Control-Max-Age"][0] != "1000" {
+ t.Fatalf("TestOptions - Access-Control-Max-Age failed")
+ }
}
func TestSlug(t *testing.T) {
- tests := [][]string{
- {"", ""},
- {"a", "a"},
- {"a/b", "a-b"},
- {"a b", "a-b"},
- {"a////b", "a-b"},
- {" a////b ", "a-b"},
- {" Manowar / Friends ", "manowar-friends"},
- }
-
- for _, test := range tests {
- v := Slug(test[0], "-")
- if v != test[1] {
- t.Fatalf("TestSlug(%v) failed, expected %v, got %v", test[0], test[1], v)
- }
- }
+ tests := [][]string{
+ {"", ""},
+ {"a", "a"},
+ {"a/b", "a-b"},
+ {"a b", "a-b"},
+ {"a////b", "a-b"},
+ {" a////b ", "a-b"},
+ {" Manowar / Friends ", "manowar-friends"},
+ }
+
+ for _, test := range tests {
+ v := Slug(test[0], "-")
+ if v != test[1] {
+ t.Fatalf("TestSlug(%v) failed, expected %v, got %v", test[0], test[1], v)
+ }
+ }
}
// tests that we don't duplicate headers
func TestDuplicateHeader(t *testing.T) {
- resp := testGet("/dupeheader", nil)
- if len(resp.headers["Server"]) > 1 {
- t.Fatalf("Expected only one header, got %#v", resp.headers["Server"])
- }
- if resp.headers["Server"][0] != "myserver" {
- t.Fatalf("Incorrect header, exp 'myserver', got %q", resp.headers["Server"][0])
- }
+ resp := testGet("/dupeheader", nil)
+ if len(resp.headers["Server"]) > 1 {
+ t.Fatalf("Expected only one header, got %#v", resp.headers["Server"])
+ }
+ if resp.headers["Server"][0] != "myserver" {
+ t.Fatalf("Incorrect header, exp 'myserver', got %q", resp.headers["Server"][0])
+ }
+}
+
+// test that output contains ASCII color codes by default
+func TestColorOutputDefault(t *testing.T) {
+ s := NewServer()
+ var logOutput bytes.Buffer
+ logger := log.New(&logOutput, "", 0)
+ s.Logger = logger
+ s.Get("/test", func() string {
+ return "test"
+ })
+ req := buildTestRequest("GET", "/test", "", nil, nil)
+ var buf bytes.Buffer
+ iob := ioBuffer{input: nil, output: &buf}
+ c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob}
+ s.Process(&c, req)
+ if !strings.Contains(logOutput.String(), "\x1b") {
+ t.Fatalf("The default log output does not seem to be colored")
+ }
+}
+
+// test that output contains ASCII color codes by default
+func TestNoColorOutput(t *testing.T) {
+ s := NewServer()
+ s.Config.ColorOutput = false
+ var logOutput bytes.Buffer
+ logger := log.New(&logOutput, "", 0)
+ s.Logger = logger
+ s.Get("/test", func() string {
+ return "test"
+ })
+ req := buildTestRequest("GET", "/test", "", nil, nil)
+ var buf bytes.Buffer
+ iob := ioBuffer{input: nil, output: &buf}
+ c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob}
+ s.Process(&c, req)
+ if strings.Contains(logOutput.String(), "\x1b") {
+ t.Fatalf("The log contains color escape codes")
+ }
+}
+
+// a malformed SCGI request should be discarded and not cause a panic
+func TestMaformedScgiRequest(t *testing.T) {
+ var headerBuf bytes.Buffer
+
+ headerBuf.WriteString("CONTENT_LENGTH")
+ headerBuf.WriteByte(0)
+ headerBuf.WriteString("0")
+ headerBuf.WriteByte(0)
+ headerData := headerBuf.Bytes()
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "%d:", len(headerData))
+ buf.Write(headerData)
+ buf.WriteByte(',')
+
+ var output bytes.Buffer
+ nb := ioBuffer{input: &buf, output: &output}
+ mainServer.handleScgiRequest(&nb)
+ if !nb.closed {
+ t.Fatalf("The connection should have been closed")
+ }
+}
+
+type TestHandler struct{}
+
+func (t *TestHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) {
+}
+
+// When a custom HTTP handler is used, the Content-Type header should not be set to a default.
+// Go's FileHandler does not replace the Content-Type header if it is already set.
+func TestCustomHandlerContentType(t *testing.T) {
+ s := NewServer()
+ s.SetLogger(log.New(ioutil.Discard, "", 0))
+ s.Handle("/testHandler", "GET", &TestHandler{})
+ req := buildTestRequest("GET", "/testHandler", "", nil, nil)
+ c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: nil}
+ s.Process(&c, req)
+ if c.headers["Content-Type"] != nil {
+ t.Fatalf("A default Content-Type should not be present when using a custom HTTP handler")
+ }
}
func BuildBasicAuthCredentials(user string, pass string) string {
- s := user + ":" + pass
- return "Basic " + base64.StdEncoding.EncodeToString([]byte(s))
+ s := user + ":" + pass
+ return "Basic " + base64.StdEncoding.EncodeToString([]byte(s))
}
func BenchmarkProcessGet(b *testing.B) {
- s := NewServer()
- s.SetLogger(log.New(ioutil.Discard, "", 0))
- s.Get("/echo/(.*)", func(s string) string {
- return s
- })
- req := buildTestRequest("GET", "/echo/hi", "", nil, nil)
- var buf bytes.Buffer
- iob := ioBuffer{input: nil, output: &buf}
- c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob}
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- s.Process(&c, req)
- }
+ s := NewServer()
+ s.SetLogger(log.New(ioutil.Discard, "", 0))
+ s.Get("/echo/(.*)", func(s string) string {
+ return s
+ })
+ req := buildTestRequest("GET", "/echo/hi", "", nil, nil)
+ var buf bytes.Buffer
+ iob := ioBuffer{input: nil, output: &buf}
+ c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob}
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s.Process(&c, req)
+ }
}
func BenchmarkProcessPost(b *testing.B) {
- s := NewServer()
- s.SetLogger(log.New(ioutil.Discard, "", 0))
- s.Post("/echo/(.*)", func(s string) string {
- return s
- })
- req := buildTestRequest("POST", "/echo/hi", "", nil, nil)
- var buf bytes.Buffer
- iob := ioBuffer{input: nil, output: &buf}
- c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob}
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- s.Process(&c, req)
- }
+ s := NewServer()
+ s.SetLogger(log.New(ioutil.Discard, "", 0))
+ s.Post("/echo/(.*)", func(s string) string {
+ return s
+ })
+ req := buildTestRequest("POST", "/echo/hi", "", nil, nil)
+ var buf bytes.Buffer
+ iob := ioBuffer{input: nil, output: &buf}
+ c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob}
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s.Process(&c, req)
+ }
}