diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 00000000..ebd12c1c --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,70 @@ +version: 2.1 + +jobs: + "test": + parameters: + version: + type: string + default: "latest" + golint: + type: boolean + default: true + modules: + type: boolean + default: true + goproxy: + type: string + default: "" + docker: + - image: "cimg/go:<< parameters.version >>" + working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket + environment: + GO111MODULE: "on" + GOPROXY: "<< parameters.goproxy >>" + steps: + - checkout + - run: + name: "Print the Go version" + command: > + go version + - run: + name: "Fetch dependencies" + command: > + if [[ << parameters.modules >> = true ]]; then + go mod download + export GO111MODULE=on + else + go get -v ./... + fi + # Only run gofmt, vet & lint against the latest Go version + - run: + name: "Run golint" + command: > + if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then + go get -u golang.org/x/lint/golint + golint ./... + fi + - run: + name: "Run gofmt" + command: > + if [[ << parameters.version >> = "latest" ]]; then + diff -u <(echo -n) <(gofmt -d -e .) + fi + - run: + name: "Run go vet" + command: > + if [[ << parameters.version >> = "latest" ]]; then + go vet -v ./... + fi + - run: + name: "Run go test (+ race detector)" + command: > + go test -v -race ./... + +workflows: + tests: + jobs: + - test: + matrix: + parameters: + version: ["1.22", "1.21", "1.20"] diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml new file mode 100644 index 00000000..0986b3e0 --- /dev/null +++ b/.github/release-drafter.yml @@ -0,0 +1,7 @@ +# Config for https://github.com/apps/release-drafter +template: | + + + + ## CHANGELOG + $CHANGES diff --git a/.gitignore b/.gitignore index ac710204..cd3fcd1e 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,4 @@ _testmain.go *.exe .idea/ -*.iml \ No newline at end of file +*.iml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 3d8d29cf..00000000 --- a/.travis.yml +++ /dev/null @@ -1,19 +0,0 @@ -language: go -sudo: false - -matrix: - include: - - go: 1.4 - - go: 1.5 - - go: 1.6 - - go: 1.7 - - go: 1.8 - - go: tip - allow_failures: - - go: tip - -script: - - go get -t -v ./... - - diff -u <(echo -n) <(gofmt -d .) - - go vet $(go list ./... | grep -v /vendor/) - - go test -v -race ./... diff --git a/AUTHORS b/AUTHORS index b003eca0..1931f400 100644 --- a/AUTHORS +++ b/AUTHORS @@ -4,5 +4,6 @@ # Please keep the list sorted. Gary Burd +Google LLC (https://opensource.google.com/) Joachim Bauch diff --git a/README.md b/README.md index 33c3d2be..ff8bfab0 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,19 @@ # Gorilla WebSocket +[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) +[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) + Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. -[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket) -[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) ### Documentation -* [API Reference](http://godoc.org/github.com/gorilla/websocket) -* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) -* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) -* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) -* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) +* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) +* [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat) +* [Command example](https://github.com/gorilla/websocket/tree/main/examples/command) +* [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo) +* [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch) ### Status @@ -27,38 +28,5 @@ package API is stable. ### Protocol Compliance The Gorilla WebSocket package passes the server tests in the [Autobahn Test -Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn -subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). - -### Gorilla WebSocket compared with other packages - - - - - - - - - - - - - - - - - - -
github.com/gorillagolang.org/x/net
RFC 6455 Features
Passes Autobahn Test SuiteYesNo
Receive fragmented messageYesNo, see note 1
Send close messageYesNo
Send pings and receive pongsYesNo
Get the type of a received data messageYesYes, see note 2
Other Features
Compression ExtensionsExperimentalNo
Read message using io.ReaderYesNo, see note 3
Write message using io.WriteCloserYesNo, see note 3
- -Notes: - -1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). -2. The application can get the type of a received data message by implementing - a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal) - function. -3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries. - Read returns when the input buffer is full or a frame boundary is - encountered. Each call to Write sends a single frame message. The Gorilla - io.Reader and io.WriteCloser operate on a single WebSocket message. - +Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn +subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn). diff --git a/client.go b/client.go index 43a87c75..00917ea3 100644 --- a/client.go +++ b/client.go @@ -5,15 +5,15 @@ package websocket import ( - "bufio" "bytes" + "context" "crypto/tls" - "encoding/base64" "errors" + "fmt" "io" - "io/ioutil" "net" "net/http" + "net/http/httptrace" "net/url" "strings" "time" @@ -48,11 +48,39 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS } // A Dialer contains options for connecting to WebSocket server. +// +// It is safe to call Dialer's methods concurrently. type Dialer struct { + // The following custom dial functions can be set to establish + // connections to either the backend server or the proxy (if it + // exists). The scheme of the dialed entity (either backend or + // proxy) determines which custom dial function is selected: + // either NetDialTLSContext for HTTPS or NetDialContext/NetDial + // for HTTP. Since the "Proxy" function can determine the scheme + // dynamically, it can make sense to set multiple custom dial + // functions simultaneously. + // // NetDial specifies the dial function for creating TCP connections. If - // NetDial is nil, net.Dial is used. + // NetDial is nil, net.Dialer DialContext is used. + // If "Proxy" field is also set, this function dials the proxy--not + // the backend server. NetDial func(network, addr string) (net.Conn, error) + // NetDialContext specifies the dial function for creating TCP connections. If + // NetDialContext is nil, NetDial is used. + // If "Proxy" field is also set, this function dials the proxy--not + // the backend server. + NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If + // NetDialTLSContext is nil, NetDialContext is used. + // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and + // TLSClientConfig is ignored. + // If "Proxy" field is also set, this function dials the proxy (and performs + // the TLS handshake with the proxy, ignoring TLSClientConfig). In this TLS proxy + // dialing case the TLSClientConfig could still be necessary for TLS to the backend server. + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -61,16 +89,29 @@ type Dialer struct { // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. + // If NetDialTLSContext is set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. TLSClientConfig *tls.Config // HandshakeTimeout specifies the duration for the handshake to complete. HandshakeTimeout time.Duration - // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer // size is zero, then a useful default size is used. The I/O buffer sizes // do not limit the size of the messages that can be sent or received. ReadBufferSize, WriteBufferSize int + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + // Subprotocols specifies the client's requested subprotocols. Subprotocols []string @@ -86,52 +127,13 @@ type Dialer struct { Jar http.CookieJar } -var errMalformedURL = errors.New("malformed ws or wss URL") - -// parseURL parses the URL. -// -// This function is a replacement for the standard library url.Parse function. -// In Go 1.4 and earlier, url.Parse loses information from the path. -func parseURL(s string) (*url.URL, error) { - // From the RFC: - // - // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] - // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] - var u url.URL - switch { - case strings.HasPrefix(s, "ws://"): - u.Scheme = "ws" - s = s[len("ws://"):] - case strings.HasPrefix(s, "wss://"): - u.Scheme = "wss" - s = s[len("wss://"):] - default: - return nil, errMalformedURL - } - - if i := strings.Index(s, "?"); i >= 0 { - u.RawQuery = s[i+1:] - s = s[:i] - } - - if i := strings.Index(s, "/"); i >= 0 { - u.Opaque = s[i:] - s = s[:i] - } else { - u.Opaque = "/" - } - - u.Host = s - - if strings.Contains(u.Host, "@") { - // Don't bother parsing user information because user information is - // not allowed in websocket URIs. - return nil, errMalformedURL - } - - return &u, nil +// Dial creates a new client connection by calling DialContext with a background context. +func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + return d.DialContext(context.Background(), urlStr, requestHeader) } +var errMalformedURL = errors.New("malformed ws or wss URL") + func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { hostPort = u.Host hostNoPort = u.Host @@ -150,26 +152,29 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { return hostPort, hostNoPort } -// DefaultDialer is a dialer with all fields set to the default zero values. +// DefaultDialer is a dialer with all fields set to the default values. var DefaultDialer = &Dialer{ - Proxy: http.ProxyFromEnvironment, + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, } -// Dial creates a new client connection. Use requestHeader to specify the +// nilDialer is dialer to use when receiver is nil. +var nilDialer = *DefaultDialer + +// DialContext creates a new client connection. Use requestHeader to specify the // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // Use the response.Header to get the selected subprotocol // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). // +// The context will be used in the request and in the Dialer. +// // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, // etcetera. The response body may not contain the entire response and does not // need to be closed by the application. -func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { - +func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { if d == nil { - d = &Dialer{ - Proxy: http.ProxyFromEnvironment, - } + d = &nilDialer } challengeKey, err := generateChallengeKey() @@ -177,7 +182,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, nil, err } - u, err := parseURL(urlStr) + u, err := url.Parse(urlStr) if err != nil { return nil, nil, err } @@ -197,7 +202,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } req := &http.Request{ - Method: "GET", + Method: http.MethodGet, URL: u, Proto: "HTTP/1.1", ProtoMajor: 1, @@ -205,6 +210,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re Header: make(http.Header), Host: u.Host, } + req = req.WithContext(ctx) // Set the cookies present in the cookie jar of the dialer if d.Jar != nil { @@ -237,116 +243,111 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + case k == "Sec-Websocket-Protocol": + req.Header["Sec-WebSocket-Protocol"] = vs default: req.Header[k] = vs } } if d.EnableCompression { - req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} } - hostPort, hostNoPort := hostPortNoPort(u) + if d.HandshakeTimeout != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) + defer cancel() + } var proxyURL *url.URL - // Check wether the proxy method has been configured if d.Proxy != nil { proxyURL, err = d.Proxy(req) + if err != nil { + return nil, nil, err + } } + netDial, err := d.netDialFn(ctx, proxyURL, u) if err != nil { return nil, nil, err } - var targetHostPort string - if proxyURL != nil { - targetHostPort, _ = hostPortNoPort(proxyURL) - } else { - targetHostPort = hostPort - } - - var deadline time.Time - if d.HandshakeTimeout != 0 { - deadline = time.Now().Add(d.HandshakeTimeout) - } - - netDial := d.NetDial - if netDial == nil { - netDialer := &net.Dialer{Deadline: deadline} - netDial = netDialer.Dial + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) } - netConn, err := netDial("tcp", targetHostPort) + netConn, err := netDial(ctx, "tcp", hostPort) if err != nil { return nil, nil, err } + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: netConn, + }) + } + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. defer func() { if netConn != nil { - netConn.Close() + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() } }() - if err := netConn.SetDeadline(deadline); err != nil { - return nil, nil, err - } - - if proxyURL != nil { - connectHeader := make(http.Header) - if user := proxyURL.User; user != nil { - proxyUser := user.Username() - if proxyPassword, passwordSet := user.Password(); passwordSet { - credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) - connectHeader.Set("Proxy-Authorization", "Basic "+credential) - } - } - connectReq := &http.Request{ - Method: "CONNECT", - URL: &url.URL{Opaque: hostPort}, - Host: hostPort, - Header: connectHeader, - } - - connectReq.Write(netConn) + // Do TLS handshake over established connection if a proxy exists. + if proxyURL != nil && u.Scheme == "https" { - // Read response. - // Okay to use and discard buffered reader here, because - // TLS server will not speak until spoken to. - br := bufio.NewReader(netConn) - resp, err := http.ReadResponse(br, connectReq) - if err != nil { - return nil, nil, err - } - if resp.StatusCode != 200 { - f := strings.SplitN(resp.Status, " ", 2) - return nil, nil, errors.New(f[1]) - } - } - - if u.Scheme == "https" { cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = hostNoPort } tlsConn := tls.Client(netConn, cfg) netConn = tlsConn - if err := tlsConn.Handshake(); err != nil { - return nil, nil, err + + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - return nil, nil, err - } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + + if err != nil { + return nil, nil, err } } - conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) if err := req.Write(netConn); err != nil { return nil, nil, err } + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + resp, err := http.ReadResponse(conn.br, req) if err != nil { + if d.TLSClientConfig != nil { + for _, proto := range d.TLSClientConfig.NextProtos { + if proto != "http/1.1" { + return nil, nil, fmt.Errorf( + "websocket: protocol %q was given but is not supported;"+ + "sharing tls.Config with net/http Transport can cause this error: %w", + proto, err, + ) + } + } + } return nil, nil, err } @@ -357,15 +358,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application // debugging. buf := make([]byte, 1024) n, _ := io.ReadFull(resp.Body, buf) - resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) return nil, resp, ErrBadHandshake } @@ -383,10 +384,134 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re break } - resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + resp.Body = io.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - netConn.SetDeadline(time.Time{}) - netConn = nil // to avoid close in defer. + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, resp, err + } + + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + return conn, resp, nil } + +// Returns the dial function to establish the connection to either the backend +// server or the proxy (if it exists). If the dialed entity is HTTPS, then the +// returned dial function *also* performs the TLS handshake to the dialed entity. +// NOTE: If a proxy exists, it is possible for a second TLS handshake to be +// necessary over the established connection. +func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *url.URL) (netDialerFunc, error) { + var netDial netDialerFunc + if proxyURL != nil { + netDial = d.netDialFromURL(proxyURL) + } else { + netDial = d.netDialFromURL(backendURL) + } + // If needed, wrap the dial function to set the connection deadline. + if deadline, ok := ctx.Deadline(); ok { + netDial = netDialWithDeadline(netDial, deadline) + } + // Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth. + if proxyURL != nil { + return proxyFromURL(proxyURL, netDial) + } + return netDial, nil +} + +// Returns function to create the connection depending on the Dialer's +// custom dialing functions and the passed URL of entity connecting to. +func (d *Dialer) netDialFromURL(u *url.URL) netDialerFunc { + var netDial netDialerFunc + switch { + case d.NetDialContext != nil: + netDial = d.NetDialContext + case d.NetDial != nil: + netDial = func(ctx context.Context, net, addr string) (net.Conn, error) { + return d.NetDial(net, addr) + } + default: + netDial = (&net.Dialer{}).DialContext + } + // If dialed entity is HTTPS, then either use custom TLS dialing function (if exists) + // or wrap the previously computed "netDial" to use TLS config for handshake. + if u.Scheme == "https" { + if d.NetDialTLSContext != nil { + netDial = d.NetDialTLSContext + } else { + netDial = netDialWithTLSHandshake(netDial, d.TLSClientConfig, u) + } + } + return netDial +} + +// Returns wrapped "netDial" function, performing TLS handshake after connecting. +func netDialWithTLSHandshake(netDial netDialerFunc, tlsConfig *tls.Config, u *url.URL) netDialerFunc { + return func(ctx context.Context, unused, addr string) (net.Conn, error) { + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + // Creates TCP connection to addr using passed "netDial" function. + conn, err := netDial(ctx, "tcp", addr) + if err != nil { + return nil, err + } + cfg := cloneTLSConfig(tlsConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(conn, cfg) + // Do the TLS handshake using TLSConfig over the wrapped connection. + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err = doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + if err != nil { + tlsConn.Close() + return nil, err + } + return tlsConn, nil + } +} + +// Returns wrapped "netDial" function, setting passed deadline. +func netDialWithDeadline(netDial netDialerFunc, deadline time.Time) netDialerFunc { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := netDial(ctx, network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + c.Close() + return nil, err + } + return c, nil + } +} + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.HandshakeContext(ctx); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/client_clone.go b/client_clone.go deleted file mode 100644 index 4f0d9437..00000000 --- a/client_clone.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.8 - -package websocket - -import "crypto/tls" - -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return cfg.Clone() -} diff --git a/client_clone_legacy.go b/client_clone_legacy.go deleted file mode 100644 index babb007f..00000000 --- a/client_clone_legacy.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !go1.8 - -package websocket - -import "crypto/tls" - -// cloneTLSConfig clones all public fields except the fields -// SessionTicketsDisabled and SessionTicketKey. This avoids copying the -// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a -// config in active use. -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - } -} diff --git a/client_proxy_server_test.go b/client_proxy_server_test.go new file mode 100644 index 00000000..c8e6850f --- /dev/null +++ b/client_proxy_server_test.go @@ -0,0 +1,911 @@ +// Copyright 2025 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" +) + +// These test cases use a websocket client (Dialer)/proxy/websocket server (Upgrader) +// to validate the cases where a proxy is an intermediary between a websocket client +// and server. The test cases usually 1) create a websocket server which echoes any +// data received back to the client, 2) a basic duplex streaming proxy, and 3) a +// websocket client which sends random data to the server through the proxy, +// validating any subsequent data received is the same as the data sent. The various +// permutations include the proxy and backend schemes (HTTP or HTTPS), as well as +// the custom dial functions (e.g NetDialContext, NetDial) set on the Dialer. + +const ( + subprotocolV1 = "subprotocol-version-1" + subprotocolV2 = "subprotocol-version-2" +) + +// Permutation 1 +// +// Backend: HTTP +// Proxy: HTTP +func TestHTTPProxyAndBackend(t *testing.T) { + websocketTLS := false + proxyTLS := false + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + // Dial the websocket server through the proxy server. + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 2 +// +// Backend: HTTP +// Proxy: HTTP +// DialFn: NetDial (dials proxy) +func TestHTTPProxyWithNetDial(t *testing.T) { + websocketTLS := false + proxyTLS := false + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + // Dial the websocket server through the proxy server. + var netDialCalled atomic.Int64 + dialer := Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + netDialCalled.Add(1) + return (&net.Dialer{}).DialContext(context.Background(), network, addr) + }, + Proxy: http.ProxyURL(proxyServerURL), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), netDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 3 +// +// Backend: HTTP +// Proxy: HTTP +// DialFn: NetDialContext (dials proxy) +func TestHTTPProxyWithNetDialContext(t *testing.T) { + websocketTLS := false + proxyTLS := false + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + // Dial the websocket server through the proxy server. + var netDialCalled atomic.Int64 + dialer := Dialer{ + NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialCalled.Add(1) + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + Proxy: http.ProxyURL(proxyServerURL), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), netDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 4 +// +// Backend: HTTPS +// Proxy: HTTP +// DialFn: NetDialTLSConfig (set but *ignored*) +// TLS Config: set (used for backend TLS) +func TestHTTPProxyWithHTTPSBackend(t *testing.T) { + websocketTLS := true + proxyTLS := false + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + var netDialTLSCalled atomic.Int64 + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + // This function should be ignored, because an HTTP proxy exists + // and the backend TLS handshake should use TLSClientConfig. + NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialTLSCalled.Add(1) + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + // Used for the backend server TLS handshake. + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if numTLSDials := netDialTLSCalled.Load(); numTLSDials > 0 { + t.Errorf("NetDialTLS should have been ignored") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 5 +// +// Backend: HTTPS +// Proxy: HTTPS +// TLS Config: set (used for both proxy and backend TLS) +func TestHTTPSProxyAndBackend(t *testing.T) { + websocketTLS := true + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 6 +// +// Backend: HTTPS +// Proxy: HTTPS +// DialFn: NetDial (used to dial proxy) +// TLS Config: set (used for both proxy and backend TLS) +func TestHTTPSProxyUsingNetDial(t *testing.T) { + websocketTLS := true + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + var netDialCalled atomic.Int64 + dialer := Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + netDialCalled.Add(1) + return (&net.Dialer{}).DialContext(context.Background(), network, addr) + }, + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), netDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 7 +// +// Backend: HTTPS +// Proxy: HTTPS +// DialFn: NetDialContext (used to dial proxy) +// TLS Config: set (used for both proxy and backend TLS) +func TestHTTPSProxyUsingNetDialContext(t *testing.T) { + websocketTLS := true + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + var netDialCalled atomic.Int64 + dialer := Dialer{ + NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialCalled.Add(1) + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), netDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 8 +// +// Backend: HTTPS +// Proxy: HTTPS +// DialFn: NetDialTLSContext (used for proxy TLS) +// TLS Config: set (used for backend TLS) +func TestHTTPSProxyUsingNetDialTLSContext(t *testing.T) { + websocketTLS := true + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + // Configure the proxy dialing function which dials the proxy and + // performs the TLS handshake. + var proxyDialCalled atomic.Int64 + proxyCerts := x509.NewCertPool() + proxyCerts.AppendCertsFromPEM(proxyServerCert) + proxyTLSConfig := &tls.Config{RootCAs: proxyCerts} + proxyDial := func(ctx context.Context, network, addr string) (net.Conn, error) { + proxyDialCalled.Add(1) + return tls.Dial(network, addr, proxyTLSConfig) + } + // Configure the backend webscocket TLS configuration (handshake occurs + // over the previously created proxy connection). + websocketCerts := x509.NewCertPool() + websocketCerts.AppendCertsFromPEM(websocketServerCert) + websocketTLSConfig := &tls.Config{RootCAs: websocketCerts} + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + // Dial and TLS handshake function to proxy. + NetDialTLSContext: proxyDial, + // Used for second TLS handshake to backend server over previously + // established proxy connection. + TLSClientConfig: websocketTLSConfig, + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), proxyDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 9 +// +// Backend: HTTP +// Proxy: HTTPS +// TLS Config: set (used for proxy TLS) +func TestHTTPSProxyHTTPBackend(t *testing.T) { + websocketTLS := false + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 10 +// +// Backend: HTTP +// Proxy: HTTPS +// DialFn: NetDialTLSContext (used for proxy TLS) +// TLS Config: set (ignored) +func TestHTTPSProxyUsingNetDialTLSContextWithHTTPBackend(t *testing.T) { + websocketTLS := false + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + var proxyDialCalled atomic.Int64 + dialer := Dialer{ + NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + proxyDialCalled.Add(1) + return tls.Dial(network, addr, tlsConfig(websocketTLS, proxyTLS)) + }, + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: &tls.Config{}, // Misconfigured, but ignored. + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), proxyDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy not called") + } +} + +func TestTLSValidationErrors(t *testing.T) { + // Both websocket and proxy servers are started with TLS. + websocketTLS := true + proxyTLS := true + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + // Dialer without proxy CA cert fails TLS verification. + tlsError := "tls: failed to verify certificate" + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(true, false), + Subprotocols: []string{subprotocolV1}, + } + _, _, err = dialer.Dial(websocketURL.String(), nil) + if err == nil { + t.Errorf("expected proxy TLS verification error did not arrive") + } else if !strings.Contains(err.Error(), tlsError) { + t.Errorf("expected proxy TLS error (%s), got (%s)", err.Error(), tlsError) + } + // Validate the proxy handler was *NOT* called (because proxy + // server TLS validation failed). + if e, a := int64(0), proxyServer.numCalls(); e != a { + t.Errorf("proxy should not have been called") + } + // Dialer without websocket CA cert fails TLS verification. + dialer = Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(false, true), + Subprotocols: []string{subprotocolV1}, + } + _, _, err = dialer.Dial(websocketURL.String(), nil) + if err == nil { + t.Errorf("expected websocket TLS verification error did not arrive") + } else if !strings.Contains(err.Error(), tlsError) { + t.Errorf("expected websocket TLS error (%s), got (%s)", err.Error(), tlsError) + } + // Validate the proxy server *was* called (but subsequent + // websocket server failed TLS validation). + if e, a := int64(1), proxyServer.numCalls(); e != a { + t.Errorf("proxy have been called") + } +} + +func TestProxyFnErrorIsPropagated(t *testing.T) { + websocketServer, websocketURL, err := newWebsocketServer(false) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Create a Dialer where Proxy function always returns an error. + proxyURLError := errors.New("proxy URL generation error") + dialer := Dialer{ + Proxy: func(r *http.Request) (*url.URL, error) { + return nil, proxyURLError + }, + Subprotocols: []string{subprotocolV1}, + } + // Proxy URL generation error should halt request and be propagated. + _, _, err = dialer.Dial(websocketURL.String(), nil) + if err == nil { + t.Fatalf("expected websocket dial error, received none") + } else if !errors.Is(proxyURLError, err) { + t.Fatalf("expected error (%s), got (%s)", proxyURLError, err) + } +} + +func TestProxyFnNilMeansNoProxy(t *testing.T) { + // Both websocket and proxy servers are started. + websocketTLS := false + proxyTLS := false + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + proxyServer, _, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + // Dialer created with Proxy URL generation function returning nil + // proxy URL, which continues with backend server connection without + // proxying. + dialer := Dialer{ + Proxy: func(r *http.Request) (*url.URL, error) { + return nil, nil + }, + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + sendReceiveData(t, wsClient) + // Validate the proxy handler was *NOT* called (because proxy + // URL generation returned nil). + if e, a := int64(0), proxyServer.numCalls(); e != a { + t.Errorf("proxy should not have been called") + } +} + +// "counter" interface can be implemented by a server to keep track +// of the number of times a handler was called, as well as "Close". +type counter interface { + increment() + numCalls() int64 + closer +} + +type closer interface { + Close() +} + +// testServer implements "counter" interface. +type testServer struct { + server *httptest.Server + numHandled atomic.Int64 +} + +func (ts *testServer) numCalls() int64 { + return ts.numHandled.Load() +} + +func (ts *testServer) increment() { + ts.numHandled.Add(1) +} + +func (ts *testServer) Close() { + if ts.server != nil { + ts.server.Close() + } +} + +// websocketEchoHandler upgrades the connection associated with the request, and +// echoes binary messages read off the websocket connection back to the client. +var websocketEchoHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + upgrader := Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Accepting all requests + }, + Subprotocols: []string{ + subprotocolV1, + subprotocolV2, + }, + } + wsConn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + defer wsConn.Close() + for { + writer, err := wsConn.NextWriter(BinaryMessage) + if err != nil { + break + } + messageType, reader, err := wsConn.NextReader() + if err != nil { + break + } + if messageType != BinaryMessage { + http.Error(w, "websocket reader not binary message type", + http.StatusInternalServerError) + } + _, err = io.Copy(writer, reader) + if err != nil { + http.Error(w, "websocket server io copy error", + http.StatusInternalServerError) + } + } +}) + +// Returns a test backend websocket server as well as the URL pointing +// to the server, or an error if one occurred. Sets up a TLS endpoint +// on the server if the passed "tlsServer" is true. +// func newWebsocketServer(tlsServer bool) (*httptest.Server, *url.URL, error) { +func newWebsocketServer(tlsServer bool) (closer, *url.URL, error) { + // Start the websocket server, which echoes data back to sender. + websocketServer := httptest.NewUnstartedServer(websocketEchoHandler) + if tlsServer { + websocketKeyPair, err := tls.X509KeyPair(websocketServerCert, websocketServerKey) + if err != nil { + return nil, nil, err + } + websocketServer.TLS = &tls.Config{ + Certificates: []tls.Certificate{websocketKeyPair}, + } + websocketServer.StartTLS() + } else { + websocketServer.Start() + } + websocketURL, err := url.Parse(websocketServer.URL) + if err != nil { + return nil, nil, err + } + if tlsServer { + websocketURL.Scheme = "wss" + } else { + websocketURL.Scheme = "ws" + } + return websocketServer, websocketURL, nil +} + +// proxyHandler creates a full duplex streaming connection between the client +// (hijacking the http request connection), and an "upstream" dialed connection +// to the "Host". Creates two goroutines to copy between connections in each direction. +var proxyHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Validate the CONNECT method. + if req.Method != http.MethodConnect { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + // Dial upstream server. + upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer upstream.Close() + // Return 200 OK to client. + w.WriteHeader(http.StatusOK) + // Hijack client connection. + client, _, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer client.Close() + // Create duplex streaming between client and upstream connections. + done := make(chan struct{}, 2) + go func() { + _, _ = io.Copy(upstream, client) + done <- struct{}{} + }() + go func() { + _, _ = io.Copy(client, upstream) + done <- struct{}{} + }() + <-done +}) + +// Returns a new test HTTP server, as well as the URL to that server, or +// an error if one occurred. numProxyCalls keeps track of the number of +// times the proxy handler was called with this server. +func newProxyServer(tlsServer bool) (counter, *url.URL, error) { + // Start the proxy server, keeping track of how many times the handler is called. + ts := &testServer{} + proxyServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ts.increment() + proxyHandler.ServeHTTP(w, req) + })) + if tlsServer { + proxyKeyPair, err := tls.X509KeyPair(proxyServerCert, proxyServerKey) + if err != nil { + return nil, nil, err + } + proxyServer.TLS = &tls.Config{ + Certificates: []tls.Certificate{proxyKeyPair}, + } + proxyServer.StartTLS() + } else { + proxyServer.Start() + } + proxyURL, err := url.Parse(proxyServer.URL) + if err != nil { + return nil, nil, err + } + return ts, proxyURL, nil +} + +// Returns the TLS config with the RootCAs cert pool set. If +// neither websocket nor proxy server uses TLS, returns nil. +func tlsConfig(websocketTLS bool, proxyTLS bool) *tls.Config { + if !websocketTLS && !proxyTLS { + return nil + } + certPool := x509.NewCertPool() + tlsConfig := &tls.Config{ + RootCAs: certPool, + } + if websocketTLS { + tlsConfig.RootCAs.AppendCertsFromPEM(websocketServerCert) + } + if proxyTLS { + tlsConfig.RootCAs.AppendCertsFromPEM(proxyServerCert) + } + return tlsConfig +} + +// Sends, receives, and validates random data sent and received +// over the passed websocket connection. +const randomDataSize = 128 * 1024 + +func sendReceiveData(t *testing.T, wsConn *Conn) { + // Create the random data. + randomData := make([]byte, randomDataSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + // Send the random data. + err := wsConn.WriteMessage(BinaryMessage, randomData) + if err != nil { + t.Errorf("websocket write error: %v", err) + } + // Read from the websocket connection, and validate the + // read data is the same as the previously sent data. + _, received, err := wsConn.ReadMessage() + if !bytes.Equal(randomData, received) { + t.Errorf("unexpected data received: %d bytes sent, %d bytes received", + len(received), len(randomData)) + } +} + +// proxyServerCert was generated from crypto/tls/generate_cert.go with the following command: +// +// go run generate_cert.go --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +// +// proxyServerCert is a self-signed. +var proxyServerCert = []byte(`-----BEGIN CERTIFICATE----- +MIIDGTCCAgGgAwIBAgIRALL5AZcefF4kkYV1SEG6YrMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 +MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBALQ/FHcyVwdFHxARbbD2KBtDUT7Eni+8ioNdjtGcmtXqBv45EC1C +JOqqGJTroFGJ6Q9kQIZ9FqH5IJR2fOOJD9kOTueG4Vt1JY1rj1Kbpjefu8XleZ5L +SBwIWVnN/lEsEbuKmj7N2gLt5AH3zMZiBI1mg1u9Z5ZZHYbCiTpBrwsq6cTlvR9g +dyo1YkM5hRESCzsrL0aUByoo0qRMD8ZsgANJwgsiO0/M6idbxDwv1BnGwGmRYvOE +Hxpy3v0Jg7GJYrvnpnifJTs4nw91N5X9pXxR7FFzi/6HTYDWRljvTb0w6XciKYAz +bWZ0+cJr5F7wB7ovlbm7HrQIR7z7EIIu2d8CAwEAAaNoMGYwDgYDVR0PAQH/BAQD +AgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wLgYDVR0R +BCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAAAAAAAAEwDQYJKoZI +hvcNAQELBQADggEBAFPPWopNEJtIA2VFAQcqN6uJK+JVFOnjGRoCrM6Xgzdm0wxY +XCGjsxY5dl+V7KzdGqu858rCaq5osEBqypBpYAnS9C38VyCDA1vPS1PsN8SYv48z +DyBwj+7R2qar0ADBhnhWxvYO9M72lN/wuCqFKYMeFSnJdQLv3AsrrHe9lYqOa36s +8wxSwVTFTYXBzljPEnSaaJMPqFD8JXaZK1ryJPkO5OsCNQNGtatNiWAf3DcmwHAT +MGYMzP0u4nw47aRz9shB8w+taPKHx2BVwE1m/yp3nHVioOjXqA1fwRQVGclCJSH1 +D2iq3hWVHRENgjTjANBPICLo9AZ4JfN6PH19mnU= +-----END CERTIFICATE-----`) + +// proxyServerKey is the private key for proxyServerCert. +var proxyServerKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAtD8UdzJXB0UfEBFtsPYoG0NRPsSeL7yKg12O0Zya1eoG/jkQ +LUIk6qoYlOugUYnpD2RAhn0WofkglHZ844kP2Q5O54bhW3UljWuPUpumN5+7xeV5 +nktIHAhZWc3+USwRu4qaPs3aAu3kAffMxmIEjWaDW71nllkdhsKJOkGvCyrpxOW9 +H2B3KjViQzmFERILOysvRpQHKijSpEwPxmyAA0nCCyI7T8zqJ1vEPC/UGcbAaZFi +84QfGnLe/QmDsYliu+emeJ8lOzifD3U3lf2lfFHsUXOL/odNgNZGWO9NvTDpdyIp +gDNtZnT5wmvkXvAHui+VubsetAhHvPsQgi7Z3wIDAQABAoIBAGmw93IxjYCQ0ncc +kSKMJNZfsdtJdaxuNRZ0nNNirhQzR2h403iGaZlEpmdkhzxozsWcto1l+gh+SdFk +bTUK4MUZM8FlgO2dEqkLYh5BcMT7ICMZvSfJ4v21E5eqR68XVUqQKoQbNvQyxFk3 +EddeEGdNrkb0GDK8DKlBlzAW5ep4gjG85wSTjR+J+muUv3R0BgLBFSuQnIDM/IMB +LWqsja/QbtB7yppe7jL5u8UCFdZG8BBKT9fcvFIu5PRLO3MO0uOI7LTc8+W1Xm23 +uv+j3SY0+v+6POjK0UlJFFi/wkSPTFIfrQO1qFBkTDQHhQ6q/7GnILYYOiGbIRg2 +NNuP52ECgYEAzXEoy50wSYh8xfFaBuxbm3ruuG2W49jgop7ZfoFrPWwOQKAZS441 +VIwV4+e5IcA6KkuYbtGSdTYqK1SMkgnUyD/VevwAqH5TJoEIGu0pDuKGwVuwqioZ +frCIAV5GllKyUJ55VZNbRr2vY2fCsWbaCSCHETn6C16DNuTCe5C0JBECgYEA4JqY +5GpNbMG8fOt4H7hU0Fbm2yd6SHJcQ3/9iimef7xG6ajxsYrIhg1ft+3IPHMjVI0+ +9brwHDnWg4bOOx/VO4VJBt6Dm/F33bndnZRkuIjfSNpLM51P+EnRdaFVHOJHwKqx +uF69kihifCAG7YATgCveeXImzBUSyZUz9UrETu8CgYARNBimdFNG1RcdvEg9rC0/ +p9u1tfecvNySwZqU7WF9kz7eSonTueTdX521qAHowaAdSpdJMGODTTXaywm6cPhQ +jIfj9JZZhbqQzt1O4+08Qdvm9TamCUB5S28YLjza+bHU7nBaqixKkDfPqzCyilpX +yVGGL8SwjwmN3zop/sQXAQKBgC0JMsESQ6YcDsRpnrOVjYQc+LtW5iEitTdfsaID +iGGKihmOI7B66IxgoCHMTws39wycKdSyADVYr5e97xpR3rrJlgQHmBIrz+Iow7Q2 +LiAGaec8xjl6QK/DdXmFuQBKqyKJ14rljFODP4QuE9WJid94bGqjpf3j99ltznZP +4J8HAoGAJb4eb4lu4UGwifDzqfAPzLGCoi0fE1/hSx34lfuLcc1G+LEu9YDKoOVJ +9suOh0b5K/bfEy9KrVMBBriduvdaERSD8S3pkIQaitIz0B029AbE4FLFf9lKQpP2 +KR8NJEkK99Vh/tew6jAMll70xFrE7aF8VLXJVE7w4sQzuvHxl9Q= +-----END RSA PRIVATE KEY----- +`) + +// websocketServerCert is self-signed. +var websocketServerCert = []byte(`-----BEGIN CERTIFICATE----- +MIIDOTCCAiGgAwIBAgIQYSN1VY/favsLUo+B7gJ5tTANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEApBlintjkL1fO1Sk2pzNvl862CtTwU7/Jy6EZqWzI17wEbPn4sbSD +bHhfDlPl2nmw3hVkc6LNK+eqzm2GX/ai4tgMiaH7kyyNit1K3g7y7GISMf9poWIa +POJhid2wmhKHbEtHECSdQ5c/jEN1UVzB4go5LO7MEEVo9kyQ+yBqS6gISyFmfaT4 +qOsPJBir33bBpptSend1JSXaRTXqRa1p+oudw2ILa4U7KfuKK3emp21m5/HYAuSf +CV4WqqDoDiBPMpsQ0kPEPugWZKFeF3qanmqFFvptYx+zJbOznWYY2D3idWsvcg6q +VLPEB19oXaVBV0HXPFtObm5m1jCpl8FI1wIDAQABo4GIMIGFMA4GA1UdDwEB/wQE +AwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1Ud +DgQWBBQcSkjqA9rgos1daegNj49BpRCA0jAuBgNVHREEJzAlggtleGFtcGxlLmNv +bYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEAnk9i +9rogNTi9B1pn+Fbk3WALKdEjv/uyePsTnwdyvswVbeYbQweU9TrhYT2+eXbMA5kY +7TaQm46idRqxCKMgc3Ip3DADJdm8cJX9p2ExU4fKdkPc1KD/J+4QHHx1W2Ml5S2o +foOo6j1F0UdZP/rBj0UumEZp32qW+4DhVV/QQjUB8J0gaDC7yZBMdyMIeClR0RqE +YfZdCJbQHqtTwBXN+imQUHPGmksYkRDpFRvw/4crpcMIE04mVVd99nOpFCQnK61t +9US1y17VW1lYpkqlCS+rkcAtor4Z5naSf9/oLGCxEAwyW0pwHGO6MXtMxvB/JD20 +hJdlz1I7wlSfF4MiRQ== +-----END CERTIFICATE-----`) + +// websocketServerKey is the private key for websocketServerCert. +var websocketServerKey = []byte(`-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCkGWKe2OQvV87V +KTanM2+XzrYK1PBTv8nLoRmpbMjXvARs+fixtINseF8OU+XaebDeFWRzos0r56rO +bYZf9qLi2AyJofuTLI2K3UreDvLsYhIx/2mhYho84mGJ3bCaEodsS0cQJJ1Dlz+M +Q3VRXMHiCjks7swQRWj2TJD7IGpLqAhLIWZ9pPio6w8kGKvfdsGmm1J6d3UlJdpF +NepFrWn6i53DYgtrhTsp+4ord6anbWbn8dgC5J8JXhaqoOgOIE8ymxDSQ8Q+6BZk +oV4XepqeaoUW+m1jH7Mls7OdZhjYPeJ1ay9yDqpUs8QHX2hdpUFXQdc8W05ubmbW +MKmXwUjXAgMBAAECggEAE6BkTDgH//rnkP/Ej/Y17Zkv6qxnMLe/4evwZB7PsrBu +cxOUAWUOpvA1UO215bh87+2XvcDbUISnyC1kpKDyAGGeC5llER2DXE11VokWgtvZ +Q0OXavw5w83A+WVGFFdiUmXP0l10CxEm7OwQjFz6D21GQ1qC65tG9NZZghTxbFTe +iZKqgWqyHsaAWLOuDQbj1FTEBMFrY8f9RbclSh0luPZnzGc4BVI/t34jKPZBpH2N +NCkr8aB7MMHGhrNZFHAu/KAvq8UBrDTX+O8ERMwcwQWB4nne2+GOTN0MdcAUc72i +GryzIa8TgO+TpQOYoZ4NPnzFrsa+m3G2Tug3vbt62QKBgQDOPfM4/5/x/h/ggxQn +aRvEOC+8ldeqEOS1VTGiuDKJMWXrNkG+d+AsxfNP4k0QVNrpEAZSYcf0gnS9Odcl +luEsi/yPZDDnPg/cS+Z3336VKsggly7BWFs1Ct/9I+ZfSCl88TkVpIfeCBC34XEb +0mFUq/RdLqXj/mVLbBfr+H8cEwKBgQDLsJUm8lkWFAPJ8UMto8xeUMGk44VukYwx ++oI6KhplFntiI0C1Dd9wrxyCjySlJcc0NFt6IPN84d7pI9LQSbiKXQ1jMvsBzd4G +EMtG8SHpIY/mMU+KzWLHYVFS0FA4PvXXvPRNLOXas7hbALZdLshVKd7aDlkQAb5C +KWFHeIFwrQKBgA8r5Xl67HQrwoKMge4IQF+l1nUj/LJo/boNI1KaBDWtaZbs7dcq +EFaa1TQ6LHsYEuZ0JFLpGIF3G0lUOOxt9fCF97VApIxON3J4LuMAkNo+RGyJUoos +isETJLkFbAv0TgD/6bga21fM9hXgwqZOSpSk9ZvpM5DbBO6QbA4SwJ77AoGAX7h1 +/z14XAW/2hDE7xfAnLn6plA9jj5b0cjVlhvfF44/IVlLuUnxrPS9wyUdpXZhbMkG +DBicFB3ZMVqiYTuju3ILLojwqGJkahlOTeJXe0VIaHbX2HS4bNXw76fxat07jsy/ +Sd1Fj0dR5YIqMRQhFNR+Y57Gf90x2cm0a2/X9GkCgYANawYx9bNfcX0HMVG7vktK +6/80omnoBM0JUxA+V7DxS8kr9Cj2Y/kcS+VHb4yyoSkDgnsSdnCr1ZTctcj828MJ +8AUwskAtEjPkHRXEgRRnEl2oJGD1TT5iwBNnuPAQDXwzkGCRYBnlfZNbILbOoSUz +m+VDcqT5XzcRADa/TLlEXA== +-----END PRIVATE KEY----- +`) diff --git a/client_server_test.go b/client_server_test.go index 7d39da68..e4546aea 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -5,17 +5,26 @@ package websocket import ( + "bufio" + "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/base64" + "encoding/binary" + "errors" + "fmt" "io" - "io/ioutil" + "log" + "net" "net/http" "net/http/cookiejar" "net/http/httptest" + "net/http/httptrace" "net/url" "reflect" "strings" + "sync" "testing" "time" ) @@ -31,16 +40,21 @@ var cstUpgrader = Upgrader{ } var cstDialer = Dialer{ - Subprotocols: []string{"p1", "p2"}, - ReadBufferSize: 1024, - WriteBufferSize: 1024, + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: 30 * time.Second, } -type cstHandler struct{ *testing.T } +type cstHandler struct { + *testing.T + s *cstServer +} type cstServer struct { - *httptest.Server - URL string + URL string + Server *httptest.Server + wg sync.WaitGroup } const ( @@ -49,9 +63,15 @@ const ( cstRequestURI = cstPath + "?" + cstRawQuery ) +func (s *cstServer) Close() { + s.Server.Close() + // Wait for handler functions to complete. + s.wg.Wait() +} + func newServer(t *testing.T) *cstServer { var s cstServer - s.Server = httptest.NewServer(cstHandler{t}) + s.Server = httptest.NewServer(cstHandler{T: t, s: &s}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s @@ -59,27 +79,33 @@ func newServer(t *testing.T) *cstServer { func newTLSServer(t *testing.T) *cstServer { var s cstServer - s.Server = httptest.NewTLSServer(cstHandler{t}) + s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s } func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Because tests wait for a response from a server, we are guaranteed that + // the wait group count is incremented before the test waits on the group + // in the call to (*cstServer).Close(). + t.s.wg.Add(1) + defer t.s.wg.Done() + if r.URL.Path != cstPath { t.Logf("path=%v, want %v", r.URL.Path, cstPath) - http.Error(w, "bad path", 400) + http.Error(w, "bad path", http.StatusBadRequest) return } if r.URL.RawQuery != cstRawQuery { t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) - http.Error(w, "bad path", 400) + http.Error(w, "bad path", http.StatusBadRequest) return } subprotos := Subprotocols(r) if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) - http.Error(w, "bad protocol", 400) + http.Error(w, "bad protocol", http.StatusBadRequest) return } ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) @@ -143,8 +169,9 @@ func TestProxyDial(t *testing.T) { s := newServer(t) defer s.Close() - surl, _ := url.Parse(s.URL) + surl, _ := url.Parse(s.Server.URL) + cstDialer := cstDialer // make local copy for modification on next line. cstDialer.Proxy = http.ProxyURL(surl) connect := false @@ -153,15 +180,15 @@ func TestProxyDial(t *testing.T) { // Capture the request Host header. s.Server.Config.Handler = http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - if r.Method == "CONNECT" { + if r.Method == http.MethodConnect { connect = true - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) return } if !connect { - t.Log("connect not recieved") - http.Error(w, "connect not recieved", 405) + t.Log("connect not received") + http.Error(w, "connect not received", http.StatusMethodNotAllowed) return } origHandler.ServeHTTP(w, r) @@ -173,16 +200,16 @@ func TestProxyDial(t *testing.T) { } defer ws.Close() sendRecv(t, ws) - - cstDialer.Proxy = http.ProxyFromEnvironment } func TestProxyAuthorizationDial(t *testing.T) { s := newServer(t) defer s.Close() - surl, _ := url.Parse(s.URL) + surl, _ := url.Parse(s.Server.URL) surl.User = url.UserPassword("username", "password") + + cstDialer := cstDialer // make local copy for modification on next line. cstDialer.Proxy = http.ProxyURL(surl) connect := false @@ -193,15 +220,15 @@ func TestProxyAuthorizationDial(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { proxyAuth := r.Header.Get("Proxy-Authorization") expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) - if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { + if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth { connect = true - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) return } if !connect { - t.Log("connect with proxy authorization not recieved") - http.Error(w, "connect with proxy authorization not recieved", 405) + t.Log("connect with proxy authorization not received") + http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed) return } origHandler.ServeHTTP(w, r) @@ -213,8 +240,6 @@ func TestProxyAuthorizationDial(t *testing.T) { } defer ws.Close() sendRecv(t, ws) - - cstDialer.Proxy = http.ProxyFromEnvironment } func TestDial(t *testing.T) { @@ -237,7 +262,7 @@ func TestDialCookieJar(t *testing.T) { d := cstDialer d.Jar = jar - u, _ := parseURL(s.URL) + u, _ := url.Parse(s.URL) switch u.Scheme { case "ws": @@ -246,7 +271,7 @@ func TestDialCookieJar(t *testing.T) { u.Scheme = "https" } - cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}} + cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}} d.Jar.SetCookies(u, cookies) ws, _, err := d.Dial(s.URL, nil) @@ -277,10 +302,7 @@ func TestDialCookieJar(t *testing.T) { sendRecv(t, ws) } -func TestDialTLS(t *testing.T) { - s := newTLSServer(t) - defer s.Close() - +func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { certs := x509.NewCertPool() for _, c := range s.TLS.Certificates { roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) @@ -291,9 +313,15 @@ func TestDialTLS(t *testing.T) { certs.AddCert(root) } } + return certs +} + +func TestDialTLS(t *testing.T) { + s := newTLSServer(t) + defer s.Close() d := cstDialer - d.TLSClientConfig = &tls.Config{RootCAs: certs} + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} ws, _, err := d.Dial(s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) @@ -302,43 +330,97 @@ func TestDialTLS(t *testing.T) { sendRecv(t, ws) } -func xTestDialTLSBadCert(t *testing.T) { - // This test is deactivated because of noisy logging from the net/http package. - s := newTLSServer(t) +func TestDialTimeout(t *testing.T) { + s := newServer(t) defer s.Close() - ws, _, err := cstDialer.Dial(s.URL, nil) + d := cstDialer + d.HandshakeTimeout = -1 + ws, _, err := d.Dial(s.URL, nil) if err == nil { ws.Close() t.Fatalf("Dial: nil") } } -func TestDialTLSNoVerify(t *testing.T) { - s := newTLSServer(t) +// requireDeadlineNetConn fails the current test when Read or Write are called +// with no deadline. +type requireDeadlineNetConn struct { + t *testing.T + c net.Conn + readDeadlineIsSet bool + writeDeadlineIsSet bool +} + +func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error { + c.writeDeadlineIsSet = !t.Equal(time.Time{}) + c.readDeadlineIsSet = c.writeDeadlineIsSet + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error { + c.readDeadlineIsSet = !t.Equal(time.Time{}) + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error { + c.writeDeadlineIsSet = !t.Equal(time.Time{}) + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) Write(p []byte) (int, error) { + if !c.writeDeadlineIsSet { + c.t.Fatalf("write with no deadline") + } + return c.c.Write(p) +} + +func (c *requireDeadlineNetConn) Read(p []byte) (int, error) { + if !c.readDeadlineIsSet { + c.t.Fatalf("read with no deadline") + } + return c.c.Read(p) +} + +func (c *requireDeadlineNetConn) Close() error { return c.c.Close() } +func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() } +func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } + +func TestHandshakeTimeout(t *testing.T) { + s := newServer(t) defer s.Close() d := cstDialer - d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + d.NetDial = func(n, a string) (net.Conn, error) { + c, err := net.Dial(n, a) + return &requireDeadlineNetConn{c: c, t: t}, err + } ws, _, err := d.Dial(s.URL, nil) if err != nil { - t.Fatalf("Dial: %v", err) + t.Fatal("Dial:", err) } - defer ws.Close() - sendRecv(t, ws) + ws.Close() } -func TestDialTimeout(t *testing.T) { +func TestHandshakeTimeoutInContext(t *testing.T) { s := newServer(t) defer s.Close() d := cstDialer - d.HandshakeTimeout = -1 - ws, _, err := d.Dial(s.URL, nil) - if err == nil { - ws.Close() - t.Fatalf("Dial: nil") + d.HandshakeTimeout = 0 + d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { + netDialer := &net.Dialer{} + c, err := netDialer.DialContext(ctx, n, a) + return &requireDeadlineNetConn{c: c, t: t}, err } + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) + defer cancel() + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatal("Dial:", err) + } + ws.Close() } func TestDialBadScheme(t *testing.T) { @@ -398,9 +480,17 @@ func TestBadMethod(t *testing.T) { })) defer s.Close() - resp, err := http.PostForm(s.URL, url.Values{}) + req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader("")) if err != nil { - t.Fatalf("PostForm returned error %v", err) + t.Fatalf("NewRequest returned error %v", err) + } + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-Websocket-Version", "13") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do returned error %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusMethodNotAllowed { @@ -408,6 +498,54 @@ func TestBadMethod(t *testing.T) { } } +func TestNoUpgrade(t *testing.T) { + t.Parallel() + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := cstUpgrader.Upgrade(w, r, nil) + if err == nil { + t.Errorf("handshake succeeded, expect fail") + ws.Close() + } + })) + defer s.Close() + + req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("NewRequest returned error %v", err) + } + req.Header.Set("Connection", "upgrade") + req.Header.Set("Sec-Websocket-Version", "13") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do returned error %v", err) + } + resp.Body.Close() + if u := resp.Header.Get("Upgrade"); u != "websocket" { + t.Errorf("Upgrade response header is %q, want %q", u, "websocket") + } + if resp.StatusCode != http.StatusUpgradeRequired { + t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired) + } +} + +func TestDialExtraTokensInRespHeaders(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + challengeKey := r.Header.Get("Sec-Websocket-Key") + w.Header().Set("Upgrade", "foo, websocket") + w.Header().Set("Connection", "upgrade, keep-alive") + w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) + w.WriteHeader(101) + })) + defer s.Close() + + ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() +} + func TestHandshake(t *testing.T) { s := newServer(t) defer s.Close() @@ -440,7 +578,7 @@ func TestRespOnBadHandshake(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(expectedStatus) - io.WriteString(w, expectedBody) + _, _ = io.WriteString(w, expectedBody) })) defer s.Close() @@ -458,7 +596,7 @@ func TestRespOnBadHandshake(t *testing.T) { t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) } - p, err := ioutil.ReadAll(resp.Body) + p, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("ReadFull(resp.Body) returned error %v", err) } @@ -468,45 +606,640 @@ func TestRespOnBadHandshake(t *testing.T) { } } -// TestHostHeader confirms that the host header provided in the call to Dial is -// sent to the server. -func TestHostHeader(t *testing.T) { +type testLogWriter struct { + t *testing.T +} + +func (w testLogWriter) Write(p []byte) (int, error) { + w.t.Logf("%s", p) + return len(p), nil +} + +// TestHost tests handling of host names and confirms that it matches net/http. +func TestHost(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()} + wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"} + httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"} + + // Avoid log noise from net/http server by logging to testing.T + server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) + tlsServer.Config.ErrorLog = server.Config.ErrorLog + + cas := rootCAs(t, tlsServer) + + tests := []struct { + fail bool // true if dial / get should fail + server *httptest.Server // server to use + url string // host for request URI + header string // optional request host header + tls string // optional host for tls ServerName + wantAddr string // expected host for dial + wantHeader string // expected request header on server + insecureSkipVerify bool + }{ + { + server: server, + url: addrs[server], + wantAddr: addrs[server], + wantHeader: addrs[server], + }, + { + server: tlsServer, + url: addrs[tlsServer], + wantAddr: addrs[tlsServer], + wantHeader: addrs[tlsServer], + }, + + { + server: server, + url: addrs[server], + header: "badhost.com", + wantAddr: addrs[server], + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: addrs[tlsServer], + header: "badhost.com", + wantAddr: addrs[tlsServer], + wantHeader: "badhost.com", + }, + + { + server: server, + url: "example.com", + header: "badhost.com", + wantAddr: "example.com:80", + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: "example.com", + header: "badhost.com", + wantAddr: "example.com:443", + wantHeader: "badhost.com", + }, + + { + server: server, + url: "badhost.com", + header: "example.com", + wantAddr: "badhost.com:80", + wantHeader: "example.com", + }, + { + fail: true, + server: tlsServer, + url: "badhost.com", + header: "example.com", + wantAddr: "badhost.com:443", + }, + { + server: tlsServer, + url: "badhost.com", + insecureSkipVerify: true, + wantAddr: "badhost.com:443", + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: "badhost.com", + tls: "example.com", + wantAddr: "badhost.com:443", + wantHeader: "badhost.com", + }, + } + + for i, tt := range tests { + + tls := &tls.Config{ + RootCAs: cas, + ServerName: tt.tls, + InsecureSkipVerify: tt.insecureSkipVerify, + } + + var gotAddr string + dialer := Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + gotAddr = addr + return net.Dial(network, addrs[tt.server]) + }, + TLSClientConfig: tls, + } + + // Test websocket dial + + h := http.Header{} + if tt.header != "" { + h.Set("Host", tt.header) + } + c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h) + if err == nil { + c.Close() + } + + check := func(protos map[*httptest.Server]string) { + name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls) + if gotAddr != tt.wantAddr { + t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr) + } + switch { + case tt.fail && err == nil: + t.Errorf("%s: unexpected success", name) + case !tt.fail && err != nil: + t.Errorf("%s: unexpected error %v", name, err) + case !tt.fail && err == nil: + if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader { + t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader) + } + } + } + + check(wsProtos) + + // Confirm that net/http has same result + + transport := &http.Transport{ + Dial: dialer.NetDial, + TLSClientConfig: dialer.TLSClientConfig, + } + req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil) + if tt.header != "" { + req.Host = tt.header + } + client := &http.Client{Transport: transport} + resp, err = client.Do(req) + if err == nil { + resp.Body.Close() + } + transport.CloseIdleConnections() + check(httpProtos) + } +} + +func TestDialCompression(t *testing.T) { s := newServer(t) defer s.Close() - specifiedHost := make(chan string, 1) - origHandler := s.Server.Config.Handler + dialer := cstDialer + dialer.EnableCompression = true + ws, _, err := dialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} - // Capture the request Host header. - s.Server.Config.Handler = http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - specifiedHost <- r.Host - origHandler.ServeHTTP(w, r) - }) +func TestSocksProxyDial(t *testing.T) { + s := newServer(t) + defer s.Close() - ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) + proxyListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer proxyListener.Close() + go func() { + c1, err := proxyListener.Accept() + if err != nil { + t.Errorf("proxy accept failed: %v", err) + return + } + defer c1.Close() + + _ = c1.SetDeadline(time.Now().Add(30 * time.Second)) + + buf := make([]byte, 32) + if _, err := io.ReadFull(c1, buf[:3]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + } + if _, err := c1.Write([]byte{5, 0}); err != nil { + t.Errorf("write failed: %v", err) + return + } + if _, err := io.ReadFull(c1, buf[:10]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + return + } + buf[1] = 0 + if _, err := c1.Write(buf[:10]); err != nil { + t.Errorf("write failed: %v", err) + return + } + + ip := net.IP(buf[4:8]) + port := binary.BigEndian.Uint16(buf[8:10]) + + c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) + if err != nil { + t.Errorf("dial failed; %v", err) + return + } + defer c2.Close() + done := make(chan struct{}) + go func() { + _, _ = io.Copy(c1, c2) + close(done) + }() + _, _ = io.Copy(c2, c1) + <-done + }() + + purl, err := url.Parse("socks5://" + proxyListener.Addr().String()) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(purl) + + ws, _, err := cstDialer.Dial(s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) } defer ws.Close() + sendRecv(t, ws) +} + +func TestTracingDialWithContext(t *testing.T) { + + var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool + trace := &httptrace.ClientTrace{ + WroteHeaders: func() { + headersWrote = true + }, + WroteRequest: func(httptrace.WroteRequestInfo) { + requestWrote = true + }, + GetConn: func(hostPort string) { + getConn = true + }, + GotConn: func(info httptrace.GotConnInfo) { + gotConn = true + }, + ConnectDone: func(network, addr string, err error) { + connectDone = true + }, + GotFirstResponseByte: func() { + gotFirstResponseByte = true + }, + } + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } - if gotHost := <-specifiedHost; gotHost != "testhost" { - t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) + if !headersWrote { + t.Fatal("Headers was not written") + } + if !requestWrote { + t.Fatal("Request was not written") + } + if !getConn { + t.Fatal("getConn was not called") + } + if !gotConn { + t.Fatal("gotConn was not called") + } + if !connectDone { + t.Fatal("connectDone was not called") + } + if !gotFirstResponseByte { + t.Fatal("GotFirstResponseByte was not called") } + defer ws.Close() sendRecv(t, ws) } -func TestDialCompression(t *testing.T) { - s := newServer(t) +func TestEmptyTracingDialWithContext(t *testing.T) { + + trace := &httptrace.ClientTrace{} + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) defer s.Close() - dialer := cstDialer - dialer.EnableCompression = true - ws, _, err := dialer.Dial(s.URL, nil) + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + + ws, _, err := d.DialContext(ctx, s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) } + defer ws.Close() sendRecv(t, ws) } + +// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext +func TestNetDialConnect(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + testUrls := map[*httptest.Server]string{ + server: "ws://" + server.Listener.Addr().String() + "/", + tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/", + } + + cas := rootCAs(t, tlsServer) + tlsConfig := &tls.Config{ + RootCAs: cas, + ServerName: "example.com", + InsecureSkipVerify: false, + } + + tests := []struct { + name string + server *httptest.Server // server to use + netDial func(network, addr string) (net.Conn, error) + netDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + tlsClientConfig *tls.Config + }{ + + { + name: "HTTP server, all NetDial* defined, shall use NetDialContext", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTP server, all NetDial* undefined", + server: server, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: nil, + }, + { + name: "HTTP server, NetDialContext undefined, shall fallback to NetDial", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDial* undefined", + server: tlsServer, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: &tls.Config{ + RootCAs: nil, + ServerName: "badserver.com", + InsecureSkipVerify: false, + }, + }, + } + + for _, tc := range tests { + dialer := Dialer{ + NetDial: tc.netDial, + NetDialContext: tc.netDialContext, + NetDialTLSContext: tc.netDialTLSContext, + TLSClientConfig: tc.tlsClientConfig, + } + + // Test websocket dial + c, _, err := dialer.Dial(testUrls[tc.server], nil) + if err != nil { + t.Errorf("FAILED %s, err: %s", tc.name, err.Error()) + } else { + c.Close() + } + } +} +func TestNextProtos(t *testing.T) { + ts := httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + ) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + d := Dialer{ + TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig, + } + + r, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + r.Body.Close() + + // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2" + // after the Client.Get call from net/http above. + var containsHTTP2 bool = false + for _, proto := range d.TLSClientConfig.NextProtos { + if proto == "h2" { + containsHTTP2 = true + } + } + if !containsHTTP2 { + t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"") + } + + _, _, err = d.Dial(makeWsProto(ts.URL), nil) + if err == nil { + t.Fatalf("Dial succeeded, expect fail ") + } +} + +type dataBeforeHandshakeResponseWriter struct { + http.ResponseWriter +} + +type dataBeforeHandshakeConnection struct { + net.Conn + io.Reader +} + +func (c *dataBeforeHandshakeConnection) Read(p []byte) (int, error) { + return c.Reader.Read(p) +} + +func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Example single-frame masked text message from section 5.7 of the RFC. + message := []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58} + n := len(message) / 2 + + c, rw, err := http.NewResponseController(w.ResponseWriter).Hijack() + if rw != nil { + // Load first part of message into bufio.Reader. If the websocket + // connection reads more than n bytes from the bufio.Reader, then the + // test will fail with an unexpected EOF error. + rw.Reader.Reset(bytes.NewReader(message[:n])) + rw.Reader.Peek(n) + } + if c != nil { + // Inject second part of message before data read from the network connection. + c = &dataBeforeHandshakeConnection{ + Conn: c, + Reader: io.MultiReader(bytes.NewReader(message[n:]), c), + } + } + return c, rw, err +} + +func TestDataReceivedBeforeHandshake(t *testing.T) { + s := newServer(t) + defer s.Close() + + origHandler := s.Server.Config.Handler + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r) + }) + + for _, readBufferSize := range []int{0, 1024} { + t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) { + dialer := cstDialer + dialer.ReadBufferSize = readBufferSize + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + _, m, err := ws.ReadMessage() + if err != nil || string(m) != "Hello" { + t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err) + } + }) + } +} diff --git a/client_test.go b/client_test.go index 7d2b0844..5aa27b37 100644 --- a/client_test.go +++ b/client_test.go @@ -6,49 +6,9 @@ package websocket import ( "net/url" - "reflect" "testing" ) -var parseURLTests = []struct { - s string - u *url.URL - rui string -}{ - {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, - {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, - {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"}, - {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"}, - {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"}, - {"ss://example.com/a/b", nil, ""}, - {"ws://webmaster@example.com/", nil, ""}, - {"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"}, - {"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"}, -} - -func TestParseURL(t *testing.T) { - for _, tt := range parseURLTests { - u, err := parseURL(tt.s) - if tt.u != nil && err != nil { - t.Errorf("parseURL(%q) returned error %v", tt.s, err) - continue - } - if tt.u == nil { - if err == nil { - t.Errorf("parseURL(%q) did not return error", tt.s) - } - continue - } - if !reflect.DeepEqual(u, tt.u) { - t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u) - continue - } - if u.RequestURI() != tt.rui { - t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui) - } - } -} - var hostPortNoPortTests = []struct { u *url.URL hostPort, hostNoPort string diff --git a/compression.go b/compression.go index 813ffb1e..fe1079ed 100644 --- a/compression.go +++ b/compression.go @@ -33,7 +33,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { "\x01\x00\x00\xff\xff" fr, _ := flateReaderPool.Get().(io.ReadCloser) - fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + mr := io.MultiReader(r, strings.NewReader(tail)) + if err := fr.(flate.Resetter).Reset(mr, nil); err != nil { + // Reset never fails, but handle error in case that changes. + fr = flate.NewReader(mr) + } return &flateReadWrapper{fr} } diff --git a/compression_test.go b/compression_test.go index 659cf421..00ae42f1 100644 --- a/compression_test.go +++ b/compression_test.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "testing" ) @@ -23,7 +22,7 @@ func TestTruncWriter(t *testing.T) { if m > n { m = n } - w.Write(p[:m]) + _, _ = w.Write(p[:m]) p = p[m:] } if b.String() != data[:len(data)-len(w.p)] { @@ -42,31 +41,31 @@ func textMessages(num int) [][]byte { } func BenchmarkWriteNoCompression(b *testing.B) { - w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + w := io.Discard + c := newTestConn(nil, w, false) messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } func BenchmarkWriteWithCompression(b *testing.B) { - w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + w := io.Discard + c := newTestConn(nil, w, false) messages := textMessages(100) c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } func TestValidCompressionLevel(t *testing.T) { - c := newConn(fakeNetConn{}, false, 1024, 1024) + c := newTestConn(nil, nil, false) for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { if err := c.SetCompressionLevel(level); err == nil { t.Errorf("no error for level %d", level) diff --git a/conn.go b/conn.go index 97e1dbac..9562ffd4 100644 --- a/conn.go +++ b/conn.go @@ -6,13 +6,13 @@ package websocket import ( "bufio" + "crypto/rand" "encoding/binary" "errors" "io" - "io/ioutil" - "math/rand" "net" "strconv" + "strings" "sync" "time" "unicode/utf8" @@ -76,7 +76,7 @@ const ( // is UTF-8 encoded text. PingMessage = 9 - // PongMessage denotes a ping control message. The optional message payload + // PongMessage denotes a pong control message. The optional message payload // is UTF-8 encoded text. PongMessage = 10 ) @@ -100,9 +100,8 @@ func (e *netError) Error() string { return e.msg } func (e *netError) Temporary() bool { return e.temporary } func (e *netError) Timeout() bool { return e.timeout } -// CloseError represents close frame. +// CloseError represents a close message. type CloseError struct { - // Code is defined in RFC 6455, section 11.7. Code int @@ -181,16 +180,16 @@ var ( errInvalidControlFrame = errors.New("websocket: invalid control frame") ) -func newMaskKey() [4]byte { - n := rand.Uint32() - return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} -} +// maskRand is an io.Reader for generating mask bytes. The reader is initialized +// to crypto/rand Reader. Tests swap the reader to a math/rand reader for +// reproducible results. +var maskRand = rand.Reader -func hideTempErr(err error) error { - if e, ok := err.(net.Error); ok && e.Temporary() { - err = &netError{msg: e.Error(), timeout: e.Timeout()} - } - return err +// newMaskKey returns a new 32 bit value for masking client frames. +func newMaskKey() [4]byte { + var k [4]byte + _, _ = io.ReadFull(maskRand, k[:]) + return k } func isControl(frameType int) bool { @@ -224,6 +223,20 @@ func isValidReceivedCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } +// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this +// interface. The type of the value stored in a pool is not specified. +type BufferPool interface { + // Get gets a value from the pool or returns nil if the pool is empty. + Get() interface{} + // Put adds a value to the pool. + Put(interface{}) +} + +// writePoolData is the type added to the write buffer pool. This wrapper is +// used to prevent applications from peeking at and depending on the values +// added to the pool. +type writePoolData struct{ buf []byte } + // The Conn type represents a WebSocket connection. type Conn struct { conn net.Conn @@ -231,8 +244,10 @@ type Conn struct { subprotocol string // Write fields - mu chan bool // used as mutex to protect write to conn - writeBuf []byte // frame is constructed in this buffer. + mu chan struct{} // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writePool BufferPool + writeBufSize int writeDeadline time.Time writer io.WriteCloser // the current writer returned to the application isWriting bool // for best-effort concurrent write detection @@ -245,10 +260,12 @@ type Conn struct { newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields - reader io.ReadCloser // the current reader returned to the application - readErr error - br *bufio.Reader - readRemaining int64 // bytes remaining in current frame. + reader io.ReadCloser // the current reader returned to the application + readErr error + br *bufio.Reader + // bytes remaining in current frame. + // set setReadRemaining to safely update this value and prevent overflow + readRemaining int64 readFinal bool // true the current message has more frames. readLength int64 // Message size. readLimit int64 // Maximum message size. @@ -264,64 +281,29 @@ type Conn struct { newDecompressionReader func(io.Reader) io.ReadCloser } -func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { - return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) -} - -type writeHook struct { - p []byte -} - -func (wh *writeHook) Write(p []byte) (int, error) { - wh.p = p - return len(p), nil -} +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { -func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { - mu := make(chan bool, 1) - mu <- true - - var br *bufio.Reader - if readBufferSize == 0 && brw != nil && brw.Reader != nil { - // Reuse the supplied bufio.Reader if the buffer has a useful size. - // This code assumes that peek on a reader returns - // bufio.Reader.buf[:0]. - brw.Reader.Reset(conn) - if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 { - br = brw.Reader - } - } if br == nil { if readBufferSize == 0 { readBufferSize = defaultReadBufferSize - } - if readBufferSize < maxControlFramePayloadSize { + } else if readBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame readBufferSize = maxControlFramePayloadSize } br = bufio.NewReaderSize(conn, readBufferSize) } - var writeBuf []byte - if writeBufferSize == 0 && brw != nil && brw.Writer != nil { - // Use the bufio.Writer's buffer if the buffer has a useful size. This - // code assumes that bufio.Writer.buf[:1] is passed to the - // bufio.Writer's underlying writer. - var wh writeHook - brw.Writer.Reset(&wh) - brw.Writer.WriteByte(0) - brw.Flush() - if cap(wh.p) >= maxFrameHeaderSize+256 { - writeBuf = wh.p[:cap(wh.p)] - } + if writeBufferSize <= 0 { + writeBufferSize = defaultWriteBufferSize } + writeBufferSize += maxFrameHeaderSize - if writeBuf == nil { - if writeBufferSize == 0 { - writeBufferSize = defaultWriteBufferSize - } - writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) + if writeBuf == nil && writeBufferPool == nil { + writeBuf = make([]byte, writeBufferSize) } + mu := make(chan struct{}, 1) + mu <- struct{}{} c := &Conn{ isServer: isServer, br: br, @@ -329,6 +311,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in mu: mu, readFinal: true, writeBuf: writeBuf, + writePool: writeBufferPool, + writeBufSize: writeBufferSize, enableWriteCompression: true, compressionLevel: defaultCompressionLevel, } @@ -338,12 +322,24 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in return c } +// setReadRemaining tracks the number of bytes remaining on the connection. If n +// overflows, an ErrReadLimit is returned. +func (c *Conn) setReadRemaining(n int64) error { + if n < 0 { + return ErrReadLimit + } + + c.readRemaining = n + return nil +} + // Subprotocol returns the negotiated protocol for the connection. func (c *Conn) Subprotocol() string { return c.subprotocol } -// Close closes the underlying network connection without sending or waiting for a close frame. +// Close closes the underlying network connection without sending or waiting +// for a close message. func (c *Conn) Close() error { return c.conn.Close() } @@ -361,7 +357,6 @@ func (c *Conn) RemoteAddr() net.Addr { // Write methods func (c *Conn) writeFatal(err error) error { - err = hideTempErr(err) c.writeErrMu.Lock() if c.writeErr == nil { c.writeErr = err @@ -370,9 +365,20 @@ func (c *Conn) writeFatal(err error) error { return err } -func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + // Discard is guaranteed to succeed because the number of bytes to discard + // is less than or equal to the number of bytes buffered. + _, _ = c.br.Discard(len(p)) + return p, err +} + +func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { <-c.mu - defer func() { c.mu <- true }() + defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() err := c.writeErr @@ -381,22 +387,29 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { return err } - c.conn.SetWriteDeadline(deadline) - for _, buf := range bufs { - if len(buf) > 0 { - _, err := c.conn.Write(buf) - if err != nil { - return c.writeFatal(err) - } - } + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } + if len(buf1) == 0 { + _, err = c.conn.Write(buf0) + } else { + err = c.writeBufs(buf0, buf1) + } + if err != nil { + return c.writeFatal(err) } - if frameType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return nil } +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err +} + // WriteControl writes a control message with the given deadline. The allowed // message types are CloseMessage, PingMessage and PongMessage. func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { @@ -425,22 +438,28 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er maskBytes(key, 0, buf[6:]) } - d := time.Hour * 1000 - if !deadline.IsZero() { - d = deadline.Sub(time.Now()) + if deadline.IsZero() { + // No timeout for zero time. + <-c.mu + } else { + d := time.Until(deadline) if d < 0 { return errWriteTimeout } + select { + case <-c.mu: + default: + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + } } - timer := time.NewTimer(d) - select { - case <-c.mu: - timer.Stop() - case <-timer.C: - return errWriteTimeout - } - defer func() { c.mu <- true }() + defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() err := c.writeErr @@ -449,18 +468,20 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er return err } - c.conn.SetWriteDeadline(deadline) - _, err = c.conn.Write(buf) - if err != nil { + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } + if _, err = c.conn.Write(buf); err != nil { return c.writeFatal(err) } if messageType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return err } -func (c *Conn) prepWrite(messageType int) error { +// beginMessage prepares a connection and message writer for a new message. +func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { // Close previous writer if not already closed by the application. It's // probably better to return an error in this situation, but we cannot // change this without breaking existing applications. @@ -476,7 +497,23 @@ func (c *Conn) prepWrite(messageType int) error { c.writeErrMu.Lock() err := c.writeErr c.writeErrMu.Unlock() - return err + if err != nil { + return err + } + + mw.c = c + mw.frameType = messageType + mw.pos = maxFrameHeaderSize + + if c.writeBuf == nil { + wpd, ok := c.writePool.Get().(writePoolData) + if ok { + c.writeBuf = wpd.buf + } else { + c.writeBuf = make([]byte, c.writeBufSize) + } + } + return nil } // NextWriter returns a writer for the next message to send. The writer's Close @@ -484,17 +521,15 @@ func (c *Conn) prepWrite(messageType int) error { // // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. +// +// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and +// PongMessage) are supported. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { - if err := c.prepWrite(messageType); err != nil { + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { return nil, err } - - mw := &messageWriter{ - c: c, - frameType: messageType, - pos: maxFrameHeaderSize, - } - c.writer = mw + c.writer = &mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { w := c.newCompressionWriter(c.writer, c.compressionLevel) mw.compress = true @@ -511,10 +546,16 @@ type messageWriter struct { err error } -func (w *messageWriter) fatal(err error) error { +func (w *messageWriter) endMessage(err error) error { if w.err != nil { - w.err = err - w.c.writer = nil + return err + } + c := w.c + w.err = err + c.writer = nil + if c.writePool != nil { + c.writePool.Put(writePoolData{buf: c.writeBuf}) + c.writeBuf = nil } return err } @@ -528,7 +569,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { // Check for invalid control frames. if isControl(w.frameType) && (!final || length > maxControlFramePayloadSize) { - return w.fatal(errInvalidControlFrame) + return w.endMessage(errInvalidControlFrame) } b0 := byte(w.frameType) @@ -573,7 +614,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) if len(extra) > 0 { - return c.writeFatal(errors.New("websocket: internal error, extra used in client mode")) + return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) } } @@ -594,11 +635,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { c.isWriting = false if err != nil { - return w.fatal(err) + return w.endMessage(err) } if final { - c.writer = nil + _ = w.endMessage(errWriteClosed) return nil } @@ -696,11 +737,7 @@ func (w *messageWriter) Close() error { if w.err != nil { return w.err } - if err := w.flushFrame(true, nil); err != nil { - return err - } - w.err = errWriteClosed - return nil + return w.flushFrame(true, nil) } // WritePreparedMessage writes prepared message into connection. @@ -732,10 +769,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { // Fast path with no allocations and single frame. - if err := c.prepWrite(messageType); err != nil { + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { return err } - mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize} n := copy(c.writeBuf[mw.pos:], data) mw.pos += n data = data[n:] @@ -764,60 +801,91 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { // Read methods func (c *Conn) advanceFrame() (int, error) { - // 1. Skip remainder of previous frame. if c.readRemaining > 0 { - if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { + if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil { return noFrame, err } } // 2. Read and parse first two bytes of frame header. + // To aid debugging, collect and report all errors in the first two bytes + // of the header. + + var errors []string p, err := c.read(2) if err != nil { return noFrame, err } - final := p[0]&finalBit != 0 frameType := int(p[0] & 0xf) + final := p[0]&finalBit != 0 + rsv1 := p[0]&rsv1Bit != 0 + rsv2 := p[0]&rsv2Bit != 0 + rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 - c.readRemaining = int64(p[1] & 0x7f) + _ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0 c.readDecompress = false - if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { - c.readDecompress = true - p[0] &^= rsv1Bit + if rsv1 { + if c.newDecompressionReader != nil { + c.readDecompress = true + } else { + errors = append(errors, "RSV1 set") + } + } + + if rsv2 { + errors = append(errors, "RSV2 set") } - if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { - return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) + if rsv3 { + errors = append(errors, "RSV3 set") } switch frameType { case CloseMessage, PingMessage, PongMessage: if c.readRemaining > maxControlFramePayloadSize { - return noFrame, c.handleProtocolError("control frame length > 125") + errors = append(errors, "len > 125 for control") } if !final { - return noFrame, c.handleProtocolError("control frame not final") + errors = append(errors, "FIN not set on control") } case TextMessage, BinaryMessage: if !c.readFinal { - return noFrame, c.handleProtocolError("message start before final message frame") + errors = append(errors, "data before FIN") } c.readFinal = final case continuationFrame: if c.readFinal { - return noFrame, c.handleProtocolError("continuation after final message frame") + errors = append(errors, "continuation after FIN") } c.readFinal = final default: - return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) } - // 3. Read and parse frame length. + if mask != c.isServer { + errors = append(errors, "bad MASK") + } + + if len(errors) > 0 { + return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) + } + + // 3. Read and parse frame length as per + // https://tools.ietf.org/html/rfc6455#section-5.2 + // + // The length of the "Payload data", in bytes: if 0-125, that is the payload + // length. + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned + // integer are the payload length. + // - If 127, the following 8 bytes interpreted as + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the + // payload length. Multibyte length quantities are expressed in network byte + // order. switch c.readRemaining { case 126: @@ -825,21 +893,23 @@ func (c *Conn) advanceFrame() (int, error) { if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint16(p)) + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { + return noFrame, err + } case 127: p, err := c.read(8) if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint64(p)) + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { + return noFrame, err + } } // 4. Handle frame masking. - if mask != c.isServer { - return noFrame, c.handleProtocolError("incorrect mask flag") - } - if mask { c.readMaskPos = 0 p, err := c.read(len(c.readMaskKey)) @@ -854,8 +924,15 @@ func (c *Conn) advanceFrame() (int, error) { if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { c.readLength += c.readRemaining + // Don't allow readLength to overflow in the presence of a large readRemaining + // counter. + if c.readLength < 0 { + return noFrame, ErrReadLimit + } + if c.readLimit > 0 && c.readLength > c.readLimit { - c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + // Make a best effort to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit } @@ -867,7 +944,7 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) - c.readRemaining = 0 + _ = c.setReadRemaining(0) // will not fail because argument is >= 0 if err != nil { return noFrame, err } @@ -893,7 +970,7 @@ func (c *Conn) advanceFrame() (int, error) { if len(payload) >= 2 { closeCode = int(binary.BigEndian.Uint16(payload)) if !isValidReceivedCloseCode(closeCode) { - return noFrame, c.handleProtocolError("invalid close code") + return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) } closeText = string(payload[2:]) if !utf8.ValidString(closeText) { @@ -910,7 +987,12 @@ func (c *Conn) advanceFrame() (int, error) { } func (c *Conn) handleProtocolError(message string) error { - c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + data := FormatCloseMessage(CloseProtocolError, message) + if len(data) > maxControlFramePayloadSize { + data = data[:maxControlFramePayloadSize] + } + // Make a best effor to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) return errors.New("websocket: " + message) } @@ -937,9 +1019,10 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { for c.readErr == nil { frameType, err := c.advanceFrame() if err != nil { - c.readErr = hideTempErr(err) + c.readErr = err break } + if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} c.reader = c.messageReader @@ -976,11 +1059,13 @@ func (r *messageReader) Read(b []byte) (int, error) { b = b[:c.readRemaining] } n, err := c.br.Read(b) - c.readErr = hideTempErr(err) + c.readErr = err if c.isServer { c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } - c.readRemaining -= int64(n) + rem := c.readRemaining + rem -= int64(n) + _ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0 if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } @@ -995,7 +1080,7 @@ func (r *messageReader) Read(b []byte) (int, error) { frameType, err := c.advanceFrame() switch { case err != nil: - c.readErr = hideTempErr(err) + c.readErr = err case frameType == TextMessage || frameType == BinaryMessage: c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } @@ -1020,7 +1105,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { if err != nil { return messageType, nil, err } - p, err = ioutil.ReadAll(r) + p, err = io.ReadAll(r) return messageType, p, err } @@ -1032,8 +1117,8 @@ func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -// SetReadLimit sets the maximum size for a message read from the peer. If a -// message exceeds the limit, the connection sends a close frame to the peer +// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a +// message exceeds the limit, the connection sends a close message to the peer // and returns ErrReadLimit to the application. func (c *Conn) SetReadLimit(limit int64) { c.readLimit = limit @@ -1046,25 +1131,24 @@ func (c *Conn) CloseHandler() func(code int, text string) error { // SetCloseHandler sets the handler for close messages received from the peer. // The code argument to h is the received close code or CloseNoStatusReceived -// if the close message is empty. The default close handler sends a close frame -// back to the peer. +// if the close message is empty. The default close handler sends a close +// message back to the peer. // -// The application must read the connection to process close messages as -// described in the section on Control Frames above. +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// close messages as described in the section on Control Messages above. // -// The connection read methods return a CloseError when a close frame is +// The connection read methods return a CloseError when a close message is // received. Most applications should handle close messages as part of their // normal error handling. Applications should only set a close handler when the -// application must perform some action before sending a close frame back to +// application must perform some action before sending a close message back to // the peer. func (c *Conn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { - message := []byte{} - if code != CloseNoStatusReceived { - message = FormatCloseMessage(code, "") - } - c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + message := FormatCloseMessage(code, "") + // Make a best effor to send the close message. + _ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) return nil } } @@ -1077,21 +1161,18 @@ func (c *Conn) PingHandler() func(appData string) error { } // SetPingHandler sets the handler for ping messages received from the peer. -// The appData argument to h is the PING frame application data. The default +// The appData argument to h is the PING message application data. The default // ping handler sends a pong to the peer. // -// The application must read the connection to process ping messages as -// described in the section on Control Frames above. +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// ping messages as described in the section on Control Messages above. func (c *Conn) SetPingHandler(h func(appData string) error) { if h == nil { h = func(message string) error { - err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) - if err == ErrCloseSent { - return nil - } else if e, ok := err.(net.Error); ok && e.Temporary() { - return nil - } - return err + // Make a best effort to send the pong message. + _ = c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + return nil } } c.handlePing = h @@ -1103,11 +1184,12 @@ func (c *Conn) PongHandler() func(appData string) error { } // SetPongHandler sets the handler for pong messages received from the peer. -// The appData argument to h is the PONG frame application data. The default +// The appData argument to h is the PONG message application data. The default // pong handler does nothing. // -// The application must read the connection to process ping messages as -// described in the section on Control Frames above. +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// pong messages as described in the section on Control Messages above. func (c *Conn) SetPongHandler(h func(appData string) error) { if h == nil { h = func(string) error { return nil } @@ -1115,8 +1197,16 @@ func (c *Conn) SetPongHandler(h func(appData string) error) { c.handlePong = h } +// NetConn returns the underlying connection that is wrapped by c. +// Note that writing to or reading from this connection directly will corrupt the +// WebSocket connection. +func (c *Conn) NetConn() net.Conn { + return c.conn +} + // UnderlyingConn returns the internal net.Conn. This can be used to further // modifications to connection specific flags. +// Deprecated: Use the NetConn method. func (c *Conn) UnderlyingConn() net.Conn { return c.conn } @@ -1141,7 +1231,14 @@ func (c *Conn) SetCompressionLevel(level int) error { } // FormatCloseMessage formats closeCode and text as a WebSocket close message. +// An empty message is returned for code CloseNoStatusReceived. func FormatCloseMessage(closeCode int, text string) []byte { + if closeCode == CloseNoStatusReceived { + // Return empty message because it's illegal to send + // CloseNoStatusReceived. Return non-nil value in case application + // checks for nil. + return []byte{} + } buf := make([]byte, 2+len(text)) binary.BigEndian.PutUint16(buf, uint16(closeCode)) copy(buf[2:], text) diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index 45038e48..540be6ba 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -2,13 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build go1.7 - package websocket import ( "io" - "io/ioutil" "sync/atomic" "testing" ) @@ -20,7 +17,6 @@ import ( // scenarios with many subscribers in one channel. type broadcastBench struct { w io.Writer - message *broadcastMessage closeCh chan struct{} doneCh chan struct{} count int32 @@ -48,20 +44,12 @@ func newBroadcastConn(c *Conn) *broadcastConn { func newBroadcastBench(usePrepared, compression bool) *broadcastBench { bench := &broadcastBench{ - w: ioutil.Discard, + w: io.Discard, doneCh: make(chan struct{}), closeCh: make(chan struct{}), usePrepared: usePrepared, compression: compression, } - msg := &broadcastMessage{ - payload: textMessages(1)[0], - } - if usePrepared { - pm, _ := NewPreparedMessage(TextMessage, msg.payload) - msg.prepared = pm - } - bench.message = msg bench.makeConns(10000) return bench } @@ -70,7 +58,7 @@ func (b *broadcastBench) makeConns(numConns int) { conns := make([]*broadcastConn, numConns) for i := 0; i < numConns; i++ { - c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024) + c := newTestConn(nil, b.w, true) if b.compression { c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover @@ -80,10 +68,10 @@ func (b *broadcastBench) makeConns(numConns int) { for { select { case msg := <-c.msgCh: - if b.usePrepared { - c.conn.WritePreparedMessage(msg.prepared) + if msg.prepared != nil { + _ = c.conn.WritePreparedMessage(msg.prepared) } else { - c.conn.WriteMessage(TextMessage, msg.payload) + _ = c.conn.WriteMessage(TextMessage, msg.payload) } val := atomic.AddInt32(&b.count, 1) if val%int32(numConns) == 0 { @@ -102,9 +90,9 @@ func (b *broadcastBench) close() { close(b.closeCh) } -func (b *broadcastBench) runOnce() { +func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { for _, c := range b.conns { - c.msgCh <- b.message + c.msgCh <- msg } <-b.doneCh } @@ -116,17 +104,25 @@ func BenchmarkBroadcast(b *testing.B) { compression bool }{ {"NoCompression", false, false}, - {"WithCompression", false, true}, + {"Compression", false, true}, {"NoCompressionPrepared", true, false}, - {"WithCompressionPrepared", true, true}, + {"CompressionPrepared", true, true}, } + payload := textMessages(1)[0] for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { bench := newBroadcastBench(bm.usePrepared, bm.compression) defer bench.close() b.ResetTimer() for i := 0; i < b.N; i++ { - bench.runOnce() + message := &broadcastMessage{ + payload: payload, + } + if bench.usePrepared { + pm, _ := NewPreparedMessage(TextMessage, message.payload) + message.prepared = pm + } + bench.broadcastOnce(message) } b.ReportAllocs() }) diff --git a/conn_read.go b/conn_read.go deleted file mode 100644 index 1ea15059..00000000 --- a/conn_read.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.5 - -package websocket - -import "io" - -func (c *Conn) read(n int) ([]byte, error) { - p, err := c.br.Peek(n) - if err == io.EOF { - err = errUnexpectedEOF - } - c.br.Discard(len(p)) - return p, err -} diff --git a/conn_read_legacy.go b/conn_read_legacy.go deleted file mode 100644 index 018541cf..00000000 --- a/conn_read_legacy.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !go1.5 - -package websocket - -import "io" - -func (c *Conn) read(n int) ([]byte, error) { - p, err := c.br.Peek(n) - if err == io.EOF { - err = errUnexpectedEOF - } - if len(p) > 0 { - // advance over the bytes just read - io.ReadFull(c.br, p) - } - return p, err -} diff --git a/conn_test.go b/conn_test.go index 06e9bc3f..28f5c4a3 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,9 +10,9 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "reflect" + "sync" "testing" "testing/iotest" "time" @@ -47,8 +47,17 @@ func (a fakeAddr) String() string { return "str" } +// newTestConn creates a connection backed by a fake network connection using +// default values for buffering. +func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { + return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) +} + func TestFraming(t *testing.T) { - frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} + frameSizes := []int{ + 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, + // 65536, 65537 + } var readChunkers = []struct { name string f func(io.Reader) io.Reader @@ -82,8 +91,8 @@ func TestFraming(t *testing.T) { for _, chunker := range readChunkers { var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(chunker.f(&connBuf), nil, !isServer) if compress { wc.newCompressionWriter = compressNoContextTakeover rc.newDecompressionReader = decompressNoContextTakeover @@ -113,7 +122,9 @@ func TestFraming(t *testing.T) { t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) continue } - rbuf, err := ioutil.ReadAll(r) + + t.Logf("frame size: %d", n) + rbuf, err := io.ReadAll(r) if err != nil { t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) continue @@ -137,16 +148,57 @@ func TestFraming(t *testing.T) { } } -func TestControl(t *testing.T) { +func TestWriteControlDeadline(t *testing.T) { + t.Parallel() + message := []byte("hello") + var connBuf bytes.Buffer + c := newTestConn(nil, &connBuf, true) + if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil { + t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil { + t.Errorf("WriteControl(..., future deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil { + t.Errorf("WriteControl(..., past deadline) = nil, want timeout error") + } +} + +func TestConcurrencyWriteControl(t *testing.T) { const message = "this is a ping/pong messsage" + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer + + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { + t.Errorf("concurrently wc.WriteControl() returned %v", err) + } + }() + } + + wg.Wait() + wc.Close() + } +} + +func TestControl(t *testing.T) { + const message = "this is a ping/pong message" for _, isServer := range []bool{true, false} { for _, isWriteControl := range []bool{true, false} { name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024) + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(&connBuf, nil, !isServer) if isWriteControl { - wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) } else { w, err := wc.NextWriter(PongMessage) if err != nil { @@ -163,7 +215,7 @@ func TestControl(t *testing.T) { } var actualMessage string rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) - rc.NextReader() + _, _, _ = rc.NextReader() if actualMessage != message { t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) continue @@ -173,25 +225,189 @@ func TestControl(t *testing.T) { } } +// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. +type simpleBufferPool struct { + v interface{} +} + +func (p *simpleBufferPool) Get() interface{} { + v := p.v + p.v = nil + return v +} + +func (p *simpleBufferPool) Put(v interface{}) { + p.v = v +} + +func TestWriteBufferPool(t *testing.T) { + const message = "Now is the time for all good people to come to the aid of the party." + + var buf bytes.Buffer + var pool simpleBufferPool + rc := newTestConn(&buf, nil, false) + + // Specify writeBufferSize smaller than message size to ensure that pooling + // works with fragmented messages. + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after create") + } + + // Part 1: test NextWriter/Write/Close + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, message); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err != nil { + t.Fatalf("w.Close() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after w.Close()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + + // Part 2: Test WriteMessage. + + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after wc.WriteMessage()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool after WriteMessage") + } + + opCode, p, err = rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } +} + +// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. +func TestWriteBufferPoolSync(t *testing.T) { + var buf bytes.Buffer + var pool sync.Pool + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) + rc := newTestConn(&buf, nil, false) + + const message = "Hello World!" + for i := 0; i < 3; i++ { + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + } +} + +// errorWriter is an io.Writer than returns an error on all writes. +type errorWriter struct{} + +func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } + +// TestWriteBufferPoolError ensures that buffer is returned to pool after error +// on write. +func TestWriteBufferPoolError(t *testing.T) { + + // Part 1: Test NextWriter/Write/Close + + var pool simpleBufferPool + wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, "Hello"); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err == nil { + t.Fatalf("w.Close() did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + // Part 2: Test WriteMessage + + wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { + t.Fatalf("wc.WriteMessage did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } +} + func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { const bufSize = 512 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) - wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) + _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) w.Close() op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("NextReader() returned %d, %v", op, err) } - _, err = io.Copy(ioutil.Discard, r) + _, err = io.Copy(io.Discard, r) if !reflect.DeepEqual(err, expectedErr) { t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) } @@ -206,11 +422,11 @@ func TestEOFWithinFrame(t *testing.T) { for n := 0; ; n++ { var b bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024) - rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024) + wc := newTestConn(nil, &b, false) + rc := newTestConn(&b, nil, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize)) + _, _ = w.Write(make([]byte, bufSize)) w.Close() if n >= b.Len() { @@ -225,7 +441,7 @@ func TestEOFWithinFrame(t *testing.T) { if op != BinaryMessage || err != nil { t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) } - _, err = io.Copy(ioutil.Discard, r) + _, err = io.Copy(io.Discard, r) if err != errUnexpectedEOF { t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) } @@ -240,17 +456,17 @@ func TestEOFBeforeFinalFrame(t *testing.T) { const bufSize = 512 var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("NextReader() returned %d, %v", op, err) } - _, err = io.Copy(ioutil.Discard, r) + _, err = io.Copy(io.Discard, r) if err != errUnexpectedEOF { t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) } @@ -261,11 +477,11 @@ func TestEOFBeforeFinalFrame(t *testing.T) { } func TestWriteAfterMessageWriterClose(t *testing.T) { - wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024) + wc := newTestConn(nil, &bytes.Buffer{}, false) w, _ := wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + _, _ = io.WriteString(w, "hello") if err := w.Close(); err != nil { - t.Fatalf("unxpected error closing message writer, %v", err) + t.Fatalf("unexpected error closing message writer, %v", err) } if _, err := io.WriteString(w, "world"); err == nil { @@ -273,7 +489,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } w, _ = wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + _, _ = io.WriteString(w, "hello") // close w by getting next writer _, err := wc.NextWriter(BinaryMessage) @@ -287,41 +503,97 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } func TestReadLimit(t *testing.T) { + t.Run("Test ReadLimit is enforced", func(t *testing.T) { + const readLimit = 512 + message := make([]byte, readLimit+1) - const readLimit = 512 - message := make([]byte, readLimit+1) + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) - var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) - rc.SetReadLimit(readLimit) + // Send message at the limit with interleaved pong. + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(message[:readLimit-1]) + _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + _, _ = w.Write(message[:1]) + w.Close() - // Send message at the limit with interleaved pong. - w, _ := wc.NextWriter(BinaryMessage) - w.Write(message[:readLimit-1]) - wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) - w.Write(message[:1]) - w.Close() + // Send message larger than the limit. + _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + + op, _, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("2: NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(io.Discard, r) + if err != ErrReadLimit { + t.Fatalf("io.Copy() returned %v", err) + } + }) - // Send message larger than the limit. - wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { + const readLimit = 1 - op, _, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("1: NextReader() returned %d, %v", op, err) - } - op, r, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("2: NextReader() returned %d, %v", op, err) - } - _, err = io.Copy(ioutil.Discard, r) - if err != ErrReadLimit { - t.Fatalf("io.Copy() returned %v", err) - } + var b1, b2 bytes.Buffer + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // First, send a non-final binary message + b1.Write([]byte("\x02\x81")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // First payload + b1.Write([]byte("A")) + + // Next, send a negative-length, non-final continuation frame + b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Next, send a too long, final continuation frame + b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Too-long payload + b1.Write([]byte("BCDEF")) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + + var buf [10]byte + var read int + n, err := r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + n, err = r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + if err == nil && read > readLimit { + t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) + } + }) } func TestAddrs(t *testing.T) { - c := newConn(&fakeNetConn{}, true, 1024, 1024) + c := newTestConn(nil, nil, true) if c.LocalAddr() != localAddr { t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) } @@ -330,29 +602,38 @@ func TestAddrs(t *testing.T) { } } -func TestUnderlyingConn(t *testing.T) { +func TestDeprecatedUnderlyingConn(t *testing.T) { var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} - c := newConn(fc, true, 1024, 1024) + c := newConn(fc, true, 1024, 1024, nil, nil, nil) ul := c.UnderlyingConn() if ul != fc { t.Fatalf("Underlying conn is not what it should be.") } } -func TestBufioReadBytes(t *testing.T) { +func TestNetConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.NetConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} +func TestBufioReadBytes(t *testing.T) { // Test calling bufio.ReadBytes for value longer than read buffer size. m := make([]byte, 512) m[len(m)-1] = '\n' var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64) + wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage) - w.Write(m) + _, _ = w.Write(m) w.Close() op, r, err := rc.NextReader() @@ -366,7 +647,7 @@ func TestBufioReadBytes(t *testing.T) { t.Fatalf("ReadBytes() returned %v", err) } if len(p) != len(m) { - t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m)) + t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) } } @@ -424,9 +705,9 @@ func (w blockingWriter) Write(p []byte) (int, error) { func TestConcurrentWritePanic(t *testing.T) { w := blockingWriter{make(chan struct{}), make(chan struct{})} - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newTestConn(nil, w, false) go func() { - c.WriteMessage(TextMessage, []byte{}) + _ = c.WriteMessage(TextMessage, []byte{}) }() // wait for goroutine to block in write. @@ -439,7 +720,7 @@ func TestConcurrentWritePanic(t *testing.T) { } }() - c.WriteMessage(TextMessage, []byte{}) + _ = c.WriteMessage(TextMessage, []byte{}) t.Fatal("should not get here") } @@ -450,7 +731,7 @@ func (r failingReader) Read(p []byte) (int, error) { } func TestFailedConnectionReadPanic(t *testing.T) { - c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024) + c := newTestConn(failingReader{}, nil, false) defer func() { if v := recover(); v != nil { @@ -459,39 +740,7 @@ func TestFailedConnectionReadPanic(t *testing.T) { }() for i := 0; i < 20000; i++ { - c.ReadMessage() + _, _, _ = c.ReadMessage() } t.Fatal("should not get here") } - -func TestBufioReuse(t *testing.T) { - brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil)) - c := newConnBRW(nil, false, 0, 0, brw) - - if c.br != brw.Reader { - t.Error("connection did not reuse bufio.Reader") - } - - var wh writeHook - brw.Writer.Reset(&wh) - brw.WriteByte(0) - brw.Flush() - if &c.writeBuf[0] != &wh.p[0] { - t.Error("connection did not reuse bufio.Writer") - } - - brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0)) - c = newConnBRW(nil, false, 0, 0, brw) - - if c.br == brw.Reader { - t.Error("connection used bufio.Reader with small size") - } - - brw.Writer.Reset(&wh) - brw.WriteByte(0) - brw.Flush() - if &c.writeBuf[0] != &wh.p[0] { - t.Error("connection used bufio.Writer with small size") - } - -} diff --git a/doc.go b/doc.go index e291a952..8db0cef9 100644 --- a/doc.go +++ b/doc.go @@ -6,9 +6,8 @@ // // Overview // -// The Conn type represents a WebSocket connection. A server application uses -// the Upgrade function from an Upgrader object with a HTTP request handler -// to get a pointer to a Conn: +// The Conn type represents a WebSocket connection. A server application calls +// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: // // var upgrader = websocket.Upgrader{ // ReadBufferSize: 1024, @@ -31,10 +30,12 @@ // for { // messageType, p, err := conn.ReadMessage() // if err != nil { +// log.Println(err) // return // } -// if err = conn.WriteMessage(messageType, p); err != nil { -// return err +// if err := conn.WriteMessage(messageType, p); err != nil { +// log.Println(err) +// return // } // } // @@ -85,20 +86,26 @@ // and pong. Call the connection WriteControl, WriteMessage or NextWriter // methods to send a control message to the peer. // -// Connections handle received close messages by sending a close message to the -// peer and returning a *CloseError from the the NextReader, ReadMessage or the -// message Read method. +// Connections handle received close messages by calling the handler function +// set with the SetCloseHandler method and by returning a *CloseError from the +// NextReader, ReadMessage or the message Read method. The default close +// handler sends a close message to the peer. // -// Connections handle received ping and pong messages by invoking callback -// functions set with SetPingHandler and SetPongHandler methods. The callback -// functions are called from the NextReader, ReadMessage and the message Read -// methods. +// Connections handle received ping messages by calling the handler function +// set with the SetPingHandler method. The default ping handler sends a pong +// message to the peer. +// +// Connections handle received pong messages by calling the handler function +// set with the SetPongHandler method. The default pong handler does nothing. +// If an application sends ping messages, then the application should set a +// pong handler to receive the corresponding pong. // -// The default ping handler sends a pong to the peer. The application's reading -// goroutine can block for a short time while the handler writes the pong data -// to the connection. +// The control message handler functions are called from the NextReader, +// ReadMessage and message reader Read methods. The default close and ping +// handlers can block these methods for a short time when the handler writes to +// the connection. // -// The application must read the connection to process ping, pong and close +// The application must read the connection to process close, ping and pong // messages sent from the peer. If the application is not otherwise interested // in messages from the peer, then the application should start a goroutine to // read and discard messages from the peer. A simple example is: @@ -137,19 +144,59 @@ // method fails the WebSocket handshake with HTTP status 403. // // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail -// the handshake if the Origin request header is present and not equal to the -// Host request header. -// -// An application can allow connections from any origin by specifying a -// function that always returns true: -// -// var upgrader = websocket.Upgrader{ -// CheckOrigin: func(r *http.Request) bool { return true }, -// } -// -// The deprecated Upgrade function does not enforce an origin policy. It's the -// application's responsibility to check the Origin header before calling -// Upgrade. +// the handshake if the Origin request header is present and the Origin host is +// not equal to the Host request header. +// +// The deprecated package-level Upgrade function does not perform origin +// checking. The application is responsible for checking the Origin header +// before calling the Upgrade function. +// +// Buffers +// +// Connections buffer network input and output to reduce the number +// of system calls when reading or writing messages. +// +// Write buffers are also used for constructing WebSocket frames. See RFC 6455, +// Section 5 for a discussion of message framing. A WebSocket frame header is +// written to the network each time a write buffer is flushed to the network. +// Decreasing the size of the write buffer can increase the amount of framing +// overhead on the connection. +// +// The buffer sizes in bytes are specified by the ReadBufferSize and +// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default +// size of 4096 when a buffer size field is set to zero. The Upgrader reuses +// buffers created by the HTTP server when a buffer size field is set to zero. +// The HTTP server buffers have a size of 4096 at the time of this writing. +// +// The buffer sizes do not limit the size of a message that can be read or +// written by a connection. +// +// Buffers are held for the lifetime of the connection by default. If the +// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the +// write buffer only when writing a message. +// +// Applications should tune the buffer sizes to balance memory use and +// performance. Increasing the buffer size uses more memory, but can reduce the +// number of system calls to read or write the network. In the case of writing, +// increasing the buffer size can reduce the number of frame headers written to +// the network. +// +// Some guidelines for setting buffer parameters are: +// +// Limit the buffer sizes to the maximum expected message size. Buffers larger +// than the largest message do not provide any benefit. +// +// Depending on the distribution of message sizes, setting the buffer size to +// a value less than the maximum expected message size can greatly reduce memory +// use with a small impact on performance. Here's an example: If 99% of the +// messages are smaller than 256 bytes and the maximum message size is 512 +// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls +// than a buffer size of 512 bytes. The memory savings is 50%. +// +// A write buffer pool is useful when the application has a modest number +// writes over a large number of connections. when buffers are pooled, a larger +// buffer size has a reduced impact on total memory use and has the benefit of +// reducing system calls and frame overhead. // // Compression EXPERIMENTAL // diff --git a/example_test.go b/example_test.go index 96449eac..cd1883b8 100644 --- a/example_test.go +++ b/example_test.go @@ -23,10 +23,9 @@ var ( // This server application works with a client application running in the // browser. The client application does not explicitly close the websocket. The // only expected close message from the client has the code -// websocket.CloseGoingAway. All other other close messages are likely the +// websocket.CloseGoingAway. All other close messages are likely the // result of an application or protocol error and are logged to aid debugging. func ExampleIsUnexpectedCloseError() { - for { messageType, p, err := c.ReadMessage() if err != nil { @@ -35,11 +34,11 @@ func ExampleIsUnexpectedCloseError() { } return } - processMesage(messageType, p) + processMessage(messageType, p) } } -func processMesage(mt int, p []byte) {} +func processMessage(mt int, p []byte) {} // TestX prevents godoc from showing this entire file in the example. Remove // this function when a second example is added. diff --git a/examples/autobahn/README.md b/examples/autobahn/README.md index 075ac153..cc954fea 100644 --- a/examples/autobahn/README.md +++ b/examples/autobahn/README.md @@ -1,6 +1,6 @@ # Test Server -This package contains a server for the [Autobahn WebSockets Test Suite](http://autobahn.ws/testsuite). +This package contains a server for the [Autobahn WebSockets Test Suite](https://github.com/crossbario/autobahn-testsuite). To test the server, run @@ -8,6 +8,11 @@ To test the server, run and start the client test driver - wstest -m fuzzingclient -s fuzzingclient.json + mkdir -p reports + docker run -it --rm \ + -v ${PWD}/config:/config \ + -v ${PWD}/reports:/reports \ + crossbario/autobahn-testsuite \ + wstest -m fuzzingclient -s /config/fuzzingclient.json -When the client completes, it writes a report to reports/clients/index.html. +When the client completes, it writes a report to reports/index.html. diff --git a/examples/autobahn/config/fuzzingclient.json b/examples/autobahn/config/fuzzingclient.json new file mode 100644 index 00000000..eda4e663 --- /dev/null +++ b/examples/autobahn/config/fuzzingclient.json @@ -0,0 +1,29 @@ +{ + "cases": ["*"], + "exclude-cases": [], + "exclude-agent-cases": {}, + "outdir": "/reports", + "options": {"failByDrop": false}, + "servers": [ + { + "agent": "ReadAllWriteMessage", + "url": "ws://host.docker.internal:9000/m" + }, + { + "agent": "ReadAllWritePreparedMessage", + "url": "ws://host.docker.internal:9000/p" + }, + { + "agent": "CopyFull", + "url": "ws://host.docker.internal:9000/f" + }, + { + "agent": "ReadAllWrite", + "url": "ws://host.docker.internal:9000/r" + }, + { + "agent": "CopyWriterOnly", + "url": "ws://host.docker.internal:9000/c" + } + ] +} diff --git a/examples/autobahn/fuzzingclient.json b/examples/autobahn/fuzzingclient.json deleted file mode 100644 index aa3a0bc0..00000000 --- a/examples/autobahn/fuzzingclient.json +++ /dev/null @@ -1,15 +0,0 @@ - -{ - "options": {"failByDrop": false}, - "outdir": "./reports/clients", - "servers": [ - {"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}}, - {"agent": "ReadAllWritePreparedMessage", "url": "ws://localhost:9000/p", "options": {"version": 18}}, - {"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}}, - {"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}}, - {"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}} - ], - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index 3db880f9..2d6d36f8 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) { } // echoReadAll echoes messages from the client by reading the entire message -// with ioutil.ReadAll. +// with io.ReadAll. func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -157,11 +157,11 @@ func echoReadAllWritePreparedMessage(w http.ResponseWriter, r *http.Request) { func serveHome(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - http.Error(w, "Not found.", 404) + http.Error(w, "Not found.", http.StatusNotFound) return } - if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "text/html; charset=utf-8") diff --git a/examples/chat/README.md b/examples/chat/README.md index 47c82f90..f8aecba0 100644 --- a/examples/chat/README.md +++ b/examples/chat/README.md @@ -1,6 +1,6 @@ # Chat Example -This application shows how to use use the +This application shows how to use the [websocket](https://github.com/gorilla/websocket) package to implement a simple web chat application. @@ -38,7 +38,7 @@ sends them to the hub. ### Hub The code for the `Hub` type is in -[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go). +[hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go). The application's `main` function starts the hub's `run` method as a goroutine. Clients send requests to the hub using the `register`, `unregister` and `broadcast` channels. @@ -57,7 +57,7 @@ unregisters the client and closes the websocket. ### Client -The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go). +The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go). The `serveWs` function is registered by the application's `main` function as an HTTP handler. The handler upgrades the HTTP connection to the WebSocket @@ -85,7 +85,7 @@ network. ## Frontend -The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html). +The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html). On document load, the script checks for websocket functionality in the browser. If websocket functionality is available, then the script opens a connection to diff --git a/examples/chat/client.go b/examples/chat/client.go index 26468477..9461c1ea 100644 --- a/examples/chat/client.go +++ b/examples/chat/client.go @@ -64,7 +64,7 @@ func (c *Client) readPump() { for { _, message, err := c.conn.ReadMessage() if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("error: %v", err) } break @@ -113,7 +113,7 @@ func (c *Client) writePump() { } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } @@ -129,6 +129,9 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { } client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} client.hub.register <- client + + // Allow collection of memory referenced by the caller by doing all work in + // new goroutines. go client.writePump() - client.readPump() + go client.readPump() } diff --git a/examples/chat/home.html b/examples/chat/home.html index a39a0c27..bf866aff 100644 --- a/examples/chat/home.html +++ b/examples/chat/home.html @@ -92,7 +92,7 @@
- +
diff --git a/examples/chat/hub.go b/examples/chat/hub.go index 7f07ea07..bb5c0e3b 100644 --- a/examples/chat/hub.go +++ b/examples/chat/hub.go @@ -4,7 +4,7 @@ package main -// hub maintains the set of active clients and broadcasts messages to the +// Hub maintains the set of active clients and broadcasts messages to the // clients. type Hub struct { // Registered clients. diff --git a/examples/chat/main.go b/examples/chat/main.go index 74615d59..474709f6 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -15,11 +15,11 @@ var addr = flag.String("addr", ":8080", "http service address") func serveHome(w http.ResponseWriter, r *http.Request) { log.Println(r.URL) if r.URL.Path != "/" { - http.Error(w, "Not found", 404) + http.Error(w, "Not found", http.StatusNotFound) return } - if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } http.ServeFile(w, r, "home.html") diff --git a/examples/command/main.go b/examples/command/main.go index 239c5c85..38d9f6c8 100644 --- a/examples/command/main.go +++ b/examples/command/main.go @@ -167,11 +167,11 @@ func serveWs(w http.ResponseWriter, r *http.Request) { func serveHome(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - http.Error(w, "Not found", 404) + http.Error(w, "Not found", http.StatusNotFound) return } - if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } http.ServeFile(w, r, "home.html") diff --git a/examples/echo/client.go b/examples/echo/client.go index 6578094e..7d870bdf 100644 --- a/examples/echo/client.go +++ b/examples/echo/client.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build ignore // +build ignore package main @@ -38,7 +39,6 @@ func main() { done := make(chan struct{}) go func() { - defer c.Close() defer close(done) for { _, message, err := c.ReadMessage() @@ -55,6 +55,8 @@ func main() { for { select { + case <-done: + return case t := <-ticker.C: err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) if err != nil { @@ -63,8 +65,9 @@ func main() { } case <-interrupt: log.Println("interrupt") - // To cleanly close a connection, a client should send a close - // frame and wait for the server to close the connection. + + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { log.Println("write close:", err) @@ -74,7 +77,6 @@ func main() { case <-done: case <-time.After(time.Second): } - c.Close() return } } diff --git a/examples/echo/server.go b/examples/echo/server.go index a685b097..f9a0b7b5 100644 --- a/examples/echo/server.go +++ b/examples/echo/server.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build ignore // +build ignore package main @@ -55,6 +56,7 @@ func main() { var homeTemplate = template.Must(template.New("").Parse(` +