diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..d2eae33e --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @nhooyr diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md deleted file mode 100644 index 357c314a..00000000 --- a/.github/CONTRIBUTING.md +++ /dev/null @@ -1,45 +0,0 @@ -# Contributing - -## Issues - -Please be as descriptive as possible. - -Reproducible examples are key to finding and fixing bugs. - -## Pull requests - -Good issues for first time contributors are marked as such. Feel free to -reach out for clarification on what needs to be done. - -Split up large changes into several small descriptive commits. - -Capitalize the first word in the commit message title. - -The commit message title should use the verb tense + phrase that completes the blank in - -> This change modifies websocket to \_\_\_\_\_\_\_\_\_ - -Be sure to [correctly link](https://help.github.com/en/articles/closing-issues-using-keywords) -to an existing issue if one exists. In general, create an issue before a PR to get some -discussion going and to make sure you do not spend time on a PR that may be rejected. - -CI must pass on your changes for them to be merged. - -### CI - -CI will ensure your code is formatted, lints and passes tests. -It will collect coverage and report it to [coveralls](https://coveralls.io/github/nhooyr/websocket) -and also upload a html `coverage` artifact that you can download to browse coverage. - -You can run CI locally. - -See [ci/image/Dockerfile](../ci/image/Dockerfile) for the installation of the CI dependencies on Ubuntu. - -1. `make fmt` performs code generation and formatting. -1. `make lint` performs linting. -1. `make test` runs tests. -1. `make` runs the above targets. - -For coverage details locally, see `ci/out/coverage.html` after running `make test`. - -You can run tests normally with `go test`. `make test` wraps around `go test` to collect coverage. diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index fce01709..00000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,4 +0,0 @@ - diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 901c994a..00000000 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,4 +0,0 @@ - diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2cc69828..4534425f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,24 +4,50 @@ on: [push, pull_request] jobs: fmt: runs-on: ubuntu-latest - container: nhooyr/websocket-ci@sha256:8a8fd73fdea33585d50a33619c4936adfd016246a2ed6bbfbf06def24b518a6a steps: - uses: actions/checkout@v1 - - run: make fmt + - uses: actions/cache@v1 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - name: Run make fmt + uses: ./ci/image + with: + args: make fmt + lint: runs-on: ubuntu-latest - container: nhooyr/websocket-ci@sha256:8a8fd73fdea33585d50a33619c4936adfd016246a2ed6bbfbf06def24b518a6a steps: - uses: actions/checkout@v1 - - run: make lint + - uses: actions/cache@v1 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - name: Run make lint + uses: ./ci/image + with: + args: make lint + test: runs-on: ubuntu-latest - container: nhooyr/websocket-ci@sha256:8a8fd73fdea33585d50a33619c4936adfd016246a2ed6bbfbf06def24b518a6a steps: - uses: actions/checkout@v1 - - run: make test + - uses: actions/cache@v1 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - name: Run make test + uses: ./ci/image + with: + args: make test env: - COVERALLS_TOKEN: ${{ secrets.github_token }} + COVERALLS_TOKEN: ${{ secrets.COVERALLS_TOKEN }} - name: Upload coverage.html uses: actions/upload-artifact@master with: diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..6961e5c8 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +websocket.test diff --git a/Makefile b/Makefile index 8c8e1a08..ad1ba257 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,3 @@ SHELL = bash include ci/fmt.mk include ci/lint.mk include ci/test.mk - -ci-image: - docker build -f ./ci/Dockerfile -t nhooyr/websocket-ci . - docker push nhooyr/websocket-ci diff --git a/README.md b/README.md index c426423a..631a14c9 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # websocket -[![GitHub Release](https://img.shields.io/github/v/release/nhooyr/websocket?color=6b9ded&sort=semver)](https://github.com/nhooyr/websocket/releases) -[![GoDoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) -[![Coveralls](https://img.shields.io/coveralls/github/nhooyr/websocket?color=65d6a4)](https://coveralls.io/github/nhooyr/websocket) -[![CI Status](https://github.com/nhooyr/websocket/workflows/ci/badge.svg)](https://github.com/nhooyr/websocket/actions) +[![release](https://img.shields.io/github/v/release/nhooyr/websocket?color=6b9ded&sort=semver)](https://github.com/nhooyr/websocket/releases) +[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) +[![coverage](https://img.shields.io/coveralls/github/nhooyr/websocket?color=65d6a4)](https://coveralls.io/github/nhooyr/websocket) +[![ci](https://github.com/nhooyr/websocket/workflows/ci/badge.svg)](https://github.com/nhooyr/websocket/actions) websocket is a minimal and idiomatic WebSocket library for Go. @@ -16,28 +16,25 @@ go get nhooyr.io/websocket ## Features - Minimal and idiomatic API -- Tiny codebase at 2200 lines - First class [context.Context](https://blog.golang.org/context) support -- Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) -- [Zero dependencies](https://godoc.org/nhooyr.io/websocket?imports) -- JSON and ProtoBuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages -- Highly optimized by default -- Concurrent writes out of the box -- [Complete Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) support +- Thorough tests, fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) +- [Minimal dependencies](https://godoc.org/nhooyr.io/websocket?imports) +- JSON and protobuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages +- Zero alloc reads and writes +- Concurrent writes - [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) +- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper +- [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API +- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression +- Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) ## Roadmap -- [ ] Compression Extensions [#163](https://github.com/nhooyr/websocket/pull/163) - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) ## Examples -For a production quality example that shows off the full API, see the [echo example on the godoc](https://godoc.org/nhooyr.io/websocket#example-package--Echo). On github, the example is at [example_echo_test.go](./example_echo_test.go). - -Use the [errors.As](https://golang.org/pkg/errors/#As) function [new in Go 1.13](https://golang.org/doc/go1.13#error_wrapping) to check for [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError). -There is also [websocket.CloseStatus](https://godoc.org/nhooyr.io/websocket#CloseStatus) to quickly grab the close status code out of a [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError). -See the [CloseStatus godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseStatus). +For a production quality example that demonstrates the complete API, see the [echo example](https://godoc.org/nhooyr.io/websocket#example-package--Echo). ### Server @@ -84,98 +81,52 @@ if err != nil { c.Close(websocket.StatusNormalClosure, "") ``` -## Design justifications - -- A minimal API is easier to maintain due to less docs, tests and bugs -- A minimal API is also easier to use and learn -- Context based cancellation is more ergonomic and robust than setting deadlines -- net.Conn is never exposed as WebSocket over HTTP/2 will not have a net.Conn. -- Using net/http's Client for dialing means we do not have to reinvent dialing hooks - and configurations like other WebSocket libraries - ## Comparison -Before the comparison, I want to point out that both gorilla/websocket and gobwas/ws were -extremely useful in implementing the WebSocket protocol correctly so _big thanks_ to the -authors of both. In particular, I made sure to go through the issue tracker of gorilla/websocket -to ensure I implemented details correctly and understood how people were using WebSockets in -production. - ### gorilla/websocket -https://github.com/gorilla/websocket - -The implementation of gorilla/websocket is 6 years old. As such, it is -widely used and very mature compared to nhooyr.io/websocket. - -On the other hand, it has grown organically and now there are too many ways to do -the same thing. Compare the godoc of -[nhooyr/websocket](https://godoc.org/nhooyr.io/websocket) with -[gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. - -The API for nhooyr.io/websocket has been designed such that there is only one way to do things. -This makes it easy to use correctly. Not only is the API simpler, the implementation is -only 2200 lines whereas gorilla/websocket is at 3500 lines. That's more code to maintain, -more code to test, more code to document and more surface area for bugs. - -Moreover, nhooyr.io/websocket supports newer Go idioms such as context.Context. -It also uses net/http's Client and ResponseWriter directly for WebSocket handshakes. -gorilla/websocket writes its handshakes to the underlying net.Conn. -Thus it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2. - -Some more advantages of nhooyr.io/websocket are that it supports concurrent writes and -makes it very easy to close the connection with a status code and reason. In fact, -nhooyr.io/websocket even implements the complete WebSocket close handshake for you whereas -with gorilla/websocket you have to perform it manually. See [gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448). - -The ping API is also nicer. gorilla/websocket requires registering a pong handler on the Conn -which results in awkward control flow. With nhooyr.io/websocket you use the Ping method on the Conn -that sends a ping and also waits for the pong. - -Additionally, nhooyr.io/websocket can compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) for the browser. - -In terms of performance, the differences mostly depend on your application code. nhooyr.io/websocket -reuses message buffers out of the box if you use the wsjson and wspb subpackages. -As mentioned above, nhooyr.io/websocket also supports concurrent writers. - -The WebSocket masking algorithm used by this package is also [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) -faster than gorilla/websocket or gobwas/ws while using only pure safe Go. +Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): -The only performance con to nhooyr.io/websocket is that it uses one extra goroutine to support -cancellation with context.Context. This costs 2 KB of memory which is cheap compared to -the benefits. +- Mature and widely used +- [Prepared writes](https://godoc.org/github.com/gorilla/websocket#PreparedMessage) +- Configurable [buffer sizes](https://godoc.org/github.com/gorilla/websocket#hdr-Buffers) -### x/net/websocket +Advantages of nhooyr.io/websocket: -https://godoc.org/golang.org/x/net/websocket - -Unmaintained and the API does not reflect WebSocket semantics. Should never be used. - -See https://github.com/golang/go/issues/18152 - -### gobwas/ws - -https://github.com/gobwas/ws - -This library has an extremely flexible API but that comes at the cost of usability -and clarity. - -This library is fantastic in terms of performance. The author put in significant -effort to ensure its speed and I have applied as many of its optimizations as -I could into nhooyr.io/websocket. Definitely check out his fantastic [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb) -about performant WebSocket servers. - -If you want a library that gives you absolute control over everything, this is the library. -But for 99.9% of use cases, nhooyr.io/websocket will fit better. It's nearly as performant -but much easier to use. - -## Contributing - -See [.github/CONTRIBUTING.md](.github/CONTRIBUTING.md). - -## Users - -If your company or project is using this library, feel free to open an issue or PR to amend this list. - -- [Coder](https://github.com/cdr) -- [Tatsu Works](https://github.com/tatsuworks) - Ingresses 20 TB in websocket data every month on their Discord bot. +- Minimal and idiomatic API + - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. +- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper +- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) +- Full [context.Context](https://blog.golang.org/context) support +- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) + - Will enable easy HTTP/2 support in the future + - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. +- Concurrent writes +- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) +- Idiomatic [ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API + - Gorilla requires registering a pong callback before sending a Ping +- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) +- Transparent message buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages +- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go + - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). +- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support + - Gorilla only supports no context takeover mode + - Uses [klauspost/compress](https://github.com/klauspost/compress) for optimized compression + - See [gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203) +- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) +- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) + +#### golang.org/x/net/websocket + +[golang.org/x/net/websocket](https://godoc.org/golang.org/x/net/websocket) is deprecated. +See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). + +The [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper will ease in transitioning +to nhooyr.io/websocket. + +#### gobwas/ws + +[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used +in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). + +However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. diff --git a/accept.go b/accept.go new file mode 100644 index 00000000..75d6d643 --- /dev/null +++ b/accept.go @@ -0,0 +1,330 @@ +// +build !js + +package websocket + +import ( + "bytes" + "crypto/sha1" + "encoding/base64" + "io" + "net/http" + "net/textproto" + "net/url" + "strings" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/errd" +) + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. + // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to + // reject it, close the connection when c.Subprotocol() == "". + Subprotocols []string + + // InsecureSkipVerify disables Accept's origin verification behaviour. By default, + // the connection will only be accepted if the request origin is equal to the request + // host. + // + // This is only required if you want javascript served from a different domain + // to access your WebSocket server. + // + // See https://stackoverflow.com/a/37837709/4283659 + // + // Please ensure you understand the ramifications of enabling this. + // If used incorrectly your WebSocket server will be open to CSRF attacks. + InsecureSkipVerify bool + + // CompressionMode controls the compression mode. + // Defaults to CompressionNoContextTakeover. + // + // See docs on CompressionMode for details. + CompressionMode CompressionMode + + // CompressionThreshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes + // for CompressionContextTakeover. + CompressionThreshold int +} + +// Accept accepts a WebSocket handshake from a client and upgrades the +// the connection to a WebSocket. +// +// Accept will not allow cross origin requests by default. +// See the InsecureSkipVerify option to allow cross origin requests. +// +// Accept will write a response to w on all errors. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return accept(w, r, opts) +} + +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { + defer errd.Wrap(&err, "failed to accept WebSocket connection") + + if opts == nil { + opts = &AcceptOptions{} + } + opts = &*opts + + errCode, err := verifyClientRequest(w, r) + if err != nil { + http.Error(w, err.Error(), errCode) + return nil, err + } + + if !opts.InsecureSkipVerify { + err = authenticateOrigin(r) + if err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return nil, err + } + } + + hj, ok := w.(http.Hijacker) + if !ok { + err = xerrors.New("http.ResponseWriter does not implement http.Hijacker") + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) + return nil, err + } + + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Connection", "Upgrade") + + key := r.Header.Get("Sec-WebSocket-Key") + w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + + subproto := selectSubprotocol(r, opts.Subprotocols) + if subproto != "" { + w.Header().Set("Sec-WebSocket-Protocol", subproto) + } + + copts, err := acceptCompression(r, w, opts.CompressionMode) + if err != nil { + return nil, err + } + + w.WriteHeader(http.StatusSwitchingProtocols) + + netConn, brw, err := hj.Hijack() + if err != nil { + err = xerrors.Errorf("failed to hijack connection: %w", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return nil, err + } + + // https://github.com/golang/go/issues/32314 + b, _ := brw.Reader.Peek(brw.Reader.Buffered()) + brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) + + return newConn(connConfig{ + subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), + rwc: netConn, + client: false, + copts: copts, + flateThreshold: opts.CompressionThreshold, + + br: brw.Reader, + bw: brw.Writer, + }), nil +} + +func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { + if !r.ProtoAtLeast(1, 1) { + return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + } + + if !headerContainsToken(r.Header, "Connection", "Upgrade") { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) + } + + if !headerContainsToken(r.Header, "Upgrade", "websocket") { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) + } + + if r.Method != "GET" { + return http.StatusMethodNotAllowed, xerrors.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) + } + + if r.Header.Get("Sec-WebSocket-Version") != "13" { + w.Header().Set("Sec-WebSocket-Version", "13") + return http.StatusBadRequest, xerrors.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) + } + + if r.Header.Get("Sec-WebSocket-Key") == "" { + return http.StatusBadRequest, xerrors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") + } + + return 0, nil +} + +func authenticateOrigin(r *http.Request) error { + origin := r.Header.Get("Origin") + if origin != "" { + u, err := url.Parse(origin) + if err != nil { + return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err) + } + if !strings.EqualFold(u.Host, r.Host) { + return xerrors.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + } + } + return nil +} + +func selectSubprotocol(r *http.Request, subprotocols []string) string { + cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") + for _, sp := range subprotocols { + for _, cp := range cps { + if strings.EqualFold(sp, cp) { + return cp + } + } + } + return "" +} + +func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { + if mode == CompressionDisabled { + return nil, nil + } + + for _, ext := range websocketExtensions(r.Header) { + switch ext.name { + case "permessage-deflate": + return acceptDeflate(w, ext, mode) + case "x-webkit-deflate-frame": + return acceptWebkitDeflate(w, ext, mode) + } + } + return nil, nil +} + +func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + continue + } + + if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") { + continue + } + + err := xerrors.Errorf("unsupported permessage-deflate parameter: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + + copts.setHeader(w.Header()) + + return copts, nil +} + +func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + // The peer must explicitly request it. + copts.serverNoContextTakeover = false + + for _, p := range ext.params { + if p == "no_context_takeover" { + copts.serverNoContextTakeover = true + continue + } + + // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead + // of ignoring it as the draft spec is unclear. It says the server can ignore it + // but the server has no way of signalling to the client it was ignored as the parameters + // are set one way. + // Thus us ignoring it would make the client think we understood it which would cause issues. + // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 + // + // Either way, we're only implementing this for webkit which never sends the max_window_bits + // parameter so we don't need to worry about it. + err := xerrors.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + + s := "x-webkit-deflate-frame" + if copts.clientNoContextTakeover { + s += "; no_context_takeover" + } + w.Header().Set("Sec-WebSocket-Extensions", s) + + return copts, nil +} + +func headerContainsToken(h http.Header, key, token string) bool { + token = strings.ToLower(token) + + for _, t := range headerTokens(h, key) { + if t == token { + return true + } + } + return false +} + +type websocketExtension struct { + name string + params []string +} + +func websocketExtensions(h http.Header) []websocketExtension { + var exts []websocketExtension + extStrs := headerTokens(h, "Sec-WebSocket-Extensions") + for _, extStr := range extStrs { + if extStr == "" { + continue + } + + vals := strings.Split(extStr, ";") + for i := range vals { + vals[i] = strings.TrimSpace(vals[i]) + } + + e := websocketExtension{ + name: vals[0], + params: vals[1:], + } + + exts = append(exts, e) + } + return exts +} + +func headerTokens(h http.Header, key string) []string { + key = textproto.CanonicalMIMEHeaderKey(key) + var tokens []string + for _, v := range h[key] { + v = strings.TrimSpace(v) + for _, t := range strings.Split(v, ",") { + t = strings.ToLower(t) + tokens = append(tokens, t) + } + } + return tokens +} + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func secWebSocketAccept(secWebSocketKey string) string { + h := sha1.New() + h.Write([]byte(secWebSocketKey)) + h.Write(keyGUID) + + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/accept_js.go b/accept_js.go new file mode 100644 index 00000000..5db12d7b --- /dev/null +++ b/accept_js.go @@ -0,0 +1,20 @@ +package websocket + +import ( + "net/http" + + "golang.org/x/xerrors" +) + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + Subprotocols []string + InsecureSkipVerify bool + CompressionMode CompressionMode + CompressionThreshold int +} + +// Accept is stubbed out for Wasm. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return nil, xerrors.New("unimplemented") +} diff --git a/handshake_test.go b/accept_test.go similarity index 52% rename from handshake_test.go rename to accept_test.go index cb09353f..53338e17 100644 --- a/handshake_test.go +++ b/accept_test.go @@ -3,12 +3,16 @@ package websocket import ( - "context" + "bufio" + "net" "net/http" "net/http/httptest" "strings" "testing" - "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/test/assert" ) func TestAccept(t *testing.T) { @@ -21,10 +25,39 @@ func TestAccept(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) _, err := Accept(w, r, nil) - if err == nil { - t.Fatalf("unexpected error value: %v", err) + assert.Contains(t, err, "protocol violation") + }) + + t.Run("badOrigin", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Origin", "harhar.com") + + _, err := Accept(w, r, nil) + assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`) + }) + + t.Run("badCompression", func(t *testing.T) { + t.Parallel() + + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), } + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") + _, err := Accept(w, r, nil) + assert.Contains(t, err, `unsupported permessage-deflate parameter`) }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -38,9 +71,27 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) - if err == nil || !strings.Contains(err.Error(), "http.Hijacker") { - t.Fatalf("unexpected error value: %v", err) + assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`) + }) + + t.Run("badHijack", func(t *testing.T) { + t.Parallel() + + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, xerrors.New("haha") + }, } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + + _, err := Accept(w, r, nil) + assert.Contains(t, err, `failed to hijack connection`) }) } @@ -119,7 +170,6 @@ func Test_verifyClientHandshake(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - w := httptest.NewRecorder() r := httptest.NewRequest(tc.method, "/", nil) r.ProtoMajor = 1 @@ -132,9 +182,11 @@ func Test_verifyClientHandshake(t *testing.T) { r.Header.Set(k, v) } - err := verifyClientRequest(w, r) - if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %+v", err) + _, err := verifyClientRequest(httptest.NewRecorder(), r) + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) } }) } @@ -184,9 +236,7 @@ func Test_selectSubprotocol(t *testing.T) { r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) negotiated := selectSubprotocol(r, tc.serverProtocols) - if tc.negotiated != negotiated { - t.Fatalf("expected %q but got %q", tc.negotiated, negotiated) - } + assert.Equal(t, "negotiated", tc.negotiated, negotiated) }) } } @@ -240,120 +290,67 @@ func Test_authenticateOrigin(t *testing.T) { r.Header.Set("Origin", tc.origin) err := authenticateOrigin(r) - if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %+v", err) + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) } }) } } -func TestBadDials(t *testing.T) { +func Test_acceptCompression(t *testing.T) { t.Parallel() testCases := []struct { - name string - url string - opts *DialOptions + name string + mode CompressionMode + reqSecWebSocketExtensions string + respSecWebSocketExtensions string + expCopts *compressionOptions + error bool }{ { - name: "badURL", - url: "://noscheme", + name: "disabled", + mode: CompressionDisabled, + expCopts: nil, }, { - name: "badURLScheme", - url: "ftp://nhooyr.io", + name: "noClientSupport", + mode: CompressionNoContextTakeover, + expCopts: nil, }, { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, - }, + name: "permessage-deflate", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", + respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, }, }, { - name: "badTLS", - url: "wss://totallyfake.nhooyr.io", - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - _, _, err := Dial(ctx, tc.url, tc.opts) - if err == nil { - t.Fatalf("expected non nil error: %+v", err) - } - }) - } -} - -func Test_verifyServerHandshake(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - response func(w http.ResponseWriter) - success bool - }{ - { - name: "badStatus", - response: func(w http.ResponseWriter) { - w.WriteHeader(http.StatusOK) - }, - success: false, - }, - { - name: "badConnection", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "???") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badUpgrade", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "???") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, + name: "permessage-deflate/error", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "permessage-deflate; meow", + error: true, }, { - name: "badSecWebSocketAccept", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Sec-WebSocket-Accept", "xd") - w.WriteHeader(http.StatusSwitchingProtocols) + name: "x-webkit-deflate-frame", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", + respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, }, - success: false, }, { - name: "badSecWebSocketProtocol", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Sec-WebSocket-Protocol", "xd") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "success", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: true, + name: "x-webkit-deflate/error", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", + error: true, }, } @@ -362,25 +359,30 @@ func Test_verifyServerHandshake(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - w := httptest.NewRecorder() - tc.response(w) - resp := w.Result() - - r := httptest.NewRequest("GET", "/", nil) - key, err := makeSecWebSocketKey() - if err != nil { - t.Fatal(err) - } - r.Header.Set("Sec-WebSocket-Key", key) + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) - if resp.Header.Get("Sec-WebSocket-Accept") == "" { - resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + w := httptest.NewRecorder() + copts, err := acceptCompression(r, w, tc.mode) + if tc.error { + assert.Error(t, err) + return } - err = verifyServerResponse(r, resp) - if (err == nil) != tc.success { - t.Fatalf("unexpected error: %+v", err) - } + assert.Success(t, err) + assert.Equal(t, "compression options", tc.expCopts, copts) + assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) }) } } + +type mockHijacker struct { + http.ResponseWriter + hijack func() (net.Conn, *bufio.ReadWriter, error) +} + +var _ http.Hijacker = mockHijacker{} + +func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return mj.hijack() +} diff --git a/assert_test.go b/assert_test.go deleted file mode 100644 index 26fd1d48..00000000 --- a/assert_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package websocket_test - -import ( - "context" - "math/rand" - "strings" - "time" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/assert" - "nhooyr.io/websocket/wsjson" -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error { - exp := randString(n) - err := wsjson.Write(ctx, c, exp) - if err != nil { - return err - } - - var act interface{} - err = wsjson.Read(ctx, c, &act) - if err != nil { - return err - } - - return assert.Equalf(exp, act, "unexpected JSON") -} - -func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { - var act interface{} - err := wsjson.Read(ctx, c, &act) - if err != nil { - return err - } - - return assert.Equalf(exp, act, "unexpected JSON") -} - -func randBytes(n int) []byte { - b := make([]byte, n) - rand.Read(b) - return b -} - -func randString(n int) string { - s := strings.ToValidUTF8(string(randBytes(n)), "_") - if len(s) > n { - return s[:n] - } - if len(s) < n { - // Pad with = - extra := n - len(s) - return s + strings.Repeat("=", extra) - } - return s -} - -func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) error { - p := randBytes(n) - err := c.Write(ctx, typ, p) - if err != nil { - return err - } - typ2, p2, err := c.Read(ctx) - if err != nil { - return err - } - err = assert.Equalf(typ, typ2, "unexpected data type") - if err != nil { - return err - } - return assert.Equalf(p, p2, "unexpected payload") -} - -func assertSubprotocol(c *websocket.Conn, exp string) error { - return assert.Equalf(exp, c.Subprotocol(), "unexpected subprotocol") -} diff --git a/autobahn_test.go b/autobahn_test.go new file mode 100644 index 00000000..fb24a06b --- /dev/null +++ b/autobahn_test.go @@ -0,0 +1,229 @@ +// +build !js + +package websocket_test + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "os" + "os/exec" + "strconv" + "strings" + "testing" + "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/test/wstest" +) + +var excludedAutobahnCases = []string{ + // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just + // more performance overhead. + "6.*", "7.5.1", + + // We skip the tests related to requestMaxWindowBits as that is unimplemented due + // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 + "13.3.*", "13.4.*", "13.5.*", "13.6.*", +} + +var autobahnCases = []string{"*"} + +func TestAutobahn(t *testing.T) { + t.Parallel() + + if os.Getenv("AUTOBAHN_TEST") == "" { + t.SkipNow() + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) + defer cancel() + + wstestURL, closeFn, err := wstestClientServer(ctx) + assert.Success(t, err) + defer closeFn() + + err = waitWS(ctx, wstestURL) + assert.Success(t, err) + + cases, err := wstestCaseCount(ctx, wstestURL) + assert.Success(t, err) + + t.Run("cases", func(t *testing.T) { + for i := 1; i <= cases; i++ { + i := i + t.Run("", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) + assert.Success(t, err) + err = wstest.EchoLoop(ctx, c) + t.Logf("echoLoop: %v", err) + }) + } + }) + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) + assert.Success(t, err) + c.Close(websocket.StatusNormalClosure, "") + + checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") +} + +func waitWS(ctx context.Context, url string) error { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + for ctx.Err() == nil { + c, _, err := websocket.Dial(ctx, url, nil) + if err != nil { + continue + } + c.Close(websocket.StatusNormalClosure, "") + return nil + } + + return ctx.Err() +} + +func wstestClientServer(ctx context.Context) (url string, closeFn func(), err error) { + serverAddr, err := unusedListenAddr() + if err != nil { + return "", nil, err + } + + url = "ws://" + serverAddr + + specFile, err := tempJSONFile(map[string]interface{}{ + "url": url, + "outdir": "ci/out/wstestClientReports", + "cases": autobahnCases, + "exclude-cases": excludedAutobahnCases, + }) + if err != nil { + return "", nil, xerrors.Errorf("failed to write spec: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) + defer func() { + if err != nil { + cancel() + } + }() + + args := []string{"--mode", "fuzzingserver", "--spec", specFile, + // Disables some server that runs as part of fuzzingserver mode. + // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 + "--webport=0", + } + wstest := exec.CommandContext(ctx, "wstest", args...) + err = wstest.Start() + if err != nil { + return "", nil, xerrors.Errorf("failed to start wstest: %w", err) + } + + return url, func() { + wstest.Process.Kill() + }, nil +} + +func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { + defer errd.Wrap(&err, "failed to get case count") + + c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil) + if err != nil { + return 0, err + } + defer c.Close(websocket.StatusInternalError, "") + + _, r, err := c.Reader(ctx) + if err != nil { + return 0, err + } + b, err := ioutil.ReadAll(r) + if err != nil { + return 0, err + } + cases, err = strconv.Atoi(string(b)) + if err != nil { + return 0, err + } + + c.Close(websocket.StatusNormalClosure, "") + + return cases, nil +} + +func checkWSTestIndex(t *testing.T, path string) { + wstestOut, err := ioutil.ReadFile(path) + assert.Success(t, err) + + var indexJSON map[string]map[string]struct { + Behavior string `json:"behavior"` + BehaviorClose string `json:"behaviorClose"` + } + err = json.Unmarshal(wstestOut, &indexJSON) + assert.Success(t, err) + + for _, tests := range indexJSON { + for test, result := range tests { + t.Run(test, func(t *testing.T) { + switch result.BehaviorClose { + case "OK", "INFORMATIONAL": + default: + t.Errorf("bad close behaviour") + } + + switch result.Behavior { + case "OK", "NON-STRICT", "INFORMATIONAL": + default: + t.Errorf("failed") + } + }) + } + } + + if t.Failed() { + htmlPath := strings.Replace(path, ".json", ".html", 1) + t.Errorf("detected autobahn violation, see %q", htmlPath) + } +} + +func unusedListenAddr() (_ string, err error) { + defer errd.Wrap(&err, "failed to get unused listen address") + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + l.Close() + return l.Addr().String(), nil +} + +func tempJSONFile(v interface{}) (string, error) { + f, err := ioutil.TempFile("", "temp.json") + if err != nil { + return "", xerrors.Errorf("temp file: %w", err) + } + defer f.Close() + + e := json.NewEncoder(f) + e.SetIndent("", "\t") + err = e.Encode(v) + if err != nil { + return "", xerrors.Errorf("json encode: %w", err) + } + + err = f.Close() + if err != nil { + return "", xerrors.Errorf("close temp file: %w", err) + } + + return f.Name(), nil +} diff --git a/ci/Dockerfile b/ci/Dockerfile deleted file mode 100644 index 0f0fc7d9..00000000 --- a/ci/Dockerfile +++ /dev/null @@ -1,31 +0,0 @@ -FROM golang:1 - -RUN apt-get update -RUN apt-get install -y chromium -RUN apt-get install -y npm -RUN apt-get install -y jq - -ENV GOPATH=/root/gopath -ENV PATH=$GOPATH/bin:$PATH -ENV GOFLAGS="-mod=readonly" -ENV PAGER=cat -ENV CI=true -ENV MAKEFLAGS="--jobs=8 --output-sync=target" - -RUN npm install -g prettier -RUN go get golang.org/x/tools/cmd/stringer -RUN go get golang.org/x/tools/cmd/goimports -RUN go get golang.org/x/lint/golint -RUN go get github.com/agnivade/wasmbrowsertest -RUN go get github.com/mattn/goveralls - -# Cache go modules and build cache. -COPY . /tmp/websocket -RUN cd /tmp/websocket && \ - CI= make && \ - rm -rf /tmp/websocket - -# GitHub actions tries to override HOME to /github/home and then -# mounts a temp directory into there. We do not want this behaviour. -# I assume it is so that $HOME is preserved between steps in a job. -ENTRYPOINT ["env", "HOME=/root"] diff --git a/ci/fmt.mk b/ci/fmt.mk index 8e61bc24..f3969721 100644 --- a/ci/fmt.mk +++ b/ci/fmt.mk @@ -22,4 +22,4 @@ prettier: prettier --write --print-width=120 --no-semi --trailing-comma=all --loglevel=warn $$(git ls-files "*.yml" "*.md") gen: - go generate ./... + stringer -type=opcode,MessageType,StatusCode -output=stringer.go diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile new file mode 100644 index 00000000..88c96502 --- /dev/null +++ b/ci/image/Dockerfile @@ -0,0 +1,16 @@ +FROM golang:1 + +RUN apt-get update +RUN apt-get install -y chromium npm + +ENV GOFLAGS="-mod=readonly" +ENV PAGER=cat +ENV CI=true +ENV MAKEFLAGS="--jobs=16 --output-sync=target" + +RUN npm install -g prettier +RUN go get golang.org/x/tools/cmd/stringer +RUN go get golang.org/x/tools/cmd/goimports +RUN go get golang.org/x/lint/golint +RUN go get github.com/agnivade/wasmbrowsertest +RUN go get github.com/mattn/goveralls diff --git a/ci/lint.mk b/ci/lint.mk index a656ea8d..031f0de3 100644 --- a/ci/lint.mk +++ b/ci/lint.mk @@ -1,4 +1,4 @@ -lint: govet golint govet-wasm golint-wasm +lint: govet golint govet: go vet ./... diff --git a/ci/test.mk b/ci/test.mk index 3183552e..3d1f0ed1 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -1,4 +1,4 @@ -test: gotest ci/out/coverage.html +test: ci/out/coverage.html ifdef CI test: coveralls endif @@ -9,17 +9,9 @@ ci/out/coverage.html: gotest coveralls: gotest # https://github.com/coverallsapp/github-action/blob/master/src/run.ts echo "--- coveralls" - export GIT_BRANCH="$$GITHUB_REF" - export BUILD_NUMBER="$$GITHUB_SHA" - if [[ $$GITHUB_EVENT_NAME == pull_request ]]; then - export CI_PULL_REQUEST="$$(jq .number "$$GITHUB_EVENT_PATH")" - BUILD_NUMBER="$$BUILD_NUMBER-PR-$$CI_PULL_REQUEST" - fi - goveralls -coverprofile=ci/out/coverage.prof -service=github + goveralls -coverprofile=ci/out/coverage.prof gotest: - go test -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... - sed -i '/_stringer\.go/d' ci/out/coverage.prof - sed -i '/wsecho\.go/d' ci/out/coverage.prof - sed -i '/assert\.go/d' ci/out/coverage.prof - sed -i '/wsgrace\.go/d' ci/out/coverage.prof + go test -timeout=30m -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... + sed -i '/stringer\.go/d' ci/out/coverage.prof + sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof diff --git a/close.go b/close.go new file mode 100644 index 00000000..20073233 --- /dev/null +++ b/close.go @@ -0,0 +1,77 @@ +package websocket + +import ( + "fmt" + + "golang.org/x/xerrors" +) + +// StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode int + +// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// +// These are only the status codes defined by the protocol. +// +// You can define custom codes in the 3000-4999 range. +// The 3000-3999 range is reserved for use by libraries, frameworks and applications. +// The 4000-4999 range is reserved for private use. +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + + // 1004 is reserved and so unexported. + statusReserved StatusCode = 1004 + + // StatusNoStatusRcvd cannot be sent in a close message. + // It is reserved for when a close message is received without + // a status code. + StatusNoStatusRcvd StatusCode = 1005 + + // StatusAbnormalClosure is exported for use only with Wasm. + // In non Wasm Go, the returned error will indicate whether the + // connection was closed abnormally. + StatusAbnormalClosure StatusCode = 1006 + + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExtension StatusCode = 1010 + StatusInternalError StatusCode = 1011 + StatusServiceRestart StatusCode = 1012 + StatusTryAgainLater StatusCode = 1013 + StatusBadGateway StatusCode = 1014 + + // StatusTLSHandshake is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether there was + // a TLS handshake failure. + StatusTLSHandshake StatusCode = 1015 +) + +// CloseError is returned when the connection is closed with a status and reason. +// +// Use Go 1.13's xerrors.As to check for this error. +// Also see the CloseStatus helper. +type CloseError struct { + Code StatusCode + Reason string +} + +func (ce CloseError) Error() string { + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) +} + +// CloseStatus is a convenience wrapper around Go 1.13's xerrors.As to grab +// the status code from a CloseError. +// +// -1 will be returned if the passed error is nil or not a CloseError. +func CloseStatus(err error) StatusCode { + var ce CloseError + if xerrors.As(err, &ce) { + return ce.Code + } + return -1 +} diff --git a/close_notjs.go b/close_notjs.go new file mode 100644 index 00000000..3367ea01 --- /dev/null +++ b/close_notjs.go @@ -0,0 +1,203 @@ +// +build !js + +package websocket + +import ( + "context" + "encoding/binary" + "log" + "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/errd" +) + +// Close performs the WebSocket close handshake with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// All data messages received from the peer during the close handshake will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes. Avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) error { + return c.closeHandshake(code, reason) +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + err = c.writeClose(code, reason) + if err != nil && CloseStatus(err) == -1 && err != errAlreadyWroteClose { + return err + } + + err = c.waitCloseHandshake() + if CloseStatus(err) == -1 { + return err + } + return nil +} + +var errAlreadyWroteClose = xerrors.New("already wrote close") + +func (c *Conn) writeClose(code StatusCode, reason string) error { + c.closeMu.Lock() + closing := c.wroteClose + c.wroteClose = true + c.closeMu.Unlock() + if closing { + return errAlreadyWroteClose + } + + ce := CloseError{ + Code: code, + Reason: reason, + } + + c.setCloseErr(xerrors.Errorf("sent close frame: %w", ce)) + + var p []byte + var err error + if ce.Code != StatusNoStatusRcvd { + p, err = ce.bytes() + if err != nil { + log.Printf("websocket: %v", err) + } + } + + werr := c.writeControl(context.Background(), opClose, p) + if err != nil { + return err + } + return werr +} + +func (c *Conn) waitCloseHandshake() error { + defer c.close(nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + err := c.readMu.Lock(ctx) + if err != nil { + return err + } + defer c.readMu.Unlock() + + if c.readCloseFrameErr != nil { + return c.readCloseFrameErr + } + + for { + h, err := c.readLoop(ctx) + if err != nil { + return err + } + + for i := int64(0); i < h.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + } +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + if len(p) < 2 { + return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() ([]byte, error) { + p, err := ce.bytesErr() + if err != nil { + err = xerrors.Errorf("failed to marshal close frame: %w", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() + } + return p, err +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + + if !validWireCloseCode(ce.Code) { + return nil, xerrors.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +func (c *Conn) setCloseErr(err error) { + c.closeMu.Lock() + c.setCloseErrLocked(err) + c.closeMu.Unlock() +} + +func (c *Conn) setCloseErrLocked(err error) { + if c.closeErr == nil { + c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) + } +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/close_test.go b/close_test.go new file mode 100644 index 00000000..00a48d9e --- /dev/null +++ b/close_test.go @@ -0,0 +1,207 @@ +// +build !js + +package websocket + +import ( + "io" + "math" + "strings" + "testing" + + "nhooyr.io/websocket/internal/test/assert" +) + +func TestCloseError(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + ce CloseError + success bool + }{ + { + name: "normal", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", maxCloseReason), + }, + success: true, + }, + { + name: "bigReason", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", maxCloseReason+1), + }, + success: false, + }, + { + name: "bigCode", + ce: CloseError{ + Code: math.MaxUint16, + Reason: strings.Repeat("x", maxCloseReason), + }, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := tc.ce.bytesErr() + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) + } + }) + } + + t.Run("Error", func(t *testing.T) { + exp := `status = StatusInternalError and reason = "meow"` + act := CloseError{ + Code: StatusInternalError, + Reason: "meow", + }.Error() + assert.Equal(t, "CloseError.Error()", exp, act) + }) +} + +func Test_parseClosePayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + p []byte + success bool + ce CloseError + }{ + { + name: "normal", + p: append([]byte{0x3, 0xE8}, []byte("hello")...), + success: true, + ce: CloseError{ + Code: StatusNormalClosure, + Reason: "hello", + }, + }, + { + name: "nothing", + success: true, + ce: CloseError{ + Code: StatusNoStatusRcvd, + }, + }, + { + name: "oneByte", + p: []byte{0}, + success: false, + }, + { + name: "badStatusCode", + p: []byte{0x17, 0x70}, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ce, err := parseClosePayload(tc.p) + if tc.success { + assert.Success(t, err) + assert.Equal(t, "close payload", tc.ce, ce) + } else { + assert.Error(t, err) + } + }) + } +} + +func Test_validWireCloseCode(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + code StatusCode + valid bool + }{ + { + name: "normal", + code: StatusNormalClosure, + valid: true, + }, + { + name: "noStatus", + code: StatusNoStatusRcvd, + valid: false, + }, + { + name: "3000", + code: 3000, + valid: true, + }, + { + name: "4999", + code: 4999, + valid: true, + }, + { + name: "unknown", + code: 5000, + valid: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + act := validWireCloseCode(tc.code) + assert.Equal(t, "wire close code", tc.valid, act) + }) + } +} + +func TestCloseStatus(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + in error + exp StatusCode + }{ + { + name: "nil", + in: nil, + exp: -1, + }, + { + name: "io.EOF", + in: io.EOF, + exp: -1, + }, + { + name: "StatusInternalError", + in: CloseError{ + Code: StatusInternalError, + }, + exp: StatusInternalError, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + act := CloseStatus(tc.in) + assert.Equal(t, "close status", tc.exp, act) + }) + } +} diff --git a/compress.go b/compress.go new file mode 100644 index 00000000..57446d01 --- /dev/null +++ b/compress.go @@ -0,0 +1,38 @@ +package websocket + +// CompressionMode represents the modes available to the deflate extension. +// See https://tools.ietf.org/html/rfc7692 +// +// A compatibility layer is implemented for the older deflate-frame extension used +// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 +// It will work the same in every way except that we cannot signal to the peer we +// want to use no context takeover on our side, we can only signal that they should. +type CompressionMode int + +const ( + // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed + // for every message. This applies to both server and client side. + // + // This means less efficient compression as the sliding window from previous messages + // will not be used but the memory overhead will be lower if the connections + // are long lived and seldom used. + // + // The message will only be compressed if greater than 512 bytes. + CompressionNoContextTakeover CompressionMode = iota + + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this can be very efficient. + // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. + CompressionContextTakeover + + // CompressionDisabled disables the deflate extension. + // + // Use this if you are using a predominantly binary protocol with very + // little duplication in between messages or CPU and memory are more + // important than bandwidth. + CompressionDisabled +) diff --git a/compress_notjs.go b/compress_notjs.go new file mode 100644 index 00000000..a6911056 --- /dev/null +++ b/compress_notjs.go @@ -0,0 +1,178 @@ +// +build !js + +package websocket + +import ( + "io" + "net/http" + "sync" + + "github.com/klauspost/compress/flate" +) + +func (m CompressionMode) opts() *compressionOptions { + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to return more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + if tw != nil && tw.tail != nil { + tw.tail = tw.tail[:0] + } +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + if tw.tail == nil { + tw.tail = make([]byte, 0, 4) + } + + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + + // Shift remaining bytes in tail over. + n := copy(tw.tail, tw.tail[extra:]) + tw.tail = tw.tail[:n] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader, dict []byte) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReaderDict(r, dict) + } + fr.(flate.Resetter).Reset(r, dict) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +type slidingWindow struct { + buf []byte +} + +var swPoolMu sync.RWMutex +var swPool = map[int]*sync.Pool{} + +func slidingWindowPool(n int) *sync.Pool { + swPoolMu.RLock() + p, ok := swPool[n] + swPoolMu.RUnlock() + if ok { + return p + } + + p = &sync.Pool{} + + swPoolMu.Lock() + swPool[n] = p + swPoolMu.Unlock() + + return p +} + +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return + } + + p := slidingWindowPool(n) + buf, ok := p.Get().([]byte) + if ok { + sw.buf = buf[:0] + } else { + sw.buf = make([]byte, 0, n) + } +} + +func (sw *slidingWindow) close() { + if sw.buf == nil { + return + } + + swPoolMu.Lock() + defer swPoolMu.Unlock() + + swPool[cap(sw.buf)].Put(sw.buf) + sw.buf = nil +} + +func (sw *slidingWindow) write(p []byte) { + if len(p) >= cap(sw.buf) { + sw.buf = sw.buf[:cap(sw.buf)] + p = p[len(p)-cap(sw.buf):] + copy(sw.buf, p) + return + } + + left := cap(sw.buf) - len(sw.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(sw.buf, sw.buf[spaceNeeded:]) + sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] + } + + sw.buf = append(sw.buf, p...) +} diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 00000000..2c4c896c --- /dev/null +++ b/compress_test.go @@ -0,0 +1,34 @@ +// +build !js + +package websocket + +import ( + "strings" + "testing" + + "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/test/xrand" +) + +func Test_slidingWindow(t *testing.T) { + t.Parallel() + + const testCount = 99 + const maxWindow = 99999 + for i := 0; i < testCount; i++ { + t.Run("", func(t *testing.T) { + t.Parallel() + + input := xrand.String(maxWindow) + windowLength := xrand.Int(maxWindow) + var sw slidingWindow + sw.init(windowLength) + sw.write([]byte(input)) + + assert.Equal(t, "window length", windowLength, cap(sw.buf)) + if !strings.HasSuffix(input, string(sw.buf)) { + t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf) + } + }) + } +} diff --git a/conn.go b/conn.go index 26906c79..a41808be 100644 --- a/conn.go +++ b/conn.go @@ -1,1064 +1,13 @@ -// +build !js - package websocket -import ( - "bufio" - "context" - "crypto/rand" - "encoding/binary" - "errors" - "fmt" - "io" - "io/ioutil" - "log" - "runtime" - "strconv" - "sync" - "sync/atomic" - "time" +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int - "nhooyr.io/websocket/internal/bpool" +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like protobufs. + MessageBinary ) - -// Conn represents a WebSocket connection. -// All methods may be called concurrently except for Reader and Read. -// -// You must always read from the connection. Otherwise control -// frames will not be handled. See the docs on Reader and CloseRead. -// -// Be sure to call Close on the connection when you -// are finished with it to release the associated resources. -// -// Every error from Read or Reader will cause the connection -// to be closed so you do not need to write your own error message. -// This applies to the Read methods in the wsjson/wspb subpackages as well. -type Conn struct { - subprotocol string - br *bufio.Reader - bw *bufio.Writer - // writeBuf is used for masking, its the buffer in bufio.Writer. - // Only used by the client for masking the bytes in the buffer. - writeBuf []byte - closer io.Closer - client bool - - closeOnce sync.Once - closeErrOnce sync.Once - closeErr error - closed chan struct{} - closing *atomicInt64 - closeReceived error - - // messageWriter state. - // writeMsgLock is acquired to write a data message. - writeMsgLock chan struct{} - // writeFrameLock is acquired to write a single frame. - // Effectively meaning whoever holds it gets to write to bw. - writeFrameLock chan struct{} - writeHeaderBuf []byte - writeHeader *header - // read limit for a message in bytes. - msgReadLimit *atomicInt64 - - // Used to ensure a previous writer is not used after being closed. - activeWriter atomic.Value - // messageWriter state. - writeMsgOpcode opcode - writeMsgCtx context.Context - readMsgLeft int64 - - // Used to ensure the previous reader is read till EOF before allowing - // a new one. - activeReader *messageReader - // readFrameLock is acquired to read from bw. - readFrameLock chan struct{} - isReadClosed *atomicInt64 - readHeaderBuf []byte - controlPayloadBuf []byte - readLock chan struct{} - - // messageReader state. - readerMsgCtx context.Context - readerMsgHeader header - readerFrameEOF bool - readerMaskKey uint32 - - setReadTimeout chan context.Context - setWriteTimeout chan context.Context - - pingCounter *atomicInt64 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} - - logf func(format string, v ...interface{}) -} - -func (c *Conn) init() { - c.closed = make(chan struct{}) - c.closing = &atomicInt64{} - - c.msgReadLimit = &atomicInt64{} - c.msgReadLimit.Store(32768) - - c.writeMsgLock = make(chan struct{}, 1) - c.writeFrameLock = make(chan struct{}, 1) - - c.readFrameLock = make(chan struct{}, 1) - c.readLock = make(chan struct{}, 1) - - c.setReadTimeout = make(chan context.Context) - c.setWriteTimeout = make(chan context.Context) - - c.pingCounter = &atomicInt64{} - c.activePings = make(map[string]chan<- struct{}) - - c.writeHeaderBuf = makeWriteHeaderBuf() - c.writeHeader = &header{} - c.readHeaderBuf = makeReadHeaderBuf() - c.isReadClosed = &atomicInt64{} - c.controlPayloadBuf = make([]byte, maxControlFramePayload) - - runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) - }) - - c.logf = log.Printf - - go c.timeoutLoop() -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol -} - -func (c *Conn) close(err error) { - c.closeOnce.Do(func() { - runtime.SetFinalizer(c, nil) - - c.setCloseErr(err) - close(c.closed) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.closer.Close() - - // See comment on bufioReaderPool in handshake.go - if c.client { - // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer - // and we can safely return them. - // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent - // a deadlock. - // As of now, this is in writeFrame, readFramePayload and readHeader. - c.readFrameLock <- struct{}{} - returnBufioReader(c.br) - - c.writeFrameLock <- struct{}{} - returnBufioWriter(c.bw) - } - }) -} - -func (c *Conn) timeoutLoop() { - readCtx := context.Background() - writeCtx := context.Background() - - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.setWriteTimeout: - case readCtx = <-c.setReadTimeout: - - case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - // Guaranteed to eventually close the connection since we can only ever send - // one close frame. - go func() { - c.exportedClose(StatusPolicyViolation, "read timed out", true) - // Ensure the connection closes, i.e if we already sent a close frame and timed out - // to read the peer's close frame. - c.close(nil) - }() - readCtx = context.Background() - case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) - return - } - } -} - -func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { - select { - case <-ctx.Done(): - var err error - switch lock { - case c.writeFrameLock, c.writeMsgLock: - err = fmt.Errorf("could not acquire write lock: %v", ctx.Err()) - case c.readFrameLock, c.readLock: - err = fmt.Errorf("could not acquire read lock: %v", ctx.Err()) - default: - panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err())) - } - c.close(err) - return ctx.Err() - case <-c.closed: - return c.closeErr - case lock <- struct{}{}: - return nil - } -} - -func (c *Conn) releaseLock(lock chan struct{}) { - // Allow multiple releases. - select { - case <-lock: - default: - } -} - -func (c *Conn) readTillMsg(ctx context.Context) (header, error) { - for { - h, err := c.readFrameHeader(ctx) - if err != nil { - return header{}, err - } - - if h.rsv1 || h.rsv2 || h.rsv3 { - err := fmt.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) - c.exportedClose(StatusProtocolError, err.Error(), false) - return header{}, err - } - - if h.opcode.controlOp() { - err = c.handleControl(ctx, h) - if err != nil { - // Pass through CloseErrors when receiving a close frame. - if h.opcode == opClose && CloseStatus(err) != -1 { - return header{}, err - } - return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) - } - continue - } - - switch h.opcode { - case opBinary, opText, opContinuation: - return h, nil - default: - err := fmt.Errorf("received unknown opcode %v", h.opcode) - c.exportedClose(StatusProtocolError, err.Error(), false) - return header{}, err - } - } -} - -func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { - wrap := func(err error) error { - return fmt.Errorf("failed to read frame header: %w", err) - } - defer func() { - if err != nil { - err = wrap(err) - } - }() - - err = c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return header{}, err - } - defer c.releaseLock(c.readFrameLock) - - select { - case <-c.closed: - return header{}, c.closeErr - case c.setReadTimeout <- ctx: - } - - h, err := readHeader(c.readHeaderBuf, c.br) - if err != nil { - select { - case <-c.closed: - return header{}, c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - c.releaseLock(c.readFrameLock) - c.close(wrap(err)) - return header{}, err - } - - select { - case <-c.closed: - return header{}, c.closeErr - case c.setReadTimeout <- context.Background(): - } - - return h, nil -} - -func (c *Conn) handleControl(ctx context.Context, h header) error { - if h.payloadLength > maxControlFramePayload { - err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength) - c.exportedClose(StatusProtocolError, err.Error(), false) - return err - } - - if !h.fin { - err := errors.New("received fragmented control frame") - c.exportedClose(StatusProtocolError, err.Error(), false) - return err - } - - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - - b := c.controlPayloadBuf[:h.payloadLength] - _, err := c.readFramePayload(ctx, b) - if err != nil { - return err - } - - if h.masked { - mask(h.maskKey, b) - } - - switch h.opcode { - case opPing: - return c.writeControl(ctx, opPong, b) - case opPong: - c.activePingsMu.Lock() - pong, ok := c.activePings[string(b)] - c.activePingsMu.Unlock() - if ok { - close(pong) - } - return nil - case opClose: - ce, err := parseClosePayload(b) - if err != nil { - err = fmt.Errorf("received invalid close payload: %w", err) - c.exportedClose(StatusProtocolError, err.Error(), false) - c.closeReceived = err - return err - } - - err = fmt.Errorf("received close: %w", ce) - c.closeReceived = err - c.writeClose(b, err, false) - - if ctx.Err() != nil { - // The above close probably has been returned by the peer in response - // to our read timing out so we have to return the read timed out error instead. - return fmt.Errorf("read timed out: %w", ctx.Err()) - } - - return err - default: - panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) - } -} - -// Reader waits until there is a WebSocket data message to read -// from the connection. -// It returns the type of the message and a reader to read it. -// The passed context will also bound the reader. -// Ensure you read to EOF otherwise the connection will hang. -// -// All returned errors will cause the connection -// to be closed so you do not need to write your own error message. -// This applies to the Read methods in the wsjson/wspb subpackages as well. -// -// You must read from the connection for control frames to be handled. -// Thus if you expect messages to take a long time to be responded to, -// you should handle such messages async to reading from the connection -// to ensure control frames are promptly handled. -// -// If you do not expect any data messages from the peer, call CloseRead. -// -// Only one Reader may be open at a time. -// -// If you need a separate timeout on the Reader call and then the message -// Read, use time.AfterFunc to cancel the context passed in early. -// See https://github.com/nhooyr/websocket/issues/87#issue-451703332 -// Most users should not need this. -func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { - if c.isReadClosed.Load() == 1 { - return 0, nil, errors.New("websocket connection read closed") - } - - typ, r, err := c.reader(ctx, true) - if err != nil { - return 0, nil, fmt.Errorf("failed to get reader: %w", err) - } - return typ, r, nil -} - -func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, error) { - if lock { - err := c.acquireLock(ctx, c.readLock) - if err != nil { - return 0, nil, err - } - defer c.releaseLock(c.readLock) - } - - if c.activeReader != nil && !c.readerFrameEOF { - // The only way we know for sure the previous reader is not yet complete is - // if there is an active frame not yet fully read. - // Otherwise, a user may have read the last byte but not the EOF if the EOF - // is in the next frame so we check for that below. - return 0, nil, errors.New("previous message not read to completion") - } - - h, err := c.readTillMsg(ctx) - if err != nil { - return 0, nil, err - } - - if c.activeReader != nil && !c.activeReader.eof() { - if h.opcode != opContinuation { - err := errors.New("received new data message without finishing the previous message") - c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, nil, err - } - - if !h.fin || h.payloadLength > 0 { - return 0, nil, fmt.Errorf("previous message not read to completion") - } - - c.activeReader = nil - - h, err = c.readTillMsg(ctx) - if err != nil { - return 0, nil, err - } - } else if h.opcode == opContinuation { - err := errors.New("received continuation frame not after data or text frame") - c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, nil, err - } - - c.readerMsgCtx = ctx - c.readerMsgHeader = h - c.readerFrameEOF = false - c.readerMaskKey = h.maskKey - c.readMsgLeft = c.msgReadLimit.Load() - - r := &messageReader{ - c: c, - } - c.activeReader = r - return MessageType(h.opcode), r, nil -} - -// messageReader enables reading a data frame from the WebSocket connection. -type messageReader struct { - c *Conn -} - -func (r *messageReader) eof() bool { - return r.c.activeReader != r -} - -// Read reads as many bytes as possible into p. -func (r *messageReader) Read(p []byte) (int, error) { - return r.exportedRead(p, true) -} - -func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) { - n, err := r.read(p, lock) - if err != nil { - // Have to return io.EOF directly for now, we cannot wrap as errors.Is - // isn't used widely yet. - if errors.Is(err, io.EOF) { - return n, io.EOF - } - return n, fmt.Errorf("failed to read: %w", err) - } - return n, nil -} - -func (r *messageReader) readUnlocked(p []byte) (int, error) { - return r.exportedRead(p, false) -} - -func (r *messageReader) read(p []byte, lock bool) (int, error) { - if lock { - // If we cannot acquire the read lock, then - // there is either a concurrent read or the close handshake - // is proceeding. - select { - case r.c.readLock <- struct{}{}: - defer r.c.releaseLock(r.c.readLock) - default: - if r.c.closing.Load() == 1 { - <-r.c.closed - return 0, r.c.closeErr - } - return 0, errors.New("concurrent read detected") - } - } - - if r.eof() { - return 0, errors.New("cannot use EOFed reader") - } - - if r.c.readMsgLeft <= 0 { - err := fmt.Errorf("read limited at %v bytes", r.c.msgReadLimit) - r.c.exportedClose(StatusMessageTooBig, err.Error(), false) - return 0, err - } - - if int64(len(p)) > r.c.readMsgLeft { - p = p[:r.c.readMsgLeft] - } - - if r.c.readerFrameEOF { - h, err := r.c.readTillMsg(r.c.readerMsgCtx) - if err != nil { - return 0, err - } - - if h.opcode != opContinuation { - err := errors.New("received new data message without finishing the previous message") - r.c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, err - } - - r.c.readerMsgHeader = h - r.c.readerFrameEOF = false - r.c.readerMaskKey = h.maskKey - } - - h := r.c.readerMsgHeader - if int64(len(p)) > h.payloadLength { - p = p[:h.payloadLength] - } - - n, err := r.c.readFramePayload(r.c.readerMsgCtx, p) - - h.payloadLength -= int64(n) - r.c.readMsgLeft -= int64(n) - if h.masked { - r.c.readerMaskKey = mask(r.c.readerMaskKey, p) - } - r.c.readerMsgHeader = h - - if err != nil { - return n, err - } - - if h.payloadLength == 0 { - r.c.readerFrameEOF = true - - if h.fin { - r.c.activeReader = nil - return n, io.EOF - } - } - - return n, nil -} - -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { - wrap := func(err error) error { - return fmt.Errorf("failed to read frame payload: %w", err) - } - defer func() { - if err != nil { - err = wrap(err) - } - }() - - err = c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.readFrameLock) - - select { - case <-c.closed: - return 0, c.closeErr - case c.setReadTimeout <- ctx: - } - - n, err := io.ReadFull(c.br, p) - if err != nil { - select { - case <-c.closed: - return n, c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - c.releaseLock(c.readFrameLock) - c.close(wrap(err)) - return n, err - } - - select { - case <-c.closed: - return n, c.closeErr - case c.setReadTimeout <- context.Background(): - } - - return n, err -} - -// Read is a convenience method to read a single message from the connection. -// -// See the Reader method if you want to be able to reuse buffers or want to stream a message. -// The docs on Reader apply to this method as well. -func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - typ, r, err := c.Reader(ctx) - if err != nil { - return 0, nil, err - } - - b, err := ioutil.ReadAll(r) - return typ, b, err -} - -// Writer returns a writer bounded by the context that will write -// a WebSocket message of type dataType to the connection. -// -// You must close the writer once you have written the entire message. -// -// Only one writer can be open at a time, multiple calls will block until the previous writer -// is closed. -func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - wc, err := c.writer(ctx, typ) - if err != nil { - return nil, fmt.Errorf("failed to get writer: %w", err) - } - return wc, nil -} - -func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.acquireLock(ctx, c.writeMsgLock) - if err != nil { - return nil, err - } - c.writeMsgCtx = ctx - c.writeMsgOpcode = opcode(typ) - w := &messageWriter{ - c: c, - } - c.activeWriter.Store(w) - return w, nil -} - -// Write is a convenience method to write a message to the connection. -// -// See the Writer method if you want to stream a message. -func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - _, err := c.write(ctx, typ, p) - if err != nil { - return fmt.Errorf("failed to write msg: %w", err) - } - return nil -} - -func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { - err := c.acquireLock(ctx, c.writeMsgLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.writeMsgLock) - - n, err := c.writeFrame(ctx, true, opcode(typ), p) - return n, err -} - -// messageWriter enables writing to a WebSocket connection. -type messageWriter struct { - c *Conn -} - -func (w *messageWriter) closed() bool { - return w != w.c.activeWriter.Load() -} - -// Write writes the given bytes to the WebSocket connection. -func (w *messageWriter) Write(p []byte) (int, error) { - n, err := w.write(p) - if err != nil { - return n, fmt.Errorf("failed to write: %w", err) - } - return n, nil -} - -func (w *messageWriter) write(p []byte) (int, error) { - if w.closed() { - return 0, fmt.Errorf("cannot use closed writer") - } - n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p) - if err != nil { - return n, fmt.Errorf("failed to write data frame: %w", err) - } - w.c.writeMsgOpcode = opContinuation - return n, nil -} - -// Close flushes the frame to the connection. -// This must be called for every messageWriter. -func (w *messageWriter) Close() error { - err := w.close() - if err != nil { - return fmt.Errorf("failed to close writer: %w", err) - } - return nil -} - -func (w *messageWriter) close() error { - if w.closed() { - return fmt.Errorf("cannot use closed writer") - } - w.c.activeWriter.Store((*messageWriter)(nil)) - - _, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil) - if err != nil { - return fmt.Errorf("failed to write fin frame: %w", err) - } - - w.c.releaseLock(w.c.writeMsgLock) - return nil -} - -func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - - _, err := c.writeFrame(ctx, true, opcode, p) - if err != nil { - return fmt.Errorf("failed to write control frame %v: %w", opcode, err) - } - return nil -} - -// writeFrame handles all writes to the connection. -func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { - err := c.acquireLock(ctx, c.writeFrameLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.writeFrameLock) - - select { - case <-c.closed: - return 0, c.closeErr - case c.setWriteTimeout <- ctx: - } - - c.writeHeader.fin = fin - c.writeHeader.opcode = opcode - c.writeHeader.masked = c.client - c.writeHeader.payloadLength = int64(len(p)) - - if c.client { - err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey) - if err != nil { - return 0, fmt.Errorf("failed to generate masking key: %w", err) - } - } - - n, err := c.realWriteFrame(ctx, *c.writeHeader, p) - if err != nil { - return n, err - } - - // We already finished writing, no need to potentially brick the connection if - // the context expires. - select { - case <-c.closed: - return n, c.closeErr - case c.setWriteTimeout <- context.Background(): - } - - return n, nil -} - -func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error) { - defer func() { - if err != nil { - select { - case <-c.closed: - err = c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - - err = fmt.Errorf("failed to write %v frame: %w", h.opcode, err) - // We need to release the lock first before closing the connection to ensure - // the lock can be acquired inside close to ensure no one can access c.bw. - c.releaseLock(c.writeFrameLock) - c.close(err) - } - }() - - headerBytes := writeHeader(c.writeHeaderBuf, h) - _, err = c.bw.Write(headerBytes) - if err != nil { - return 0, err - } - - if c.client { - maskKey := h.maskKey - for len(p) > 0 { - if c.bw.Available() == 0 { - err = c.bw.Flush() - if err != nil { - return n, err - } - } - - // Start of next write in the buffer. - i := c.bw.Buffered() - - p2 := p - if len(p) > c.bw.Available() { - p2 = p[:c.bw.Available()] - } - - n2, err := c.bw.Write(p2) - if err != nil { - return n, err - } - - maskKey = mask(maskKey, c.writeBuf[i:i+n2]) - - p = p[n2:] - n += n2 - } - } else { - n, err = c.bw.Write(p) - if err != nil { - return n, err - } - } - - if h.fin { - err = c.bw.Flush() - if err != nil { - return n, err - } - } - - return n, nil -} - -// Close closes the WebSocket connection with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for -// the peer to send a close frame. -// Thus, it implements the full WebSocket close handshake. -// All data messages received from the peer during the close handshake -// will be discarded. -// -// The connection can only be closed once. Additional calls to Close -// are no-ops. -// -// The maximum length of reason must be 125 bytes otherwise an internal -// error will be sent to the peer. For this reason, you should avoid -// sending a dynamic reason. -// -// Close will unblock all goroutines interacting with the connection once -// complete. -func (c *Conn) Close(code StatusCode, reason string) error { - err := c.exportedClose(code, reason, true) - var ec errClosing - if errors.As(err, &ec) { - <-c.closed - // We wait until the connection closes. - // We use writeClose and not exportedClose to avoid a second failed to marshal close frame error. - err = c.writeClose(nil, ec.ce, true) - } - if err != nil { - return fmt.Errorf("failed to close websocket connection: %w", err) - } - return nil -} - -func (c *Conn) exportedClose(code StatusCode, reason string, handshake bool) error { - ce := CloseError{ - Code: code, - Reason: reason, - } - - // This function also will not wait for a close frame from the peer like the RFC - // wants because that makes no sense and I don't think anyone actually follows that. - // Definitely worth seeing what popular browsers do later. - p, err := ce.bytes() - if err != nil { - c.logf("websocket: failed to marshal close frame: %+v", err) - ce = CloseError{ - Code: StatusInternalError, - } - p, _ = ce.bytes() - } - - return c.writeClose(p, fmt.Errorf("sent close: %w", ce), handshake) -} - -type errClosing struct { - ce error -} - -func (e errClosing) Error() string { - return "already closing connection" -} - -func (c *Conn) writeClose(p []byte, ce error, handshake bool) error { - if c.isClosed() { - return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) - } - - if !c.closing.CAS(0, 1) { - // Normally, we would want to wait until the connection is closed, - // at least for when a user calls into Close, so we handle that case in - // the exported Close function. - // - // But for internal library usage, we always want to return early, e.g. - // if we are performing a close handshake and the peer sends their close frame, - // we do not want to block here waiting for c.closed to close because it won't, - // at least not until we return since the gorouine that will close it is this one. - return errClosing{ - ce: ce, - } - } - - // No matter what happens next, close error should be set. - c.setCloseErr(ce) - defer c.close(nil) - - err := c.writeControl(context.Background(), opClose, p) - if err != nil { - return err - } - - if handshake { - err = c.waitClose() - if CloseStatus(err) == -1 { - // waitClose exited not due to receiving a close frame. - return fmt.Errorf("failed to wait for peer close frame: %w", err) - } - } - - return nil -} - -func (c *Conn) waitClose() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - err := c.acquireLock(ctx, c.readLock) - if err != nil { - return err - } - defer c.releaseLock(c.readLock) - - if c.closeReceived != nil { - // goroutine reading just received the close. - return c.closeReceived - } - - b := bpool.Get() - buf := b.Bytes() - buf = buf[:cap(buf)] - defer bpool.Put(b) - - for { - if c.activeReader == nil || c.readerFrameEOF { - _, _, err := c.reader(ctx, false) - if err != nil { - return fmt.Errorf("failed to get reader: %w", err) - } - } - - r := readerFunc(c.activeReader.readUnlocked) - _, err = io.CopyBuffer(ioutil.Discard, r, buf) - if err != nil { - return err - } - } -} - -// Ping sends a ping to the peer and waits for a pong. -// Use this to measure latency or ensure the peer is responsive. -// Ping must be called concurrently with Reader as it does -// not read from the connection but instead waits for a Reader call -// to read the pong. -// -// TCP Keepalives should suffice for most use cases. -func (c *Conn) Ping(ctx context.Context) error { - p := c.pingCounter.Increment(1) - - err := c.ping(ctx, strconv.FormatInt(p, 10)) - if err != nil { - return fmt.Errorf("failed to ping: %w", err) - } - return nil -} - -func (c *Conn) ping(ctx context.Context, p string) error { - pong := make(chan struct{}) - - c.activePingsMu.Lock() - c.activePings[p] = pong - c.activePingsMu.Unlock() - - defer func() { - c.activePingsMu.Lock() - delete(c.activePings, p) - c.activePingsMu.Unlock() - }() - - err := c.writeControl(ctx, opPing, []byte(p)) - if err != nil { - return err - } - - select { - case <-c.closed: - return c.closeErr - case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err - case <-pong: - return nil - } -} - -type readerFunc func(p []byte) (int, error) - -func (f readerFunc) Read(p []byte) (int, error) { - return f(p) -} - -type writerFunc func(p []byte) (int, error) - -func (f writerFunc) Write(p []byte) (int, error) { - return f(p) -} - -// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer -// and stores it in c.writeBuf. -func (c *Conn) extractBufioWriterBuf(w io.Writer) { - c.bw.Reset(writerFunc(func(p2 []byte) (int, error) { - c.writeBuf = p2[:cap(p2)] - return len(p2), nil - })) - - c.bw.WriteByte(0) - c.bw.Flush() - - c.bw.Reset(w) -} diff --git a/conn_export_test.go b/conn_export_test.go deleted file mode 100644 index d5f5aa24..00000000 --- a/conn_export_test.go +++ /dev/null @@ -1,129 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "context" - "fmt" -) - -type ( - Addr = websocketAddr - OpCode int -) - -const ( - OpClose = OpCode(opClose) - OpBinary = OpCode(opBinary) - OpText = OpCode(opText) - OpPing = OpCode(opPing) - OpPong = OpCode(opPong) - OpContinuation = OpCode(opContinuation) -) - -func (c *Conn) SetLogf(fn func(format string, v ...interface{})) { - c.logf = fn -} - -func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) { - h, err := c.readFrameHeader(ctx) - if err != nil { - return 0, nil, err - } - b := make([]byte, h.payloadLength) - _, err = c.readFramePayload(ctx, b) - if err != nil { - return 0, nil, err - } - if h.masked { - mask(h.maskKey, b) - } - return OpCode(h.opcode), b, nil -} - -func (c *Conn) WriteFrame(ctx context.Context, fin bool, opc OpCode, p []byte) (int, error) { - return c.writeFrame(ctx, fin, opcode(opc), p) -} - -// header represents a WebSocket frame header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -type Header struct { - Fin bool - Rsv1 bool - Rsv2 bool - Rsv3 bool - OpCode OpCode - - PayloadLength int64 -} - -func (c *Conn) WriteHeader(ctx context.Context, h Header) error { - headerBytes := writeHeader(c.writeHeaderBuf, header{ - fin: h.Fin, - rsv1: h.Rsv1, - rsv2: h.Rsv2, - rsv3: h.Rsv3, - opcode: opcode(h.OpCode), - payloadLength: h.PayloadLength, - masked: c.client, - }) - _, err := c.bw.Write(headerBytes) - if err != nil { - return fmt.Errorf("failed to write header: %w", err) - } - if h.Fin { - err = c.Flush() - if err != nil { - return err - } - } - return nil -} - -func (c *Conn) PingWithPayload(ctx context.Context, p string) error { - return c.ping(ctx, p) -} - -func (c *Conn) WriteHalfFrame(ctx context.Context) (int, error) { - return c.realWriteFrame(ctx, header{ - fin: true, - opcode: opBinary, - payloadLength: 10, - }, make([]byte, 5)) -} - -func (c *Conn) CloseUnderlyingConn() { - c.closer.Close() -} - -func (c *Conn) Flush() error { - return c.bw.Flush() -} - -func (c CloseError) Bytes() ([]byte, error) { - return c.bytes() -} - -func (c *Conn) BW() *bufio.Writer { - return c.bw -} - -func (c *Conn) WriteClose(ctx context.Context, code StatusCode, reason string) ([]byte, error) { - b, err := CloseError{ - Code: code, - Reason: reason, - }.Bytes() - if err != nil { - return nil, err - } - _, err = c.WriteFrame(ctx, true, OpClose, b) - if err != nil { - return nil, err - } - return b, nil -} - -func ParseClosePayload(p []byte) (CloseError, error) { - return parseClosePayload(p) -} diff --git a/conn_notjs.go b/conn_notjs.go new file mode 100644 index 00000000..8598ded3 --- /dev/null +++ b/conn_notjs.go @@ -0,0 +1,258 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "io" + "runtime" + "strconv" + "sync" + "sync/atomic" + + "golang.org/x/xerrors" +) + +// Conn represents a WebSocket connection. +// All methods may be called concurrently except for Reader and Read. +// +// You must always read from the connection. Otherwise control +// frames will not be handled. See Reader and CloseRead. +// +// Be sure to call Close on the connection when you +// are finished with it to release associated resources. +// +// On any error from any method, the connection is closed +// with an appropriate reason. +type Conn struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer + + readTimeout chan context.Context + writeTimeout chan context.Context + + // Read state. + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader + readCloseFrameErr error + + // Write state. + msgWriterState *msgWriterState + writeFrameMu *mu + writeBuf []byte + writeHeaderBuf [8]byte + writeHeader header + + closed chan struct{} + closeMu sync.Mutex + closeErr error + wroteClose bool + + pingCounter int32 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} +} + +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + + br *bufio.Reader + bw *bufio.Writer +} + +func newConn(cfg connConfig) *Conn { + c := &Conn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + } + + c.readMu = newMu(c) + c.writeFrameMu = newMu(c) + + c.msgReader = newMsgReader(c) + + c.msgWriterState = newMsgWriterState(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } + + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 128 + if !c.msgWriterState.flateContextTakeover() { + c.flateThreshold = 512 + } + } + + runtime.SetFinalizer(c, func(c *Conn) { + c.close(xerrors.New("connection garbage collected")) + }) + + go c.timeoutLoop() + + return c +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +func (c *Conn) close(err error) { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.isClosed() { + return + } + c.setCloseErrLocked(err) + close(c.closed) + runtime.SetFinalizer(c, nil) + + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.rwc.Close() + + go func() { + if c.client { + c.writeFrameMu.Lock(context.Background()) + putBufioWriter(c.bw) + } + c.msgWriterState.close() + + c.msgReader.close() + if c.client { + putBufioReader(c.br) + } + }() +} + +func (c *Conn) timeoutLoop() { + readCtx := context.Background() + writeCtx := context.Background() + + for { + select { + case <-c.closed: + return + + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: + + case <-readCtx.Done(): + c.setCloseErr(xerrors.Errorf("read timed out: %w", readCtx.Err())) + go c.writeError(StatusPolicyViolation, xerrors.New("timed out")) + case <-writeCtx.Done(): + c.close(xerrors.Errorf("write timed out: %w", writeCtx.Err())) + return + } + } +} + +func (c *Conn) flate() bool { + return c.copts != nil +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// Ping must be called concurrently with Reader as it does +// not read from the connection but instead waits for a Reader call +// to read the pong. +// +// TCP Keepalives should suffice for most use cases. +func (c *Conn) Ping(ctx context.Context) error { + p := atomic.AddInt32(&c.pingCounter, 1) + + err := c.ping(ctx, strconv.Itoa(int(p))) + if err != nil { + return xerrors.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *Conn) ping(ctx context.Context, p string) error { + pong := make(chan struct{}) + + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() + + defer func() { + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() + }() + + err := c.writeControl(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-c.closed: + return c.closeErr + case <-ctx.Done(): + err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err()) + c.close(err) + return err + case <-pong: + return nil + } +} + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) Lock(ctx context.Context) error { + select { + case <-m.c.closed: + return m.c.closeErr + case <-ctx.Done(): + err := xerrors.Errorf("failed to acquire lock: %w", ctx.Err()) + m.c.close(err) + return err + case m.ch <- struct{}{}: + return nil + } +} + +func (m *mu) Unlock() { + select { + case <-m.ch: + default: + } +} diff --git a/conn_test.go b/conn_test.go index 83f09dbf..7755048c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -5,2386 +5,464 @@ package websocket_test import ( "bytes" "context" - "encoding/binary" - "encoding/json" - "errors" "fmt" "io" "io/ioutil" - "math/rand" - "net" "net/http" - "net/http/cookiejar" "net/http/httptest" - "net/url" "os" "os/exec" - "reflect" - "strconv" "strings" + "sync" "testing" "time" - "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/timestamp" - "go.uber.org/multierr" + "github.com/golang/protobuf/ptypes/duration" + "golang.org/x/xerrors" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/assert" - "nhooyr.io/websocket/internal/wsecho" - "nhooyr.io/websocket/internal/wsgrace" + "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/test/wstest" + "nhooyr.io/websocket/internal/test/xrand" + "nhooyr.io/websocket/internal/xsync" "nhooyr.io/websocket/wsjson" "nhooyr.io/websocket/wspb" ) -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func TestHandshake(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - client func(ctx context.Context, url string) error - server func(w http.ResponseWriter, r *http.Request) error - }{ - { - name: "badOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err == nil { - c.Close(websocket.StatusInternalError, "") - return errors.New("expected error regarding bad origin") - } - return assertErrorContains(err, "not authorized") - }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", "http://unauthorized.com/") - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPHeader: h, - }) - if err == nil { - c.Close(websocket.StatusInternalError, "") - return errors.New("expected handshake failure") - } - return assertErrorContains(err, "403") - }, - }, - { - name: "acceptSecureOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", u) - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPHeader: h, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - { - name: "acceptInsecureOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - InsecureSkipVerify: true, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", "https://example.com/") - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPHeader: h, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - { - name: "cookies", - server: func(w http.ResponseWriter, r *http.Request) error { - cookie, err := r.Cookie("mycookie") - if err != nil { - return fmt.Errorf("request is missing mycookie: %w", err) - } - err = assert.Equalf("myvalue", cookie.Value, "unexpected cookie value") - if err != nil { - return err - } - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - client: func(ctx context.Context, u string) error { - jar, err := cookiejar.New(nil) - if err != nil { - return fmt.Errorf("failed to create cookie jar: %w", err) - } - parsedURL, err := url.Parse(u) - if err != nil { - return fmt.Errorf("failed to parse url: %w", err) - } - parsedURL.Scheme = "http" - jar.SetCookies(parsedURL, []*http.Cookie{ - { - Name: "mycookie", - Value: "myvalue", - }, - }) - hc := &http.Client{ - Jar: jar, - } - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPClient: hc, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - s, closeFn := testServer(t, tc.server, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - err := tc.client(ctx, wsURL) - if err != nil { - t.Fatalf("client failed: %+v", err) - } - }) - } -} - func TestConn(t *testing.T) { t.Parallel() - testCases := []struct { - name string - - acceptOpts *websocket.AcceptOptions - server func(ctx context.Context, c *websocket.Conn) error - - dialOpts *websocket.DialOptions - response func(resp *http.Response) error - client func(ctx context.Context, c *websocket.Conn) error - }{ - { - name: "handshake", - acceptOpts: &websocket.AcceptOptions{ - Subprotocols: []string{"myproto"}, - }, - dialOpts: &websocket.DialOptions{ - Subprotocols: []string{"myproto"}, - }, - response: func(resp *http.Response) error { - headers := map[string]string{ - "Connection": "Upgrade", - "Upgrade": "websocket", - "Sec-WebSocket-Protocol": "myproto", - } - for h, exp := range headers { - value := resp.Header.Get(h) - err := assert.Equalf(exp, value, "unexpected value for header %v", h) - if err != nil { - return err - } - } - return nil - }, - }, - { - name: "handshake/defaultSubprotocol", - server: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "") - }, - }, - { - name: "handshake/subprotocolPriority", - acceptOpts: &websocket.AcceptOptions{ - Subprotocols: []string{"echo", "lar"}, - }, - server: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "echo") - }, - dialOpts: &websocket.DialOptions{ - Subprotocols: []string{"poof", "echo"}, - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "echo") - }, - }, - { - name: "closeError", - server: func(ctx context.Context, c *websocket.Conn) error { - return wsjson.Write(ctx, c, "hello") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := assertJSONRead(ctx, c, "hello") - if err != nil { - return err - } - - _, _, err = c.Reader(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "netConn", - server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - defer nc.Close() - - nc.SetWriteDeadline(time.Time{}) - time.Sleep(1) - nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) - - err := assert.Equalf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr") - if err != nil { - return err - } - err = assert.Equalf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr") - if err != nil { - return err - } - - for i := 0; i < 3; i++ { - _, err := nc.Write([]byte("hello")) - if err != nil { - return err - } - } - - return nil - }, - client: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - - nc.SetReadDeadline(time.Time{}) - time.Sleep(1) - nc.SetReadDeadline(time.Now().Add(time.Second * 15)) - - for i := 0; i < 3; i++ { - err := assertNetConnRead(nc, "hello") - if err != nil { - return err - } - } - - // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. - err2 := assertNetConnRead(nc, "hello") - err := assert.Equalf(io.EOF, err2, "unexpected error") - if err != nil { - return err - } - - err2 = assertNetConnRead(nc, "hello") - return assert.Equalf(io.EOF, err2, "unexpected error") - }, - }, - { - name: "netConn/badReadMsgType", - server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - - nc.SetDeadline(time.Now().Add(time.Second * 15)) - - _, err := nc.Read(make([]byte, 1)) - return assertErrorContains(err, "unexpected frame type") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wsjson.Write(ctx, c, "meow") - if err != nil { - return err - } - - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusUnsupportedData) - }, - }, - { - name: "netConn/badRead", - server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - defer nc.Close() - - nc.SetDeadline(time.Now().Add(time.Second * 15)) - - _, err2 := nc.Read(make([]byte, 1)) - err := assertCloseStatus(err2, websocket.StatusBadGateway) - if err != nil { - return err - } - - _, err2 = nc.Write([]byte{0xff}) - return assertErrorContains(err2, "websocket closed") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusBadGateway, "") - }, - }, - { - name: "wsjson/echo", - server: func(ctx context.Context, c *websocket.Conn) error { - return wsjson.Write(ctx, c, "meow") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertJSONRead(ctx, c, "meow") - }, - }, - { - name: "protobuf/echo", - server: func(ctx context.Context, c *websocket.Conn) error { - return wspb.Write(ctx, c, ptypes.DurationProto(100)) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertProtobufRead(ctx, c, ptypes.DurationProto(100)) - }, - }, - { - name: "ping", - server: func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - - err := c.Ping(ctx) - if err != nil { - return err - } - - err = wsjson.Write(ctx, c, "hi") - if err != nil { - return err - } - - <-ctx.Done() - err = c.Ping(context.Background()) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - // We read a message from the connection and then keep reading until - // the Ping completes. - pingErrc := make(chan error, 1) - go func() { - pingErrc <- c.Ping(ctx) - }() - - // Once this completes successfully, that means they sent their ping and we responded to it. - err := assertJSONRead(ctx, c, "hi") - if err != nil { - return err - } - - // Now we need to ensure we're reading for their pong from our ping. - // Need new var to not race with above goroutine. - ctx2 := c.CloseRead(ctx) - - // Now we wait for our pong. - select { - case err = <-pingErrc: - return err - case <-ctx2.Done(): - return fmt.Errorf("failed to wait for pong: %w", ctx2.Err()) - } - }, - }, - { - name: "readLimit", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err2 := c.Read(ctx) - return assertErrorContains(err2, "read limited at 32768 bytes") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) - if err != nil { - return err - } - - _, _, err2 := c.Read(ctx) - return assertCloseStatus(err2, websocket.StatusMessageTooBig) - }, - }, - { - name: "wsjson/binary", - server: func(ctx context.Context, c *websocket.Conn) error { - var v interface{} - err2 := wsjson.Read(ctx, c, &v) - return assertErrorContains(err2, "unexpected frame type") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return wspb.Write(ctx, c, ptypes.DurationProto(100)) - }, - }, - { - name: "wsjson/badRead", - server: func(ctx context.Context, c *websocket.Conn) error { - var v interface{} - err2 := wsjson.Read(ctx, c, &v) - return assertErrorContains(err2, "failed to unmarshal json") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Write(ctx, websocket.MessageText, []byte("notjson")) - }, - }, - { - name: "wsjson/badWrite", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err2 := c.Read(ctx) - return assertCloseStatus(err2, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wsjson.Write(ctx, c, fmt.Println) - return assertErrorContains(err, "failed to encode json") - }, - }, - { - name: "wspb/text", - server: func(ctx context.Context, c *websocket.Conn) error { - var v proto.Message - err := wspb.Read(ctx, c, v) - return assertErrorContains(err, "unexpected frame type") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return wsjson.Write(ctx, c, "hi") - }, - }, - { - name: "wspb/badRead", - server: func(ctx context.Context, c *websocket.Conn) error { - var v timestamp.Timestamp - err := wspb.Read(ctx, c, &v) - return assertErrorContains(err, "failed to unmarshal protobuf") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Write(ctx, websocket.MessageBinary, []byte("notpb")) - }, - }, - { - name: "wspb/badWrite", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wspb.Write(ctx, c, nil) - return assertErrorIs(proto.ErrNil, err) - }, - }, - { - name: "badClose", - server: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(9999, "") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "pingTimeout", - server: func(ctx context.Context, c *websocket.Conn) error { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - err := c.Ping(ctx) - return assertErrorIs(context.DeadlineExceeded, err) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - err1 := assertErrorContains(err, "connection reset") - err2 := assertErrorIs(io.EOF, err) - if err1 != nil || err2 != nil { - return nil - } - return multierr.Combine(err1, err2) - }, - }, - { - name: "writeTimeout", - server: func(ctx context.Context, c *websocket.Conn) error { - c.Writer(ctx, websocket.MessageBinary) - - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - err := c.Write(ctx, websocket.MessageBinary, []byte("meow")) - return assertErrorIs(context.DeadlineExceeded, err) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorIs(io.EOF, err) - }, - }, - { - name: "readTimeout", - server: func(ctx context.Context, c *websocket.Conn) error { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - _, _, err := c.Read(ctx) - return assertErrorIs(context.DeadlineExceeded, err) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorIs(websocket.CloseError{ - Code: websocket.StatusPolicyViolation, - Reason: "read timed out", - }, err) - }, - }, - { - name: "badOpCode", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, 13, []byte("meow")) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertErrorContains(err, "unknown opcode") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "unknown opcode") - }, - }, - { - name: "noRsv", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, 99, []byte("meow")) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "rsv") { - return fmt.Errorf("expected error that contains rsv: %+v", err) - } - return nil - }, - }, - { - name: "largeControlFrame", - server: func(ctx context.Context, c *websocket.Conn) error { - err := c.WriteHeader(ctx, websocket.Header{ - Fin: true, - OpCode: websocket.OpClose, - PayloadLength: 4096, - }) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "too big") - }, - }, - { - name: "fragmentedControlFrame", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, false, websocket.OpPing, []byte(strings.Repeat("x", 32))) - if err != nil { - return err - } - err = c.Flush() - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "fragmented") - }, - }, - { - name: "invalidClosePayload", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{0x17, 0x70}) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "invalid status code") - }, - }, - { - name: "doubleReader", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 10) - _, err = io.ReadFull(r, p) - if err != nil { - return err - } - _, _, err = c.Reader(ctx) - return assertErrorContains(err, "previous message not read to completion") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 11))) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "doubleFragmentedReader", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 10) - _, err = io.ReadFull(r, p) - if err != nil { - return err - } - _, _, err = c.Reader(ctx) - return assertErrorContains(err, "previous message not read to completion") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "newMessageInFragmentedMessage", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 10) - _, err = io.ReadFull(r, p) - if err != nil { - return err - } - _, _, err = c.Reader(ctx) - return assertErrorContains(err, "received new data message without finishing") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - _, _, err = c.Read(ctx) - return assertErrorContains(err, "received new data message without finishing") - }, - }, - { - name: "continuationFrameWithoutDataFrame", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Reader(ctx) - return assertErrorContains(err, "received continuation frame not after data") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, false, websocket.OpContinuation, []byte(strings.Repeat("x", 10))) - return err - }, - }, - { - name: "readBeforeEOF", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - var v interface{} - d := json.NewDecoder(r) - err = d.Decode(&v) - if err != nil { - return err - } - err = assert.Equalf("hi", v, "unexpected JSON") - if err != nil { - return err - } - _, b, err := c.Read(ctx) - if err != nil { - return err - } - return assert.Equalf("hi", string(b), "unexpected JSON") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wsjson.Write(ctx, c, "hi") - if err != nil { - return err - } - return c.Write(ctx, websocket.MessageText, []byte("hi")) - }, - }, - { - name: "newMessageInFragmentedMessage2", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 11) - _, err = io.ReadFull(r, p) - return assertErrorContains(err, "received new data message without finishing") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - }, - { - name: "doubleRead", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - _, err = ioutil.ReadAll(r) - if err != nil { - return err - } - _, err = r.Read(make([]byte, 1)) - return assertErrorContains(err, "cannot use EOFed reader") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Write(ctx, websocket.MessageBinary, []byte("hi")) - }, - }, - { - name: "eofInPayload", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "failed to read frame payload") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteHalfFrame(ctx) - if err != nil { - return err - } - c.CloseUnderlyingConn() - return nil - }, - }, - { - name: "closeHandshake", - server: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, "") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, "") - }, - }, - { - // Issue #164 - name: "closeHandshake_concurrentRead", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - errc := make(chan error, 1) - go func() { - _, _, err := c.Read(ctx) - errc <- err - }() - - err := c.Close(websocket.StatusNormalClosure, "") - if err != nil { - return err - } - - err = <-errc - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Run random tests over TLS. - tls := rand.Intn(2) == 1 - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, tc.acceptOpts) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - c.SetLogf(t.Logf) - if tc.server == nil { - return nil - } - return tc.server(r.Context(), c) - }, tls) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - opts := tc.dialOpts - if tls { - if opts == nil { - opts = &websocket.DialOptions{} - } - opts.HTTPClient = s.Client() - } - - c, resp, err := websocket.Dial(ctx, wsURL, opts) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - c.SetLogf(t.Logf) - - if tc.response != nil { - err = tc.response(resp) - if err != nil { - t.Fatalf("response asserter failed: %+v", err) - } - } - - if tc.client != nil { - err = tc.client(ctx, c) - if err != nil { - t.Fatalf("client failed: %+v", err) - } - } - - c.Close(websocket.StatusNormalClosure, "") - }) - } -} - -func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) { - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err := fn(w, r) - if err != nil { - tb.Errorf("server failed: %+v", err) - } - }) - if tls { - s = httptest.NewTLSServer(h) - } else { - s = httptest.NewServer(h) - } - closeFn2 := wsgrace.Grace(s.Config) - return s, func() { - err := closeFn2() - if err != nil { - tb.Fatal(err) - } - } -} - -func TestAutobahn(t *testing.T) { - t.Parallel() - - run := func(t *testing.T, name string, fn func(ctx context.Context, c *websocket.Conn) error) { - run2 := func(t *testing.T, testingClient bool) { - // Run random tests over TLS. - tls := rand.Intn(2) == 1 - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - ctx := r.Context() - if testingClient { - err = wsecho.Loop(ctx, c) - if err != nil { - t.Logf("failed to wsecho: %+v", err) - } - return nil - } - - c.SetReadLimit(1 << 30) - err = fn(ctx, c) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, tls) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - } - if tls { - opts.HTTPClient = s.Client() - } - - c, _, err := websocket.Dial(ctx, wsURL, opts) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - if testingClient { - c.SetReadLimit(1 << 30) - err = fn(ctx, c) - if err != nil { - t.Fatalf("client failed: %+v", err) - } - c.Close(websocket.StatusNormalClosure, "") - return - } - - err = wsecho.Loop(ctx, c) - if err != nil { - t.Logf("failed to wsecho: %+v", err) - } - } - t.Run(name, func(t *testing.T) { - t.Parallel() - - run2(t, true) - }) - } - - // Section 1. - t.Run("echo", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 0, - 125, - 126, - 127, - 128, - 65535, - 65536, - 65536, - } - run := func(typ websocket.MessageType) { - for i, l := range lengths { - l := l - run(t, fmt.Sprintf("%v/%v", typ, l), func(ctx context.Context, c *websocket.Conn) error { - p := randBytes(l) - if i == len(lengths)-1 { - w, err := c.Writer(ctx, typ) - if err != nil { - return err - } - for i := 0; i < l; { - j := i + 997 - if j > l { - j = l - } - _, err = w.Write(p[i:j]) - if err != nil { - return err - } - - i = j - } - - err = w.Close() - if err != nil { - return err - } - } else { - err := c.Write(ctx, typ, p) - if err != nil { - return err - } - } - actTyp, p2, err := c.Read(ctx) - if err != nil { - return err - } - err = assert.Equalf(typ, actTyp, "unexpected message type") - if err != nil { - return err - } - return assert.Equalf(p, p2, "unexpected message") - }) - } - } - - run(websocket.MessageText) - run(websocket.MessageBinary) - }) - - // Section 2. - t.Run("pingPong", func(t *testing.T) { - t.Parallel() - - run(t, "emptyPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - return c.PingWithPayload(ctx, "") - }) - run(t, "smallTextPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - return c.PingWithPayload(ctx, "hi") - }) - run(t, "smallBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - p := bytes.Repeat([]byte{0xFE}, 16) - return c.PingWithPayload(ctx, string(p)) - }) - run(t, "largeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - p := bytes.Repeat([]byte{0xFE}, 125) - return c.PingWithPayload(ctx, string(p)) - }) - run(t, "tooLargeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { - c.CloseRead(ctx) - p := bytes.Repeat([]byte{0xFE}, 126) - err := c.PingWithPayload(ctx, string(p)) - return assertCloseStatus(err, websocket.StatusProtocolError) - }) - run(t, "streamPingPayload", func(ctx context.Context, c *websocket.Conn) error { - err := assertStreamPing(ctx, c, 125) - if err != nil { - return err - } - return c.Close(websocket.StatusNormalClosure, "") - }) - t.Run("unsolicitedPong", func(t *testing.T) { - t.Parallel() - - var testCases = []struct { - name string - pongPayload string - ping bool - }{ - { - name: "noPayload", - pongPayload: "", - }, - { - name: "payload", - pongPayload: "hi", - }, - { - name: "pongThenPing", - pongPayload: "hi", - ping: true, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpPong, []byte(tc.pongPayload)) - if err != nil { - return err - } - if tc.ping { - _, err := c.WriteFrame(ctx, true, websocket.OpPing, []byte("meow")) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpPong, []byte("meow")) - if err != nil { - return err - } - } - return c.Close(websocket.StatusNormalClosure, "") - }) - } - }) - run(t, "tenPings", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - - for i := 0; i < 10; i++ { - err := c.Ping(ctx) - if err != nil { - return err - } - } - - _, err := c.WriteClose(ctx, websocket.StatusNormalClosure, "") - if err != nil { - return err - } - <-ctx.Done() - - err = c.Ping(context.Background()) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }) - - run(t, "tenStreamedPings", func(ctx context.Context, c *websocket.Conn) error { - for i := 0; i < 10; i++ { - err := assertStreamPing(ctx, c, 125) - if err != nil { - return err - } - } - - return c.Close(websocket.StatusNormalClosure, "") - }) - }) - - // Section 3. - // We skip the per octet sending as it will add too much complexity. - t.Run("reserved", func(t *testing.T) { - t.Parallel() - - var testCases = []struct { - name string - header websocket.Header - }{ - { - name: "rsv1", - header: websocket.Header{ - Fin: true, - Rsv1: true, - OpCode: websocket.OpClose, - PayloadLength: 0, - }, - }, - { - name: "rsv2", - header: websocket.Header{ - Fin: true, - Rsv2: true, - OpCode: websocket.OpPong, - PayloadLength: 0, - }, - }, - { - name: "rsv3", - header: websocket.Header{ - Fin: true, - Rsv3: true, - OpCode: websocket.OpBinary, - PayloadLength: 0, - }, - }, - { - name: "rsvAll", - header: websocket.Header{ - Fin: true, - Rsv1: true, - Rsv2: true, - Rsv3: true, - OpCode: websocket.OpText, - PayloadLength: 0, - }, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - err := assertEcho(ctx, c, websocket.MessageText, 4096) - if err != nil { - return err - } - err = c.WriteHeader(ctx, tc.header) - if err != nil { - return err - } - err = c.Flush() - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - }) - - // Section 4. - t.Run("opcodes", func(t *testing.T) { + t.Run("fuzzData", func(t *testing.T) { t.Parallel() - testCases := []struct { - name string - opcode websocket.OpCode - payload bool - echo bool - ping bool - }{ - // Section 1. - { - name: "3", - opcode: 3, - }, - { - name: "4", - opcode: 4, - payload: true, - }, - { - name: "5", - opcode: 5, - echo: true, - ping: true, - }, - { - name: "6", - opcode: 6, - payload: true, - echo: true, - ping: true, - }, - { - name: "7", - opcode: 7, - payload: true, - echo: true, - ping: true, - }, - - // Section 2. - { - name: "11", - opcode: 11, - }, - { - name: "12", - opcode: 12, - payload: true, - }, - { - name: "13", - opcode: 13, - payload: true, - echo: true, - ping: true, - }, - { - name: "14", - opcode: 14, - payload: true, - echo: true, - ping: true, - }, - { - name: "15", - opcode: 15, - payload: true, - echo: true, - ping: true, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - if tc.echo { - err := assertEcho(ctx, c, websocket.MessageText, 4096) - if err != nil { - return err - } - } - - p := []byte(nil) - if tc.payload { - p = randBytes(rand.Intn(4096) + 1) - } - _, err := c.WriteFrame(ctx, true, tc.opcode, p) - if err != nil { - return err - } - if tc.ping { - _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) - if err != nil { - return err - } - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - }) - - // Section 5. - t.Run("fragmentation", func(t *testing.T) { - t.Parallel() - - // 5.1 to 5.8 - testCases := []struct { - name string - opcode websocket.OpCode - success bool - pingInBetween bool - }{ - { - name: "ping", - opcode: websocket.OpPing, - success: false, - }, - { - name: "pong", - opcode: websocket.OpPong, - success: false, - }, - { - name: "text", - opcode: websocket.OpText, - success: true, - }, - { - name: "textPing", - opcode: websocket.OpText, - success: true, - pingInBetween: true, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(16) - _, err := c.WriteFrame(ctx, false, tc.opcode, p1) - if err != nil { - return err - } - err = c.BW().Flush() - if err != nil { - return err - } - if !tc.success { - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - } - - if tc.pingInBetween { - _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) - if err != nil { - return err - } - } - - p2 := randBytes(16) - _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p2) - if err != nil { - return err - } - - err = assertReadFrame(ctx, c, tc.opcode, p1) - if err != nil { - return err - } - - if tc.pingInBetween { - err = assertReadFrame(ctx, c, websocket.OpPong, p1) - if err != nil { - return err - } - } - - return assertReadFrame(ctx, c, websocket.OpContinuation, p2) - }) + compressionMode := func() websocket.CompressionMode { + return websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)) } - t.Run("unexpectedContinuation", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - fin bool - textFirst bool - }{ - { - name: "fin", - fin: true, - }, - { - name: "noFin", - fin: false, - }, - { - name: "echoFirst", - fin: false, - textFirst: true, - }, - // The rest of the tests in this section get complicated and do not inspire much confidence. - } - - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - if tc.textFirst { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return err - } - p1 := randBytes(32) - _, err = w.Write(p1) - if err != nil { - return err - } - p2 := randBytes(32) - _, err = w.Write(p2) - if err != nil { - return err - } - err = w.Close() - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) - if err != nil { - return err - } - } - - _, err := c.WriteFrame(ctx, tc.fin, websocket.OpContinuation, randBytes(32)) - if err != nil { - return err - } - err = c.BW().Flush() - if err != nil { - return err - } - - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + for i := 0; i < 5; i++ { + t.Run("", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ + CompressionMode: compressionMode(), + CompressionThreshold: xrand.Int(9999), + }, &websocket.AcceptOptions{ + CompressionMode: compressionMode(), + CompressionThreshold: xrand.Int(9999), }) - } + defer tt.cleanup() - run(t, "doubleText", func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(32) - _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, true, websocket.OpText, randBytes(32)) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) + tt.goEchoLoop(c2) - run(t, "5.19", func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(32) - p2 := randBytes(32) - p3 := randBytes(32) - p4 := randBytes(32) - p5 := randBytes(32) + c1.SetReadLimit(131072) - _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p2) - if err != nil { - return err - } - - _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) - if err != nil { - return err - } - - time.Sleep(time.Second) - - _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p3) - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p4) - if err != nil { - return err - } - - _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) - if err != nil { - return err - } - - _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p5) - if err != nil { - return err - } - - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpPong, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p3) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p4) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpPong, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p5) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) - if err != nil { - return err - } - return c.Close(websocket.StatusNormalClosure, "") - }) - }) - }) - - // Section 7 - t.Run("closeHandling", func(t *testing.T) { - t.Parallel() - - // 1.1 - 1.4 is useless. - run(t, "1.5", func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(32) - _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) - if err != nil { - return err - } - err = c.Flush() - if err != nil { - return err - } - _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) - }) - - run(t, "1.6", func(ctx context.Context, c *websocket.Conn) error { - // 262144 bytes. - p1 := randBytes(1 << 18) - err := c.Write(ctx, websocket.MessageText, p1) - if err != nil { - return err - } - _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") - if err != nil { - return err - } - err = assertReadMessage(ctx, c, websocket.MessageText, p1) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) - }) - - run(t, "emptyClose", func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, nil) - if err != nil { - return err - } - return assertReadFrame(ctx, c, websocket.OpClose, []byte{}) - }) - - run(t, "badClose", func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{1}) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) + for i := 0; i < 5; i++ { + err := wstest.Echo(tt.ctx, c1, 131072) + assert.Success(t, err) + } - run(t, "noReason", func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, "") - }) + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + }) + } + }) - run(t, "simpleReason", func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, randString(16)) - }) + t.Run("badClose", func(t *testing.T) { + tt, c1, _ := newConnTest(t, nil, nil) + defer tt.cleanup() - run(t, "maxReason", func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, randString(123)) - }) + err := c1.Close(-1, "") + assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") + }) - run(t, "tooBigReason", func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, - append([]byte{0x03, 0xE8}, randString(124)...), - ) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) + t.Run("ping", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + defer tt.cleanup() - t.Run("validCloses", func(t *testing.T) { - t.Parallel() - - codes := [...]websocket.StatusCode{ - 1000, - 1001, - 1002, - 1003, - 1007, - 1008, - 1009, - 1010, - 1011, - 3000, - 3999, - 4000, - 4999, - } - for _, code := range codes { - run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { - return c.Close(code, randString(32)) - }) - } - }) + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) - t.Run("invalidCloseCodes", func(t *testing.T) { - t.Parallel() - - codes := []websocket.StatusCode{ - 0, - 999, - 1004, - 1005, - 1006, - 1016, - 1100, - 2000, - 2999, - 5000, - 65535, - } - for _, code := range codes { - run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { - p := make([]byte, 2) - binary.BigEndian.PutUint16(p, uint16(code)) - p = append(p, randBytes(32)...) - _, err := c.WriteFrame(ctx, true, websocket.OpClose, p) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - }) + for i := 0; i < 10; i++ { + err := c1.Ping(tt.ctx) + assert.Success(t, err) + } + + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) }) - // Section 9. - t.Run("limits", func(t *testing.T) { - t.Parallel() + t.Run("badPing", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + defer tt.cleanup() - t.Run("unfragmentedEcho", func(t *testing.T) { - t.Parallel() + c2.CloseRead(tt.ctx) - lengths := []int{ - 1 << 16, - 1 << 18, - // Anything higher is completely unnecessary. - } + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() - for _, l := range lengths { - l := l - run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { - return assertEcho(ctx, c, websocket.MessageBinary, l) - }) - } - }) + err := c1.Ping(ctx) + assert.Contains(t, err, "failed to wait for pong") + }) - t.Run("fragmentedEcho", func(t *testing.T) { - t.Parallel() + t.Run("concurrentWrite", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + defer tt.cleanup() - fragments := []int{ - 64, - 256, - 1 << 10, - 1 << 12, - 1 << 14, - 1 << 16, - } + tt.goDiscardLoop(c2) - for _, l := range fragments { - fragmentLength := l - run(t, strconv.Itoa(fragmentLength), func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return err - } - b := randBytes(1 << 16) - for i := 0; i < len(b); { - j := i + fragmentLength - if j > len(b) { - j = len(b) - } - - _, err = w.Write(b[i:j]) - if err != nil { - return err - } - - i = j - } - err = w.Close() - if err != nil { - return err - } - - err = assertReadMessage(ctx, c, websocket.MessageText, b) - if err != nil { - return err - } - return c.Close(websocket.StatusNormalClosure, "") - }) - } - }) + msg := xrand.Bytes(xrand.Int(9999)) + const count = 100 + errs := make(chan error, count) - t.Run("latencyEcho", func(t *testing.T) { - t.Parallel() + for i := 0; i < count; i++ { + go func() { + errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg) + }() + } - lengths := []int{ - 0, - 16, - } + for i := 0; i < count; i++ { + err := <-errs + assert.Success(t, err) + } - for _, l := range lengths { - l := l - run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { - for i := 0; i < 1000; i++ { - err := assertEcho(ctx, c, websocket.MessageBinary, l) - if err != nil { - return err - } - } - return nil - }) - } - }) + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) }) -} -func assertCloseStatus(err error, code websocket.StatusCode) error { - var cerr websocket.CloseError - if !errors.As(err, &cerr) { - return fmt.Errorf("no websocket close error in error chain: %+v", err) - } - return assert.Equalf(code, cerr.Code, "unexpected status code") -} + t.Run("concurrentWriteError", func(t *testing.T) { + tt, c1, _ := newConnTest(t, nil, nil) + defer tt.cleanup() -func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { - expType := reflect.TypeOf(exp) - actv := reflect.New(expType.Elem()) - act := actv.Interface().(proto.Message) - err := wspb.Read(ctx, c, act) - if err != nil { - return err - } + _, err := c1.Writer(tt.ctx, websocket.MessageText) + assert.Success(t, err) - return assert.Equalf(exp, act, "unexpected protobuf") -} + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() -func assertNetConnRead(r io.Reader, exp string) error { - act := make([]byte, len(exp)) - _, err := r.Read(act) - if err != nil { - return err - } - return assert.Equalf(exp, string(act), "unexpected net conn read") -} + err = c1.Write(ctx, websocket.MessageText, []byte("x")) + assert.Equal(t, "write error", context.DeadlineExceeded, err) + }) -func assertErrorContains(err error, exp string) error { - if err == nil || !strings.Contains(err.Error(), exp) { - return fmt.Errorf("expected error that contains %q but got: %+v", exp, err) - } - return nil -} + t.Run("netConn", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + defer tt.cleanup() -func assertErrorIs(exp, act error) error { - if !errors.Is(act, exp) { - return fmt.Errorf("expected error %+v to be in %+v", exp, act) - } - return nil -} + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) -func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.OpCode, p []byte) error { - actOpcode, actP, err := c.ReadFrame(ctx) - if err != nil { - return err - } - err = assert.Equalf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP) - if err != nil { - return err - } - return assert.Equalf(p, actP, "unexpected frame %v payload", opcode) -} + // Does not give any confidence but at least ensures no crashes. + d, _ := tt.ctx.Deadline() + n1.SetDeadline(d) + n1.SetDeadline(time.Time{}) -func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error { - actOpcode, actP, err := c.ReadFrame(ctx) - if err != nil { - return err - } - err = assert.Equalf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP) - if err != nil { - return err - } - ce, err := websocket.ParseClosePayload(actP) - if err != nil { - return fmt.Errorf("failed to parse close frame payload: %w", err) - } - return assert.Equalf(ce.Code, code, "unexpected frame close frame code with payload %q", actP) -} + assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr()) + assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String()) + assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network()) -func assertStreamPing(ctx context.Context, c *websocket.Conn, l int) error { - err := c.WriteHeader(ctx, websocket.Header{ - Fin: true, - OpCode: websocket.OpPing, - PayloadLength: int64(l), - }) - if err != nil { - return err - } - for i := 0; i < l; i++ { - err = c.BW().WriteByte(0xFE) - if err != nil { - return fmt.Errorf("failed to write byte %d: %w", i, err) - } - if i%32 == 0 { - err = c.BW().Flush() + errs := xsync.Go(func() error { + _, err := n2.Write([]byte("hello")) if err != nil { - return fmt.Errorf("failed to flush at byte %d: %w", i, err) + return err } - } - } - err = c.BW().Flush() - if err != nil { - return fmt.Errorf("failed to flush: %v", err) - } - return assertReadFrame(ctx, c, websocket.OpPong, bytes.Repeat([]byte{0xFE}, l)) -} + return n2.Close() + }) -func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, p []byte) error { - actTyp, actP, err := c.Read(ctx) - if err != nil { - return err - } - err = assert.Equalf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP) - if err != nil { - return err - } - return assert.Equalf(p, actP, "unexpected frame %v payload", actTyp) -} + b, err := ioutil.ReadAll(n1) + assert.Success(t, err) -func BenchmarkConn(b *testing.B) { - sizes := []int{ - 2, - 16, - 32, - 512, - 4096, - 16384, - } + _, err = n1.Read(nil) + assert.Equal(t, "read error", err, io.EOF) - b.Run("write", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - b.Run("stream", func(b *testing.B) { - benchConn(b, false, true, size) - }) - b.Run("buffer", func(b *testing.B) { - benchConn(b, false, false, size) - }) - }) - } - }) + err = <-errs + assert.Success(t, err) - b.Run("echo", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - benchConn(b, true, true, size) - }) - } + assert.Equal(t, "read msg", []byte("hello"), b) }) -} - -func benchConn(b *testing.B, echo, stream bool, size int) { - s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - if echo { - wsecho.Loop(r.Context(), c) - } else { - discardLoop(r.Context(), c) - } - return nil - }, false) - defer closeFn() - wsURL := strings.Replace(s.URL, "http", "ws", 1) + t.Run("netConn/BadMsg", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + defer tt.cleanup() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) - c, _, err := websocket.Dial(ctx, wsURL, nil) - if err != nil { - b.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - msg := []byte(strings.Repeat("2", size)) - readBuf := make([]byte, len(msg)) - b.SetBytes(int64(len(msg))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if stream { - w, err := c.Writer(ctx, websocket.MessageText) + errs := xsync.Go(func() error { + _, err := n2.Write([]byte("hello")) if err != nil { - b.Fatal(err) + return err } + return nil + }) - _, err = w.Write(msg) - if err != nil { - b.Fatal(err) - } + _, err := ioutil.ReadAll(n1) + assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) - err = w.Close() - if err != nil { - b.Fatal(err) - } - } else { - err = c.Write(ctx, websocket.MessageText, msg) - if err != nil { - b.Fatal(err) - } - } + err = <-errs + assert.Success(t, err) + }) - if echo { - _, r, err := c.Reader(ctx) - if err != nil { - b.Fatal(err) - } + t.Run("wsjson", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + defer tt.cleanup() - _, err = io.ReadFull(r, readBuf) - if err != nil { - b.Fatal(err) - } - } - } - b.StopTimer() + tt.goEchoLoop(c2) - c.Close(websocket.StatusNormalClosure, "") -} + c1.SetReadLimit(1 << 30) -func discardLoop(ctx context.Context, c *websocket.Conn) { - defer c.Close(websocket.StatusInternalError, "") + exp := xrand.String(xrand.Int(131072)) - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() + werr := xsync.Go(func() error { + return wsjson.Write(tt.ctx, c1, exp) + }) - b := make([]byte, 32768) - echo := func() error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } + var act interface{} + err := wsjson.Read(tt.ctx, c1, &act) + assert.Success(t, err) + assert.Equal(t, "read msg", exp, act) - _, err = io.CopyBuffer(ioutil.Discard, r, b) - if err != nil { - return err - } - return nil - } + err = <-werr + assert.Success(t, err) - for { - err := echo() - if err != nil { - return - } - } -} + err = c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + }) -func TestAutobahnPython(t *testing.T) { - // This test contains the old autobahn test suite tests that use the - // python binary. The approach is clunky and slow so new tests - // have been written in pure Go in websocket_test.go. - // These have been kept for correctness purposes and are occasionally ran. - if os.Getenv("AUTOBAHN_PYTHON") == "" { - t.Skip("Set $AUTOBAHN_PYTHON to run tests against the python autobahn test suite") - } + t.Run("wspb", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + defer tt.cleanup() + + tt.goEchoLoop(c2) - t.Run("server", testServerAutobahnPython) - t.Run("client", testClientAutobahnPython) + exp := ptypes.DurationProto(100) + err := wspb.Write(tt.ctx, c1, exp) + assert.Success(t, err) + + act := &duration.Duration{} + err = wspb.Read(tt.ctx, c1, act) + assert.Success(t, err) + assert.Equal(t, "read msg", exp, act) + + err = c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + }) } -// https://github.com/crossbario/autobahn-python/tree/master/wstest -func testServerAutobahnPython(t *testing.T) { +func TestWasm(t *testing.T) { t.Parallel() + var wg sync.WaitGroup s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wg.Add(1) + defer wg.Done() + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, + Subprotocols: []string{"echo"}, + InsecureSkipVerify: true, }) if err != nil { - t.Logf("server handshake failed: %+v", err) + t.Errorf("echo server failed: %v", err) return } - wsecho.Loop(r.Context(), c) - })) - defer s.Close() - - spec := map[string]interface{}{ - "outdir": "ci/out/wstestServerReports", - "servers": []interface{}{ - map[string]interface{}{ - "agent": "main", - "url": strings.Replace(s.URL, "http", "ws", 1), - }, - }, - "cases": []string{"*"}, - // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just - // more performance overhead. 7.5.1 is the same. - // 12.* and 13.* as we do not support compression. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) - } - defer specFile.Close() + defer c.Close(websocket.StatusInternalError, "") - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } + err = wstest.EchoLoop(r.Context(), c) - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } + err = assertCloseStatus(websocket.StatusNormalClosure, err) + if err != nil { + t.Errorf("echo server failed: %v", err) + return + } + })) + defer wg.Wait() + defer s.Close() - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} - wstest := exec.CommandContext(ctx, "wstest", args...) - out, err := wstest.CombinedOutput() + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") + cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wstest.URL(s))) + + b, err := cmd.CombinedOutput() if err != nil { - t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) + t.Fatalf("wasm test binary failed: %v:\n%s", err, b) } - - checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") } -func unusedListenAddr() (string, error) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - return "", err +func assertCloseStatus(exp websocket.StatusCode, err error) error { + if websocket.CloseStatus(err) == -1 { + return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err) + } + if websocket.CloseStatus(err) != exp { + return xerrors.Errorf("expected close status %v but got ", exp, err) } - l.Close() - return l.Addr().String(), nil + return nil } -// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py -func testClientAutobahnPython(t *testing.T) { - t.Parallel() +type connTest struct { + t testing.TB + ctx context.Context - if os.Getenv("AUTOBAHN_PYTHON") == "" { - t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite") - } + doneFuncs []func() +} - serverAddr, err := unusedListenAddr() - if err != nil { - t.Fatalf("failed to get unused listen addr for wstest: %v", err) +func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { + if t, ok := t.(*testing.T); ok { + t.Parallel() } + t.Helper() - wsServerURL := "ws://" + serverAddr + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + tt = &connTest{t: t, ctx: ctx} + tt.appendDone(cancel) - spec := map[string]interface{}{ - "url": wsServerURL, - "outdir": "ci/out/wstestClientReports", - "cases": []string{"*"}, - // See TestAutobahnServer for the reasons why we exclude these. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) - } - defer specFile.Close() + c1, c2, err := wstest.Pipe(dialOpts, acceptOpts) + assert.Success(tt.t, err) + tt.appendDone(func() { + c2.Close(websocket.StatusInternalError, "") + c1.Close(websocket.StatusInternalError, "") + }) - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } + return tt, c1, c2 +} - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) +func (tt *connTest) appendDone(f func()) { + tt.doneFuncs = append(tt.doneFuncs, f) +} + +func (tt *connTest) cleanup() { + for i := len(tt.doneFuncs) - 1; i >= 0; i-- { + tt.doneFuncs[i]() } +} - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() +func (tt *connTest) goEchoLoop(c *websocket.Conn) { + ctx, cancel := context.WithCancel(tt.ctx) - args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), - // Disables some server that runs as part of fuzzingserver mode. - // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 - "--webport=0", - } - wstest := exec.CommandContext(ctx, "wstest", args...) - err = wstest.Start() - if err != nil { - t.Fatal(err) - } - defer func() { - err := wstest.Process.Kill() + echoLoopErr := xsync.Go(func() error { + err := wstest.EchoLoop(ctx, c) + return assertCloseStatus(websocket.StatusNormalClosure, err) + }) + tt.appendDone(func() { + cancel() + err := <-echoLoopErr if err != nil { - t.Error(err) + tt.t.Errorf("echo loop error: %v", err) } - }() + }) +} - // Let it come up. - time.Sleep(time.Second * 5) +func (tt *connTest) goDiscardLoop(c *websocket.Conn) { + ctx, cancel := context.WithCancel(tt.ctx) - var cases int - func() { - c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) - if err != nil { - t.Fatal(err) - } + discardLoopErr := xsync.Go(func() error { defer c.Close(websocket.StatusInternalError, "") - _, r, err := c.Reader(ctx) - if err != nil { - t.Fatal(err) - } - b, err := ioutil.ReadAll(r) - if err != nil { - t.Fatal(err) + for { + _, _, err := c.Read(ctx) + if err != nil { + return assertCloseStatus(websocket.StatusNormalClosure, err) + } } - cases, err = strconv.Atoi(string(b)) + }) + tt.appendDone(func() { + cancel() + err := <-discardLoopErr if err != nil { - t.Fatal(err) + tt.t.Errorf("discard loop error: %v", err) } + }) +} - c.Close(websocket.StatusNormalClosure, "") - }() - - for i := 1; i <= cases; i++ { - func() { - ctx, cancel := context.WithTimeout(ctx, time.Second*45) - defer cancel() - - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) - if err != nil { - t.Fatal(err) - } - wsecho.Loop(ctx, c) - }() +func BenchmarkConn(b *testing.B) { + var benchCases = []struct { + name string + mode websocket.CompressionMode + }{ + { + name: "disabledCompress", + mode: websocket.CompressionDisabled, + }, + { + name: "compress", + mode: websocket.CompressionContextTakeover, + }, + { + name: "compressNoContext", + mode: websocket.CompressionNoContextTakeover, + }, } + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + bb, c1, c2 := newConnTest(b, &websocket.DialOptions{ + CompressionMode: bc.mode, + }, &websocket.AcceptOptions{ + CompressionMode: bc.mode, + }) + defer bb.cleanup() - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) - if err != nil { - t.Fatal(err) - } - c.Close(websocket.StatusNormalClosure, "") + bb.goEchoLoop(c2) - checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") -} + bytesWritten := c1.RecordBytesWritten() + bytesRead := c1.RecordBytesRead() -func checkWSTestIndex(t *testing.T, path string) { - wstestOut, err := ioutil.ReadFile(path) - if err != nil { - t.Fatalf("failed to read index.json: %v", err) - } + msg := []byte(strings.Repeat("1234", 128)) + readBuf := make([]byte, len(msg)) + writes := make(chan struct{}) + defer close(writes) + werrs := make(chan error) - var indexJSON map[string]map[string]struct { - Behavior string `json:"behavior"` - BehaviorClose string `json:"behaviorClose"` - } - err = json.Unmarshal(wstestOut, &indexJSON) - if err != nil { - t.Fatalf("failed to unmarshal index.json: %v", err) - } + go func() { + for range writes { + werrs <- c1.Write(bb.ctx, websocket.MessageText, msg) + } + }() + b.SetBytes(int64(len(msg))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + writes <- struct{}{} - var failed bool - for _, tests := range indexJSON { - for test, result := range tests { - switch result.Behavior { - case "OK", "NON-STRICT", "INFORMATIONAL": - default: - failed = true - t.Errorf("test %v failed", test) - } - switch result.BehaviorClose { - case "OK", "INFORMATIONAL": - default: - failed = true - t.Errorf("bad close behaviour for test %v", test) - } - } - } + typ, r, err := c1.Reader(bb.ctx) + if err != nil { + b.Fatal(err) + } + if websocket.MessageText != typ { + assert.Equal(b, "data type", websocket.MessageText, typ) + } - if failed { - path = strings.Replace(path, ".json", ".html", 1) - if os.Getenv("CI") == "" { - t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path) - } - } -} + _, err = io.ReadFull(r, readBuf) + if err != nil { + b.Fatal(err) + } -func TestWASM(t *testing.T) { - t.Parallel() + n2, err := r.Read(readBuf) + if err != io.EOF { + assert.Equal(b, "read err", io.EOF, err) + } + if n2 != 0 { + assert.Equal(b, "n2", 0, n2) + } - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") + if !bytes.Equal(msg, readBuf) { + assert.Equal(b, "msg", msg, readBuf) + } - err = wsecho.Loop(r.Context(), c) - if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - return err - } - return nil - }, false) - defer closeFn() + err = <-werrs + if err != nil { + b.Fatal(err) + } + } + b.StopTimer() - wsURL := strings.Replace(s.URL, "http", "ws", 1) + b.ReportMetric(float64(*bytesWritten/b.N), "written/op") + b.ReportMetric(float64(*bytesRead/b.N), "read/op") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) - defer cancel() + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(b, err) + }) + } +} - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wsURL)) +func TestCompression(t *testing.T) { + t.Parallel() - b, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("wasm test binary failed: %v:\n%s", err, b) - } } diff --git a/dial.go b/dial.go new file mode 100644 index 00000000..09546ac6 --- /dev/null +++ b/dial.go @@ -0,0 +1,270 @@ +// +build !js + +package websocket + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "sync" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/errd" +) + +// DialOptions represents Dial's options. +type DialOptions struct { + // HTTPClient is used for the connection. + // Its Transport must return writable bodies for WebSocket handshakes. + // http.Transport does beginning with Go 1.12. + HTTPClient *http.Client + + // HTTPHeader specifies the HTTP headers included in the handshake request. + HTTPHeader http.Header + + // Subprotocols lists the WebSocket subprotocols to negotiate with the server. + Subprotocols []string + + // CompressionMode controls the compression mode. + // Defaults to CompressionNoContextTakeover. + // + // See docs on CompressionMode for details. + CompressionMode CompressionMode + + // CompressionThreshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes + // for CompressionContextTakeover. + CompressionThreshold int +} + +// Dial performs a WebSocket handshake on url. +// +// The response is the WebSocket handshake response from the server. +// You never need to close resp.Body yourself. +// +// If an error occurs, the returned response may be non nil. +// However, you can only read the first 1024 bytes of the body. +// +// This function requires at least Go 1.12 as it uses a new feature +// in net/http to perform WebSocket handshakes. +// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { + return dial(ctx, u, opts, nil) +} + +func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { + defer errd.Wrap(&err, "failed to WebSocket dial") + + if opts == nil { + opts = &DialOptions{} + } + + opts = &*opts + if opts.HTTPClient == nil { + opts.HTTPClient = http.DefaultClient + } + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + + secWebSocketKey, err := secWebSocketKey(rand) + if err != nil { + return nil, nil, xerrors.Errorf("failed to generate Sec-WebSocket-Key: %w", err) + } + + resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey) + if err != nil { + return nil, resp, err + } + respBody := resp.Body + resp.Body = nil + defer func() { + if err != nil { + // We read a bit of the body for easier debugging. + r := io.LimitReader(respBody, 1024) + b, _ := ioutil.ReadAll(r) + respBody.Close() + resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + } + }() + + copts, err := verifyServerResponse(opts, secWebSocketKey, resp) + if err != nil { + return nil, resp, err + } + + rwc, ok := respBody.(io.ReadWriteCloser) + if !ok { + return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) + } + + return newConn(connConfig{ + subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), + rwc: rwc, + client: true, + copts: copts, + flateThreshold: opts.CompressionThreshold, + br: getBufioReader(rwc), + bw: getBufioWriter(rwc), + }), resp, nil +} + +func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) { + if opts.HTTPClient.Timeout > 0 { + return nil, xerrors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + } + + u, err := url.Parse(urls) + if err != nil { + return nil, xerrors.Errorf("failed to parse url: %w", err) + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, xerrors.Errorf("unexpected url scheme: %q", u.Scheme) + } + + req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req.Header = opts.HTTPHeader.Clone() + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) + if len(opts.Subprotocols) > 0 { + req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) + } + if opts.CompressionMode != CompressionDisabled { + copts := opts.CompressionMode.opts() + copts.setHeader(req.Header) + } + + resp, err := opts.HTTPClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("failed to send handshake request: %w", err) + } + return resp, nil +} + +func secWebSocketKey(rr io.Reader) (string, error) { + if rr == nil { + rr = rand.Reader + } + b := make([]byte, 16) + _, err := io.ReadFull(rr, b) + if err != nil { + return "", xerrors.Errorf("failed to read random data from rand.Reader: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil +} + +func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + if !headerContainsToken(resp.Header, "Connection", "Upgrade") { + return nil, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + } + + if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { + return nil, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + } + + if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { + return nil, xerrors.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + resp.Header.Get("Sec-WebSocket-Accept"), + secWebSocketKey, + ) + } + + err := verifySubprotocol(opts.Subprotocols, resp) + if err != nil { + return nil, err + } + + return verifyServerExtensions(resp.Header) +} + +func verifySubprotocol(subprotos []string, resp *http.Response) error { + proto := resp.Header.Get("Sec-WebSocket-Protocol") + if proto == "" { + return nil + } + + for _, sp2 := range subprotos { + if strings.EqualFold(sp2, proto) { + return nil + } + } + + return xerrors.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) +} + +func verifyServerExtensions(h http.Header) (*compressionOptions, error) { + exts := websocketExtensions(h) + if len(exts) == 0 { + return nil, nil + } + + ext := exts[0] + if ext.name != "permessage-deflate" || len(exts) > 1 { + return nil, xerrors.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) + } + + copts := &compressionOptions{} + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + default: + return nil, xerrors.Errorf("unsupported permessage-deflate parameter: %q", p) + } + } + + return copts, nil +} + +var readerPool sync.Pool + +func getBufioReader(r io.Reader) *bufio.Reader { + br, ok := readerPool.Get().(*bufio.Reader) + if !ok { + return bufio.NewReader(r) + } + br.Reset(r) + return br +} + +func putBufioReader(br *bufio.Reader) { + readerPool.Put(br) +} + +var writerPool sync.Pool + +func getBufioWriter(w io.Writer) *bufio.Writer { + bw, ok := writerPool.Get().(*bufio.Writer) + if !ok { + return bufio.NewWriter(w) + } + bw.Reset(w) + return bw +} + +func putBufioWriter(bw *bufio.Writer) { + writerPool.Put(bw) +} diff --git a/dial_test.go b/dial_test.go new file mode 100644 index 00000000..06084cc5 --- /dev/null +++ b/dial_test.go @@ -0,0 +1,244 @@ +// +build !js + +package websocket + +import ( + "context" + "crypto/rand" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "nhooyr.io/websocket/internal/test/assert" +) + +func TestBadDials(t *testing.T) { + t.Parallel() + + t.Run("badReq", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + url string + opts *DialOptions + rand readerFunc + }{ + { + name: "badURL", + url: "://noscheme", + }, + { + name: "badURLScheme", + url: "ftp://nhooyr.io", + }, + { + name: "badHTTPClient", + url: "ws://nhooyr.io", + opts: &DialOptions{ + HTTPClient: &http.Client{ + Timeout: time.Minute, + }, + }, + }, + { + name: "badTLS", + url: "wss://totallyfake.nhooyr.io", + }, + { + name: "badReader", + rand: func(p []byte) (int, error) { + return 0, io.EOF + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + if tc.rand == nil { + tc.rand = rand.Reader.Read + } + + _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) + assert.Error(t, err) + }) + } + }) + + t.Run("badResponse", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("hi")), + }, nil + }), + }) + assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") + }) + + t.Run("badBody", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + rt := func(r *http.Request) (*http.Response, error) { + h := http.Header{} + h.Set("Connection", "Upgrade") + h.Set("Upgrade", "websocket") + h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: h, + Body: ioutil.NopCloser(strings.NewReader("hi")), + }, nil + } + + _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + HTTPClient: mockHTTPClient(rt), + }) + assert.Contains(t, err, "response body is not a io.ReadWriteCloser") + }) +} + +func Test_verifyServerHandshake(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + response func(w http.ResponseWriter) + success bool + }{ + { + name: "badStatus", + response: func(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) + }, + success: false, + }, + { + name: "badConnection", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badUpgrade", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badSecWebSocketAccept", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Accept", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badSecWebSocketProtocol", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Protocol", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "unsupportedExtension", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Extensions", "meow") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "unsupportedDeflateParam", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "success", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + tc.response(w) + resp := w.Result() + + r := httptest.NewRequest("GET", "/", nil) + key, err := secWebSocketKey(rand.Reader) + assert.Success(t, err) + r.Header.Set("Sec-WebSocket-Key", key) + + if resp.Header.Get("Sec-WebSocket-Accept") == "" { + resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + } + + opts := &DialOptions{ + Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), + } + _, err = verifyServerResponse(opts, key, resp) + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) + } + }) + } +} + +func mockHTTPClient(fn roundTripperFunc) *http.Client { + return &http.Client{ + Transport: fn, + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} diff --git a/doc.go b/doc.go index b29d2cdd..efa920e3 100644 --- a/doc.go +++ b/doc.go @@ -1,53 +1,32 @@ // +build !js -// Package websocket is a minimal and idiomatic implementation of the WebSocket protocol. +// Package websocket implements the RFC 6455 WebSocket protocol. // // https://tools.ietf.org/html/rfc6455 // -// Conn, Dial, and Accept are the main entrypoints into this package. Use Dial to dial -// a WebSocket server, Accept to accept a WebSocket client dial and then Conn to interact -// with the resulting WebSocket connections. +// Use Dial to dial a WebSocket server. // -// The examples are the best way to understand how to correctly use the library. +// Use Accept to accept a WebSocket client. +// +// Conn represents the resulting WebSocket connection. // -// The wsjson and wspb subpackages contain helpers for JSON and ProtoBuf messages. +// The examples are the best way to understand how to correctly use the library. // -// See https://nhooyr.io/websocket for more overview docs and a -// comparison with existing implementations. +// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages. // -// Use the errors.As function new in Go 1.13 to check for websocket.CloseError. -// Or use the CloseStatus function to grab the StatusCode out of a websocket.CloseError -// See the CloseStatus example. +// More documentation at https://nhooyr.io/websocket. // // Wasm // -// The client side fully supports compiling to Wasm. +// The client side supports compiling to Wasm. // It wraps the WebSocket browser API. // // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket // -// Thus the unsupported features (not compiled in) for Wasm are: -// -// - Accept and AcceptOptions -// - Conn.Ping -// - HTTPClient and HTTPHeader fields in DialOptions -// -// The *http.Response returned by Dial will always either be nil or &http.Response{} as -// we do not have access to the handshake response in the browser. -// -// The Writer method on the Conn buffers everything in memory and then sends it as a message -// when the writer is closed. -// -// The Reader method also reads the entire response and then returns a reader that -// reads from the byte slice. -// -// SetReadLimit cannot actually limit the number of bytes read from the connection so instead -// when a message beyond the limit is fully read, it throws an error. -// -// Writes are also always async so the passed context is no-op. -// -// Everything else is fully supported. This includes the wsjson and wspb helper packages. +// Some important caveats to be aware of: // -// Once https://github.com/gopherjs/gopherjs/issues/929 is closed, GopherJS should be supported -// as well. +// - Accept always errors out +// - Conn.Ping is no-op +// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op +// - *http.Response from Dial is &http.Response{} with a 101 status code on success package websocket // import "nhooyr.io/websocket" diff --git a/example_echo_test.go b/example_echo_test.go index ecc9b97c..1daec8a5 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -12,6 +12,7 @@ import ( "time" "golang.org/x/time/rate" + "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" @@ -77,7 +78,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { if c.Subprotocol() != "echo" { c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") - return fmt.Errorf("client does not speak echo sub protocol") + return xerrors.New("client does not speak echo sub protocol") } l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) @@ -87,12 +88,12 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { return nil } if err != nil { - return fmt.Errorf("failed to echo with %v: %w", r.RemoteAddr, err) + return xerrors.Errorf("failed to echo with %v: %w", r.RemoteAddr, err) } } } -// echo reads from the websocket connection and then writes +// echo reads from the WebSocket connection and then writes // the received message back to it. // The entire function has 10s to complete. func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { @@ -116,7 +117,7 @@ func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { _, err = io.Copy(w, r) if err != nil { - return fmt.Errorf("failed to io.Copy: %w", err) + return xerrors.Errorf("failed to io.Copy: %w", err) } err = w.Close() diff --git a/example_test.go b/example_test.go index bc603aff..075107b0 100644 --- a/example_test.go +++ b/example_test.go @@ -33,8 +33,6 @@ func ExampleAccept() { return } - log.Printf("received: %v", v) - c.Close(websocket.StatusNormalClosure, "") }) @@ -76,8 +74,7 @@ func ExampleCloseStatus() { _, _, err = c.Reader(ctx) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %+v", err) - return + log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %v", err) } } diff --git a/export_test.go b/export_test.go new file mode 100644 index 00000000..88b82c9f --- /dev/null +++ b/export_test.go @@ -0,0 +1,22 @@ +// +build !js + +package websocket + +func (c *Conn) RecordBytesWritten() *int { + var bytesWritten int + c.bw.Reset(writerFunc(func(p []byte) (int, error) { + bytesWritten += len(p) + return c.rwc.Write(p) + })) + return &bytesWritten +} + +func (c *Conn) RecordBytesRead() *int { + var bytesRead int + c.br.Reset(readerFunc(func(p []byte) (int, error) { + n, err := c.rwc.Read(p) + bytesRead += n + return n, err + })) + return &bytesRead +} diff --git a/frame.go b/frame.go index e4bf931a..4acaecf4 100644 --- a/frame.go +++ b/frame.go @@ -1,20 +1,21 @@ package websocket import ( + "bufio" "encoding/binary" - "errors" - "fmt" "io" "math" "math/bits" -) -//go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/errd" +) -// opcode represents a WebSocket Opcode. +// opcode represents a WebSocket opcode. type opcode int -// opcode constants. +// https://tools.ietf.org/html/rfc6455#section-11.8. const ( opContinuation opcode = iota opText @@ -31,35 +32,8 @@ const ( // 11-16 are reserved for further control frames. ) -func (o opcode) controlOp() bool { - switch o { - case opClose, opPing, opPong: - return true - } - return false -} - -// MessageType represents the type of a WebSocket message. -// See https://tools.ietf.org/html/rfc6455#section-5.6 -type MessageType int - -// MessageType constants. -const ( - // MessageText is for UTF-8 encoded text messages like JSON. - MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like Protobufs. - MessageBinary -) - -// First byte contains fin, rsv1, rsv2, rsv3. -// Second byte contains mask flag and payload len. -// Next 8 bytes are the maximum extended payload length. -// Last 4 bytes are the mask key. -// https://tools.ietf.org/html/rfc6455#section-5.2 -const maxHeaderSize = 1 + 1 + 8 + 4 - // header represents a WebSocket frame header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 +// See https://tools.ietf.org/html/rfc6455#section-5.2. type header struct { fin bool rsv1 bool @@ -73,256 +47,132 @@ type header struct { maskKey uint32 } -func makeWriteHeaderBuf() []byte { - return make([]byte, maxHeaderSize) -} - -// bytes returns the bytes of the header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func writeHeader(b []byte, h header) []byte { - if b == nil { - b = makeWriteHeaderBuf() - } - - b = b[:2] - b[0] = 0 - - if h.fin { - b[0] |= 1 << 7 - } - if h.rsv1 { - b[0] |= 1 << 6 - } - if h.rsv2 { - b[0] |= 1 << 5 - } - if h.rsv3 { - b[0] |= 1 << 4 - } - - b[0] |= byte(h.opcode) +// readFrameHeader reads a header from the reader. +// See https://tools.ietf.org/html/rfc6455#section-5.2. +func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { + defer errd.Wrap(&err, "failed to read frame header") - switch { - case h.payloadLength < 0: - panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength)) - case h.payloadLength <= 125: - b[1] = byte(h.payloadLength) - case h.payloadLength <= math.MaxUint16: - b[1] = 126 - b = b[:len(b)+2] - binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength)) - default: - b[1] = 127 - b = b[:len(b)+8] - binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.payloadLength)) - } - - if h.masked { - b[1] |= 1 << 7 - b = b[:len(b)+4] - binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey) + b, err := r.ReadByte() + if err != nil { + return header{}, err } - return b -} + h.fin = b&(1<<7) != 0 + h.rsv1 = b&(1<<6) != 0 + h.rsv2 = b&(1<<5) != 0 + h.rsv3 = b&(1<<4) != 0 -func makeReadHeaderBuf() []byte { - return make([]byte, maxHeaderSize-2) -} - -// readHeader reads a header from the reader. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func readHeader(b []byte, r io.Reader) (header, error) { - if b == nil { - b = makeReadHeaderBuf() - } + h.opcode = opcode(b & 0xf) - // We read the first two bytes first so that we know - // exactly how long the header is. - b = b[:2] - _, err := io.ReadFull(r, b) + b, err = r.ReadByte() if err != nil { return header{}, err } - var h header - h.fin = b[0]&(1<<7) != 0 - h.rsv1 = b[0]&(1<<6) != 0 - h.rsv2 = b[0]&(1<<5) != 0 - h.rsv3 = b[0]&(1<<4) != 0 - - h.opcode = opcode(b[0] & 0xf) - - var extra int + h.masked = b&(1<<7) != 0 - h.masked = b[1]&(1<<7) != 0 - if h.masked { - extra += 4 - } - - payloadLength := b[1] &^ (1 << 7) + payloadLength := b &^ (1 << 7) switch { case payloadLength < 126: h.payloadLength = int64(payloadLength) case payloadLength == 126: - extra += 2 + _, err = io.ReadFull(r, readBuf[:2]) + h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) case payloadLength == 127: - extra += 8 - } - - if extra == 0 { - return h, nil + _, err = io.ReadFull(r, readBuf) + h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) } - - b = b[:extra] - _, err = io.ReadFull(r, b) if err != nil { return header{}, err } - switch { - case payloadLength == 126: - h.payloadLength = int64(binary.BigEndian.Uint16(b)) - b = b[2:] - case payloadLength == 127: - h.payloadLength = int64(binary.BigEndian.Uint64(b)) - if h.payloadLength < 0 { - return header{}, fmt.Errorf("header with negative payload length: %v", h.payloadLength) - } - b = b[8:] + if h.payloadLength < 0 { + return header{}, xerrors.Errorf("received negative payload length: %v", h.payloadLength) } if h.masked { - h.maskKey = binary.LittleEndian.Uint32(b) + _, err = io.ReadFull(r, readBuf[:4]) + if err != nil { + return header{}, err + } + h.maskKey = binary.LittleEndian.Uint32(readBuf) } return h, nil } -// StatusCode represents a WebSocket status code. -// https://tools.ietf.org/html/rfc6455#section-7.4 -type StatusCode int +// maxControlPayload is the maximum length of a control frame payload. +// See https://tools.ietf.org/html/rfc6455#section-5.5. +const maxControlPayload = 125 -// These codes were retrieved from: -// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// -// The defined constants only represent the status codes registered with IANA. -// The 4000-4999 range of status codes is reserved for arbitrary use by applications. -const ( - StatusNormalClosure StatusCode = 1000 - StatusGoingAway StatusCode = 1001 - StatusProtocolError StatusCode = 1002 - StatusUnsupportedData StatusCode = 1003 - - // 1004 is reserved and so not exported. - statusReserved StatusCode = 1004 - - // StatusNoStatusRcvd cannot be sent in a close message. - // It is reserved for when a close message is received without - // an explicit status. - StatusNoStatusRcvd StatusCode = 1005 - - // StatusAbnormalClosure is only exported for use with Wasm. - // In non Wasm Go, the returned error will indicate whether the connection was closed or not or what happened. - StatusAbnormalClosure StatusCode = 1006 - - StatusInvalidFramePayloadData StatusCode = 1007 - StatusPolicyViolation StatusCode = 1008 - StatusMessageTooBig StatusCode = 1009 - StatusMandatoryExtension StatusCode = 1010 - StatusInternalError StatusCode = 1011 - StatusServiceRestart StatusCode = 1012 - StatusTryAgainLater StatusCode = 1013 - StatusBadGateway StatusCode = 1014 - - // StatusTLSHandshake is only exported for use with Wasm. - // In non Wasm Go, the returned error will indicate whether there was a TLS handshake failure. - StatusTLSHandshake StatusCode = 1015 -) - -// CloseError represents a WebSocket close frame. -// It is returned by Conn's methods when a WebSocket close frame is received from -// the peer. -// You will need to use the https://golang.org/pkg/errors/#As function, new in Go 1.13, -// to check for this error. See the CloseError example. -type CloseError struct { - Code StatusCode - Reason string -} - -func (ce CloseError) Error() string { - return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) -} +// writeFrameHeader writes the bytes of the header to w. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { + defer errd.Wrap(&err, "failed to write frame header") -// CloseStatus is a convenience wrapper around errors.As to grab -// the status code from a *CloseError. If the passed error is nil -// or not a *CloseError, the returned StatusCode will be -1. -func CloseStatus(err error) StatusCode { - var ce CloseError - if errors.As(err, &ce) { - return ce.Code + var b byte + if h.fin { + b |= 1 << 7 } - return -1 -} - -func parseClosePayload(p []byte) (CloseError, error) { - if len(p) == 0 { - return CloseError{ - Code: StatusNoStatusRcvd, - }, nil + if h.rsv1 { + b |= 1 << 6 } - - if len(p) < 2 { - return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + if h.rsv2 { + b |= 1 << 5 } - - ce := CloseError{ - Code: StatusCode(binary.BigEndian.Uint16(p)), - Reason: string(p[2:]), + if h.rsv3 { + b |= 1 << 4 } - if !validWireCloseCode(ce.Code) { - return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) - } + b |= byte(h.opcode) - return ce, nil -} + err = w.WriteByte(b) + if err != nil { + return err + } -// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// and https://tools.ietf.org/html/rfc6455#section-7.4.1 -func validWireCloseCode(code StatusCode) bool { - switch code { - case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: - return false + lengthByte := byte(0) + if h.masked { + lengthByte |= 1 << 7 } - if code >= StatusNormalClosure && code <= StatusBadGateway { - return true + switch { + case h.payloadLength > math.MaxUint16: + lengthByte |= 127 + case h.payloadLength > 125: + lengthByte |= 126 + case h.payloadLength >= 0: + lengthByte |= byte(h.payloadLength) } - if code >= 3000 && code <= 4999 { - return true + err = w.WriteByte(lengthByte) + if err != nil { + return err } - return false -} - -const maxControlFramePayload = 125 - -func (ce CloseError) bytes() ([]byte, error) { - if len(ce.Reason) > maxControlFramePayload-2 { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason)) + switch { + case h.payloadLength > math.MaxUint16: + binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) + _, err = w.Write(buf) + case h.payloadLength > 125: + binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) + _, err = w.Write(buf[:2]) } - if !validWireCloseCode(ce.Code) { - return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + if err != nil { + return err + } + + if h.masked { + binary.LittleEndian.PutUint32(buf, h.maskKey) + _, err = w.Write(buf[:4]) + if err != nil { + return err + } } - buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf, uint16(ce.Code)) - copy(buf[2:], ce.Reason) - return buf, nil + return nil } -// fastMask applies the WebSocket masking algorithm to p +// mask applies the WebSocket masking algorithm to p // with the given key. // See https://tools.ietf.org/html/rfc6455#section-5.3 // diff --git a/frame_test.go b/frame_test.go index 571e68fc..76826248 100644 --- a/frame_test.go +++ b/frame_test.go @@ -3,98 +3,25 @@ package websocket import ( + "bufio" "bytes" "encoding/binary" - "io" - "math" "math/bits" "math/rand" "strconv" - "strings" "testing" "time" _ "unsafe" "github.com/gobwas/ws" - "github.com/google/go-cmp/cmp" _ "github.com/gorilla/websocket" - "nhooyr.io/websocket/internal/assert" + "nhooyr.io/websocket/internal/test/assert" ) -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func randBool() bool { - return rand.Intn(1) == 0 -} - func TestHeader(t *testing.T) { t.Parallel() - t.Run("eof", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - bytes []byte - }{ - { - "start", - []byte{0xff}, - }, - { - "middle", - []byte{0xff, 0xff, 0xff}, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - b := bytes.NewBuffer(tc.bytes) - _, err := readHeader(nil, b) - if io.ErrUnexpectedEOF != err { - t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) - } - }) - } - }) - - t.Run("writeNegativeLength", func(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r == nil { - t.Fatal("failed to induce panic in writeHeader with negative payload length") - } - }() - - writeHeader(nil, header{ - payloadLength: -1, - }) - }) - - t.Run("readNegativeLength", func(t *testing.T) { - t.Parallel() - - b := writeHeader(nil, header{ - payloadLength: 1<<16 + 1, - }) - - // Make length negative - b[2] |= 1 << 7 - - r := bytes.NewReader(b) - _, err := readHeader(nil, r) - if err == nil { - t.Fatalf("unexpected error value: %+v", err) - } - }) - t.Run("lengths", func(t *testing.T) { t.Parallel() @@ -102,12 +29,12 @@ func TestHeader(t *testing.T) { 124, 125, 126, - 4096, - 16384, + 127, + + 65534, 65535, 65536, 65537, - 131072, } for _, n := range lengths { @@ -125,20 +52,22 @@ func TestHeader(t *testing.T) { t.Run("fuzz", func(t *testing.T) { t.Parallel() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + randBool := func() bool { + return r.Intn(1) == 0 + } + for i := 0; i < 10000; i++ { h := header{ fin: randBool(), rsv1: randBool(), rsv2: randBool(), rsv3: randBool(), - opcode: opcode(rand.Intn(1 << 4)), + opcode: opcode(r.Intn(16)), masked: randBool(), - payloadLength: rand.Int63(), - } - - if h.masked { - h.maskKey = rand.Uint32() + maskKey: r.Uint32(), + payloadLength: r.Int63(), } testHeader(t, h) @@ -147,168 +76,20 @@ func TestHeader(t *testing.T) { } func testHeader(t *testing.T, h header) { - b := writeHeader(nil, h) - r := bytes.NewReader(b) - h2, err := readHeader(nil, r) - if err != nil { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("failed to read header: %v", err) - } + b := &bytes.Buffer{} + w := bufio.NewWriter(b) + r := bufio.NewReader(b) - if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) - } -} + err := writeFrameHeader(h, w, make([]byte, 8)) + assert.Success(t, err) -func TestCloseError(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - ce CloseError - success bool - }{ - { - name: "normal", - ce: CloseError{ - Code: StatusNormalClosure, - Reason: strings.Repeat("x", maxControlFramePayload-2), - }, - success: true, - }, - { - name: "bigReason", - ce: CloseError{ - Code: StatusNormalClosure, - Reason: strings.Repeat("x", maxControlFramePayload-1), - }, - success: false, - }, - { - name: "bigCode", - ce: CloseError{ - Code: math.MaxUint16, - Reason: strings.Repeat("x", maxControlFramePayload-2), - }, - success: false, - }, - } + err = w.Flush() + assert.Success(t, err) - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() + h2, err := readFrameHeader(r, make([]byte, 8)) + assert.Success(t, err) - _, err := tc.ce.bytes() - if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %+v", err) - } - }) - } -} - -func Test_parseClosePayload(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - p []byte - success bool - ce CloseError - }{ - { - name: "normal", - p: append([]byte{0x3, 0xE8}, []byte("hello")...), - success: true, - ce: CloseError{ - Code: StatusNormalClosure, - Reason: "hello", - }, - }, - { - name: "nothing", - success: true, - ce: CloseError{ - Code: StatusNoStatusRcvd, - }, - }, - { - name: "oneByte", - p: []byte{0}, - success: false, - }, - { - name: "badStatusCode", - p: []byte{0x17, 0x70}, - success: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ce, err := parseClosePayload(tc.p) - if (err == nil) != tc.success { - t.Fatalf("unexpected expected error value: %+v", err) - } - - if tc.success && tc.ce != ce { - t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) - } - }) - } -} - -func Test_validWireCloseCode(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - code StatusCode - valid bool - }{ - { - name: "normal", - code: StatusNormalClosure, - valid: true, - }, - { - name: "noStatus", - code: StatusNoStatusRcvd, - valid: false, - }, - { - name: "3000", - code: 3000, - valid: true, - }, - { - name: "4999", - code: 4999, - valid: true, - }, - { - name: "unknown", - code: 5000, - valid: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - if valid := validWireCloseCode(tc.code); tc.valid != valid { - t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) - } - }) - } + assert.Equal(t, "read header", h, h2) } func Test_mask(t *testing.T) { @@ -319,13 +100,11 @@ func Test_mask(t *testing.T) { p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} gotKey32 := mask(key32, p) - if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { - t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) - } + expP := []byte{0, 0, 0, 0x0d, 0x6} + assert.Equal(t, "p", expP, p) - if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) { - t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32)) - } + expKey32 := bits.RotateLeft32(key32, -8) + assert.Equal(t, "key32", expKey32, gotKey32) } func basicMask(maskKey [4]byte, pos int, b []byte) int { @@ -395,11 +174,7 @@ func Benchmark_mask(b *testing.B) { }, } - var key [4]byte - _, err := rand.Read(key[:]) - if err != nil { - b.Fatalf("failed to populate mask key: %v", err) - } + key := [4]byte{1, 2, 3, 4} for _, size := range sizes { p := make([]byte, size) @@ -415,43 +190,3 @@ func Benchmark_mask(b *testing.B) { }) } } - -func TestCloseStatus(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - in error - exp StatusCode - }{ - { - name: "nil", - in: nil, - exp: -1, - }, - { - name: "io.EOF", - in: io.EOF, - exp: -1, - }, - { - name: "StatusInternalError", - in: CloseError{ - Code: StatusInternalError, - }, - exp: StatusInternalError, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - err := assert.Equalf(tc.exp, CloseStatus(tc.in), "unexpected close status") - if err != nil { - t.Fatal(err) - } - }) - } -} diff --git a/go.mod b/go.mod index e6ef0014..a10c7b1e 100644 --- a/go.mod +++ b/go.mod @@ -1,19 +1,15 @@ module nhooyr.io/websocket -go 1.13 +go 1.12 require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect github.com/gobwas/pool v0.2.0 // indirect github.com/gobwas/ws v1.0.2 - github.com/golang/protobuf v1.3.2 - github.com/google/go-cmp v0.3.1 + github.com/golang/protobuf v1.3.3 + github.com/google/go-cmp v0.4.0 github.com/gorilla/websocket v1.4.1 - github.com/kr/pretty v0.1.0 // indirect - github.com/stretchr/testify v1.4.0 // indirect - go.uber.org/atomic v1.4.0 // indirect - go.uber.org/multierr v1.1.0 - golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + github.com/klauspost/compress v1.10.0 + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 + golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/go.sum b/go.sum index d2f1f0e4..e4bbd62d 100644 --- a/go.sum +++ b/go.sum @@ -1,38 +1,18 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= +github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/handshake.go b/handshake.go deleted file mode 100644 index 11e46d8f..00000000 --- a/handshake.go +++ /dev/null @@ -1,422 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "bytes" - "context" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "errors" - "fmt" - "io" - "io/ioutil" - "net/http" - "net/textproto" - "net/url" - "strings" - "sync" -) - -// AcceptOptions represents the options available to pass to Accept. -type AcceptOptions struct { - // Subprotocols lists the websocket subprotocols that Accept will negotiate with a client. - // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to - // reject it, close the connection if c.Subprotocol() == "". - Subprotocols []string - - // InsecureSkipVerify disables Accept's origin verification - // behaviour. By default Accept only allows the handshake to - // succeed if the javascript that is initiating the handshake - // is on the same domain as the server. This is to prevent CSRF - // attacks when secure data is stored in a cookie as there is no same - // origin policy for WebSockets. In other words, javascript from - // any domain can perform a WebSocket dial on an arbitrary server. - // This dial will include cookies which means the arbitrary javascript - // can perform actions as the authenticated user. - // - // See https://stackoverflow.com/a/37837709/4283659 - // - // The only time you need this is if your javascript is running on a different domain - // than your WebSocket server. - // Think carefully about whether you really need this option before you use it. - // If you do, remember that if you store secure data in cookies, you wil need to verify the - // Origin header yourself otherwise you are exposing yourself to a CSRF attack. - InsecureSkipVerify bool -} - -func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { - if !r.ProtoAtLeast(1, 1) { - err := fmt.Errorf("websocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) - http.Error(w, err.Error(), http.StatusUpgradeRequired) - return err - } - - if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) - http.Error(w, err.Error(), http.StatusUpgradeRequired) - return err - } - - if r.Method != "GET" { - err := fmt.Errorf("websocket protocol violation: handshake request method is not GET but %q", r.Method) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if r.Header.Get("Sec-WebSocket-Version") != "13" { - w.Header().Set("Sec-WebSocket-Version", "13") - err := fmt.Errorf("unsupported websocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if r.Header.Get("Sec-WebSocket-Key") == "" { - err := errors.New("websocket protocol violation: missing Sec-WebSocket-Key") - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - return nil -} - -// Accept accepts a WebSocket handshake from a client and upgrades the -// the connection to a WebSocket. -// -// Accept will reject the handshake if the Origin domain is not the same as the Host unless -// the InsecureSkipVerify option is set. In other words, by default it does not allow -// cross origin requests. -// -// If an error occurs, Accept will always write an appropriate response so you do not -// have to. -func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - c, err := accept(w, r, opts) - if err != nil { - return nil, fmt.Errorf("failed to accept websocket connection: %w", err) - } - return c, nil -} - -func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - if opts == nil { - opts = &AcceptOptions{} - } - - err := verifyClientRequest(w, r) - if err != nil { - return nil, err - } - - if !opts.InsecureSkipVerify { - err = authenticateOrigin(r) - if err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return nil, err - } - } - - hj, ok := w.(http.Hijacker) - if !ok { - err = errors.New("passed ResponseWriter does not implement http.Hijacker") - http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) - return nil, err - } - - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Connection", "Upgrade") - - handleSecWebSocketKey(w, r) - - subproto := selectSubprotocol(r, opts.Subprotocols) - if subproto != "" { - w.Header().Set("Sec-WebSocket-Protocol", subproto) - } - - w.WriteHeader(http.StatusSwitchingProtocols) - - netConn, brw, err := hj.Hijack() - if err != nil { - err = fmt.Errorf("failed to hijack connection: %w", err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return nil, err - } - - // https://github.com/golang/go/issues/32314 - b, _ := brw.Reader.Peek(brw.Reader.Buffered()) - brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - - c := &Conn{ - subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), - br: brw.Reader, - bw: brw.Writer, - closer: netConn, - } - c.init() - - return c, nil -} - -func headerValuesContainsToken(h http.Header, key, token string) bool { - key = textproto.CanonicalMIMEHeaderKey(key) - - for _, val2 := range h[key] { - if headerValueContainsToken(val2, token) { - return true - } - } - - return false -} - -func headerValueContainsToken(val2, token string) bool { - val2 = strings.TrimSpace(val2) - - for _, val2 := range strings.Split(val2, ",") { - val2 = strings.TrimSpace(val2) - if strings.EqualFold(val2, token) { - return true - } - } - - return false -} - -func selectSubprotocol(r *http.Request, subprotocols []string) string { - for _, sp := range subprotocols { - if headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { - return sp - } - } - return "" -} - -var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - -func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { - key := r.Header.Get("Sec-WebSocket-Key") - w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) -} - -func secWebSocketAccept(secWebSocketKey string) string { - h := sha1.New() - h.Write([]byte(secWebSocketKey)) - h.Write(keyGUID) - - return base64.StdEncoding.EncodeToString(h.Sum(nil)) -} - -func authenticateOrigin(r *http.Request) error { - origin := r.Header.Get("Origin") - if origin == "" { - return nil - } - u, err := url.Parse(origin) - if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) - } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) - } - return nil -} - -// DialOptions represents the options available to pass to Dial. -type DialOptions struct { - // HTTPClient is the http client used for the handshake. - // Its Transport must return writable bodies - // for WebSocket handshakes. - // http.Transport does this correctly beginning with Go 1.12. - HTTPClient *http.Client - - // HTTPHeader specifies the HTTP headers included in the handshake request. - HTTPHeader http.Header - - // Subprotocols lists the subprotocols to negotiate with the server. - Subprotocols []string -} - -// Dial performs a WebSocket handshake on the given url with the given options. -// The response is the WebSocket handshake response from the server. -// If an error occurs, the returned response may be non nil. However, you can only -// read the first 1024 bytes of its body. -// -// You never need to close the resp.Body yourself. -// -// This function requires at least Go 1.12 to succeed as it uses a new feature -// in net/http to perform WebSocket handshakes and get a writable body -// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 -func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { - c, r, err := dial(ctx, u, opts) - if err != nil { - return nil, r, fmt.Errorf("failed to websocket dial: %w", err) - } - return c, r, nil -} - -func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { - if opts == nil { - opts = &DialOptions{} - } - - // Shallow copy to ensure defaults do not affect user passed options. - opts2 := *opts - opts = &opts2 - - if opts.HTTPClient == nil { - opts.HTTPClient = http.DefaultClient - } - if opts.HTTPClient.Timeout > 0 { - return nil, nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } - if opts.HTTPHeader == nil { - opts.HTTPHeader = http.Header{} - } - - parsedURL, err := url.Parse(u) - if err != nil { - return nil, nil, fmt.Errorf("failed to parse url: %w", err) - } - - switch parsedURL.Scheme { - case "ws": - parsedURL.Scheme = "http" - case "wss": - parsedURL.Scheme = "https" - default: - return nil, nil, fmt.Errorf("unexpected url scheme: %q", parsedURL.Scheme) - } - - req, _ := http.NewRequest("GET", parsedURL.String(), nil) - req = req.WithContext(ctx) - req.Header = opts.HTTPHeader - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Sec-WebSocket-Version", "13") - secWebSocketKey, err := makeSecWebSocketKey() - if err != nil { - return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) - } - req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) - if len(opts.Subprotocols) > 0 { - req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) - } - - resp, err := opts.HTTPClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("failed to send handshake request: %w", err) - } - defer func() { - if err != nil { - // We read a bit of the body for easier debugging. - r := io.LimitReader(resp.Body, 1024) - b, _ := ioutil.ReadAll(r) - resp.Body.Close() - resp.Body = ioutil.NopCloser(bytes.NewReader(b)) - } - }() - - err = verifyServerResponse(req, resp) - if err != nil { - return nil, resp, err - } - - rwc, ok := resp.Body.(io.ReadWriteCloser) - if !ok { - return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", resp.Body) - } - - c := &Conn{ - subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), - br: getBufioReader(rwc), - bw: getBufioWriter(rwc), - closer: rwc, - client: true, - } - c.extractBufioWriterBuf(rwc) - c.init() - - return c, resp, nil -} - -func verifyServerResponse(r *http.Request, resp *http.Response) error { - if resp.StatusCode != http.StatusSwitchingProtocols { - return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) - } - - if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") { - return fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) - } - - if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") { - return fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) - } - - if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { - return fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", - resp.Header.Get("Sec-WebSocket-Accept"), - r.Header.Get("Sec-WebSocket-Key"), - ) - } - - if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { - return fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) - } - - return nil -} - -// The below pools can only be used by the client because http.Hijacker will always -// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top. - -var bufioReaderPool = sync.Pool{ - New: func() interface{} { - return bufio.NewReader(nil) - }, -} - -func getBufioReader(r io.Reader) *bufio.Reader { - br := bufioReaderPool.Get().(*bufio.Reader) - br.Reset(r) - return br -} - -func returnBufioReader(br *bufio.Reader) { - bufioReaderPool.Put(br) -} - -var bufioWriterPool = sync.Pool{ - New: func() interface{} { - return bufio.NewWriter(nil) - }, -} - -func getBufioWriter(w io.Writer) *bufio.Writer { - bw := bufioWriterPool.Get().(*bufio.Writer) - bw.Reset(w) - return bw -} - -func returnBufioWriter(bw *bufio.Writer) { - bufioWriterPool.Put(bw) -} - -func makeSecWebSocketKey() (string, error) { - b := make([]byte, 16) - _, err := io.ReadFull(rand.Reader, b) - if err != nil { - return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) - } - return base64.StdEncoding.EncodeToString(b), nil -} diff --git a/internal/assert/assert.go b/internal/assert/assert.go deleted file mode 100644 index e57abfd9..00000000 --- a/internal/assert/assert.go +++ /dev/null @@ -1,63 +0,0 @@ -package assert - -import ( - "fmt" - "reflect" - - "github.com/google/go-cmp/cmp" -) - -// https://github.com/google/go-cmp/issues/40#issuecomment-328615283 -func cmpDiff(exp, act interface{}) string { - return cmp.Diff(exp, act, deepAllowUnexported(exp, act)) -} - -func deepAllowUnexported(vs ...interface{}) cmp.Option { - m := make(map[reflect.Type]struct{}) - for _, v := range vs { - structTypes(reflect.ValueOf(v), m) - } - var typs []interface{} - for t := range m { - typs = append(typs, reflect.New(t).Elem().Interface()) - } - return cmp.AllowUnexported(typs...) -} - -func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { - if !v.IsValid() { - return - } - switch v.Kind() { - case reflect.Ptr: - if !v.IsNil() { - structTypes(v.Elem(), m) - } - case reflect.Interface: - if !v.IsNil() { - structTypes(v.Elem(), m) - } - case reflect.Slice, reflect.Array: - for i := 0; i < v.Len(); i++ { - structTypes(v.Index(i), m) - } - case reflect.Map: - for _, k := range v.MapKeys() { - structTypes(v.MapIndex(k), m) - } - case reflect.Struct: - m[v.Type()] = struct{}{} - for i := 0; i < v.NumField(); i++ { - structTypes(v.Field(i), m) - } - } -} - -// Equalf compares exp to act and if they are not equal, returns -// an error describing an error. -func Equalf(exp, act interface{}, f string, v ...interface{}) error { - if diff := cmpDiff(exp, act); diff != "" { - return fmt.Errorf(f+": %v", append(v, diff)...) - } - return nil -} diff --git a/internal/bpool/bpool.go b/internal/bpool/bpool.go index 4266c236..aa826fba 100644 --- a/internal/bpool/bpool.go +++ b/internal/bpool/bpool.go @@ -10,11 +10,11 @@ var bpool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { - b, ok := bpool.Get().(*bytes.Buffer) - if !ok { - b = &bytes.Buffer{} + b := bpool.Get() + if b == nil { + return &bytes.Buffer{} } - return b + return b.(*bytes.Buffer) } // Put returns a buffer into the pool. diff --git a/internal/bpool/bpool_test.go b/internal/bpool/bpool_test.go deleted file mode 100644 index 5dfe56e6..00000000 --- a/internal/bpool/bpool_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package bpool - -import ( - "strconv" - "sync" - "testing" -) - -func BenchmarkSyncPool(b *testing.B) { - sizes := []int{ - 2, - 16, - 32, - 64, - 128, - 256, - 512, - 4096, - 16384, - } - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - b.Run("allocate", func(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - buf := make([]byte, size) - _ = buf - } - }) - b.Run("pool", func(b *testing.B) { - b.ReportAllocs() - - p := sync.Pool{} - - for i := 0; i < b.N; i++ { - buf := p.Get() - if buf == nil { - buf = make([]byte, size) - } - - p.Put(buf) - } - }) - }) - } -} diff --git a/internal/errd/wrap.go b/internal/errd/wrap.go new file mode 100644 index 00000000..ed0b7754 --- /dev/null +++ b/internal/errd/wrap.go @@ -0,0 +1,42 @@ +package errd + +import ( + "fmt" + + "golang.org/x/xerrors" +) + +type wrapError struct { + msg string + err error + frame xerrors.Frame +} + +func (e *wrapError) Error() string { + return fmt.Sprint(e) +} + +func (e *wrapError) Format(s fmt.State, v rune) { xerrors.FormatError(e, s, v) } + +func (e *wrapError) FormatError(p xerrors.Printer) (next error) { + p.Print(e.msg) + e.frame.Format(p) + return e.err +} + +func (e *wrapError) Unwrap() error { + return e.err +} + +// Wrap wraps err with xerrors.Errorf if err is non nil. +// Intended for use with defer and a named error return. +// Inspired by https://github.com/golang/go/issues/32676. +func Wrap(err *error, f string, v ...interface{}) { + if *err != nil { + *err = &wrapError{ + msg: fmt.Sprintf(f, v...), + err: *err, + frame: xerrors.Caller(1), + } + } +} diff --git a/internal/test/assert/assert.go b/internal/test/assert/assert.go new file mode 100644 index 00000000..602b887e --- /dev/null +++ b/internal/test/assert/assert.go @@ -0,0 +1,46 @@ +package assert + +import ( + "fmt" + "strings" + "testing" + + "nhooyr.io/websocket/internal/test/cmp" +) + +// Equal asserts exp == act. +func Equal(t testing.TB, name string, exp, act interface{}) { + t.Helper() + + if diff := cmp.Diff(exp, act); diff != "" { + t.Fatalf("unexpected %v: %v", name, diff) + } +} + +// Success asserts err == nil. +func Success(t testing.TB, err error) { + t.Helper() + + if err != nil { + t.Fatal(err) + } +} + +// Error asserts err != nil. +func Error(t testing.TB, err error) { + t.Helper() + + if err == nil { + t.Fatal("expected error") + } +} + +// Contains asserts the fmt.Sprint(v) contains sub. +func Contains(t testing.TB, v interface{}, sub string) { + t.Helper() + + s := fmt.Sprint(v) + if !strings.Contains(s, sub) { + t.Fatalf("expected %q to contain %q", s, sub) + } +} diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go new file mode 100644 index 00000000..eadcb5d9 --- /dev/null +++ b/internal/test/cmp/cmp.go @@ -0,0 +1,16 @@ +package cmp + +import ( + "reflect" + + "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +// Diff returns a human readable diff between v1 and v2 +func Diff(v1, v2 interface{}) string { + return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { + return true + }), cmp.Comparer(proto.Equal)) +} diff --git a/internal/test/doc.go b/internal/test/doc.go new file mode 100644 index 00000000..94b2e82d --- /dev/null +++ b/internal/test/doc.go @@ -0,0 +1,2 @@ +// Package test contains subpackages only used in tests. +package test diff --git a/internal/test/wstest/echo.go b/internal/test/wstest/echo.go new file mode 100644 index 00000000..714767fc --- /dev/null +++ b/internal/test/wstest/echo.go @@ -0,0 +1,91 @@ +package wstest + +import ( + "bytes" + "context" + "io" + "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/xrand" + "nhooyr.io/websocket/internal/xsync" +) + +// EchoLoop echos every msg received from c until an error +// occurs or the context expires. +// The read limit is set to 1 << 30. +func EchoLoop(ctx context.Context, c *websocket.Conn) error { + defer c.Close(websocket.StatusInternalError, "") + + c.SetReadLimit(1 << 30) + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + b := make([]byte, 32<<10) + for { + typ, r, err := c.Reader(ctx) + if err != nil { + return err + } + + w, err := c.Writer(ctx, typ) + if err != nil { + return err + } + + _, err = io.CopyBuffer(w, r, b) + if err != nil { + return err + } + + err = w.Close() + if err != nil { + return err + } + } +} + +// Echo writes a message and ensures the same is sent back on c. +func Echo(ctx context.Context, c *websocket.Conn, max int) error { + expType := websocket.MessageBinary + if xrand.Bool() { + expType = websocket.MessageText + } + + msg := randMessage(expType, xrand.Int(max)) + + writeErr := xsync.Go(func() error { + return c.Write(ctx, expType, msg) + }) + + actType, act, err := c.Read(ctx) + if err != nil { + return err + } + + err = <-writeErr + if err != nil { + return err + } + + if expType != actType { + return xerrors.Errorf("unexpected message typ (%v): %v", expType, actType) + } + + if !bytes.Equal(msg, act) { + return xerrors.Errorf("unexpected msg read: %v", cmp.Diff(msg, act)) + } + + return nil +} + +func randMessage(typ websocket.MessageType, n int) []byte { + if typ == websocket.MessageBinary { + return xrand.Bytes(n) + } + return []byte(xrand.String(n)) +} diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go new file mode 100644 index 00000000..81705a8a --- /dev/null +++ b/internal/test/wstest/pipe.go @@ -0,0 +1,85 @@ +// +build !js + +package wstest + +import ( + "bufio" + "context" + "net" + "net/http" + "net/http/httptest" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/test/xrand" +) + +// Pipe is used to create an in memory connection +// between two websockets analogous to net.Pipe. +func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (_ *websocket.Conn, _ *websocket.Conn, err error) { + defer errd.Wrap(&err, "failed to create ws pipe") + + var serverConn *websocket.Conn + var acceptErr error + tt := fakeTransport{ + h: func(w http.ResponseWriter, r *http.Request) { + serverConn, acceptErr = websocket.Accept(w, r, acceptOpts) + }, + } + + if dialOpts == nil { + dialOpts = &websocket.DialOptions{} + } + dialOpts = &*dialOpts + dialOpts.HTTPClient = &http.Client{ + Transport: tt, + } + + clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) + if err != nil { + return nil, nil, xerrors.Errorf("failed to dial with fake transport: %w", err) + } + + if serverConn == nil { + return nil, nil, xerrors.Errorf("failed to get server conn from fake transport: %w", acceptErr) + } + + if xrand.Bool() { + return serverConn, clientConn, nil + } + return clientConn, serverConn, nil +} + +type fakeTransport struct { + h http.HandlerFunc +} + +func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) { + clientConn, serverConn := net.Pipe() + + hj := testHijacker{ + ResponseRecorder: httptest.NewRecorder(), + serverConn: serverConn, + } + + t.h.ServeHTTP(hj, r) + + resp := hj.ResponseRecorder.Result() + if resp.StatusCode == http.StatusSwitchingProtocols { + resp.Body = clientConn + } + return resp, nil +} + +type testHijacker struct { + *httptest.ResponseRecorder + serverConn net.Conn +} + +var _ http.Hijacker = testHijacker{} + +func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil +} diff --git a/internal/test/wstest/url.go b/internal/test/wstest/url.go new file mode 100644 index 00000000..a11c61b4 --- /dev/null +++ b/internal/test/wstest/url.go @@ -0,0 +1,11 @@ +package wstest + +import ( + "net/http/httptest" + "strings" +) + +// URL returns the ws url for s. +func URL(s *httptest.Server) string { + return strings.Replace(s.URL, "http", "ws", 1) +} diff --git a/internal/test/xrand/xrand.go b/internal/test/xrand/xrand.go new file mode 100644 index 00000000..8de1ede8 --- /dev/null +++ b/internal/test/xrand/xrand.go @@ -0,0 +1,47 @@ +package xrand + +import ( + "crypto/rand" + "fmt" + "math/big" + "strings" +) + +// Bytes generates random bytes with length n. +func Bytes(n int) []byte { + b := make([]byte, n) + _, err := rand.Reader.Read(b) + if err != nil { + panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) + } + return b +} + +// String generates a random string with length n. +func String(n int) string { + s := strings.ToValidUTF8(string(Bytes(n)), "_") + s = strings.ReplaceAll(s, "\x00", "_") + if len(s) > n { + return s[:n] + } + if len(s) < n { + // Pad with = + extra := n - len(s) + return s + strings.Repeat("=", extra) + } + return s +} + +// Bool returns a randomly generated boolean. +func Bool() bool { + return Int(2) == 1 +} + +// Int returns a randomly generated integer between [0, max). +func Int(max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + if err != nil { + panic(fmt.Sprintf("failed to get random int: %v", err)) + } + return int(x.Int64()) +} diff --git a/internal/wsecho/wsecho.go b/internal/wsecho/wsecho.go deleted file mode 100644 index c408f07f..00000000 --- a/internal/wsecho/wsecho.go +++ /dev/null @@ -1,55 +0,0 @@ -// +build !js - -package wsecho - -import ( - "context" - "io" - "time" - - "nhooyr.io/websocket" -) - -// Loop echos every msg received from c until an error -// occurs or the context expires. -// The read limit is set to 1 << 30. -func Loop(ctx context.Context, c *websocket.Conn) error { - defer c.Close(websocket.StatusInternalError, "") - - c.SetReadLimit(1 << 30) - - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - - b := make([]byte, 32<<10) - echo := func() error { - typ, r, err := c.Reader(ctx) - if err != nil { - return err - } - - w, err := c.Writer(ctx, typ) - if err != nil { - return err - } - - _, err = io.CopyBuffer(w, r, b) - if err != nil { - return err - } - - err = w.Close() - if err != nil { - return err - } - - return nil - } - - for { - err := echo() - if err != nil { - return err - } - } -} diff --git a/internal/wsgrace/wsgrace.go b/internal/wsgrace/wsgrace.go deleted file mode 100644 index 513af1fe..00000000 --- a/internal/wsgrace/wsgrace.go +++ /dev/null @@ -1,50 +0,0 @@ -package wsgrace - -import ( - "context" - "fmt" - "net/http" - "sync/atomic" - "time" -) - -// Grace wraps s.Handler to gracefully shutdown WebSocket connections. -// The returned function must be used to close the server instead of s.Close. -func Grace(s *http.Server) (closeFn func() error) { - h := s.Handler - var conns int64 - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&conns, 1) - defer atomic.AddInt64(&conns, -1) - - ctx, cancel := context.WithTimeout(r.Context(), time.Minute) - defer cancel() - - r = r.WithContext(ctx) - - h.ServeHTTP(w, r) - }) - - return func() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - err := s.Shutdown(ctx) - if err != nil { - return fmt.Errorf("server shutdown failed: %v", err) - } - - t := time.NewTicker(time.Millisecond * 10) - defer t.Stop() - for { - select { - case <-t.C: - if atomic.LoadInt64(&conns) == 0 { - return nil - } - case <-ctx.Done(): - return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) - } - } - } -} diff --git a/internal/wsjs/wsjs_js.go b/internal/wsjs/wsjs_js.go index d48691d4..26ffb456 100644 --- a/internal/wsjs/wsjs_js.go +++ b/internal/wsjs/wsjs_js.go @@ -102,7 +102,7 @@ type MessageEvent struct { // See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent } -// OnMessage registers a function to be called when the websocket receives a message. +// OnMessage registers a function to be called when the WebSocket receives a message. func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { return c.addEventListener("message", func(e js.Value) { var data interface{} @@ -128,7 +128,7 @@ func (c WebSocket) Subprotocol() string { return c.v.Get("protocol").String() } -// OnOpen registers a function to be called when the websocket is opened. +// OnOpen registers a function to be called when the WebSocket is opened. func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) { return c.addEventListener("open", fn) } diff --git a/internal/xsync/go.go b/internal/xsync/go.go new file mode 100644 index 00000000..712739aa --- /dev/null +++ b/internal/xsync/go.go @@ -0,0 +1,25 @@ +package xsync + +import ( + "golang.org/x/xerrors" +) + +// Go allows running a function in another goroutine +// and waiting for its error. +func Go(fn func() error) <-chan error { + errs := make(chan error, 1) + go func() { + defer func() { + r := recover() + if r != nil { + select { + case errs <- xerrors.Errorf("panic in go fn: %v", r): + default: + } + } + }() + errs <- fn() + }() + + return errs +} diff --git a/internal/xsync/go_test.go b/internal/xsync/go_test.go new file mode 100644 index 00000000..dabea8a5 --- /dev/null +++ b/internal/xsync/go_test.go @@ -0,0 +1,18 @@ +package xsync + +import ( + "testing" + + "nhooyr.io/websocket/internal/test/assert" +) + +func TestGoRecover(t *testing.T) { + t.Parallel() + + errs := Go(func() error { + panic("anmol") + }) + + err := <-errs + assert.Contains(t, err, "anmol") +} diff --git a/internal/xsync/int64.go b/internal/xsync/int64.go new file mode 100644 index 00000000..a0c40204 --- /dev/null +++ b/internal/xsync/int64.go @@ -0,0 +1,23 @@ +package xsync + +import ( + "sync/atomic" +) + +// Int64 represents an atomic int64. +type Int64 struct { + // We do not use atomic.Load/StoreInt64 since it does not + // work on 32 bit computers but we need 64 bit integers. + i atomic.Value +} + +// Load loads the int64. +func (v *Int64) Load() int64 { + i, _ := v.i.Load().(int64) + return i +} + +// Store stores the int64. +func (v *Int64) Store(i int64) { + v.i.Store(i) +} diff --git a/conn_common.go b/netconn.go similarity index 59% rename from conn_common.go rename to netconn.go index 1247df6e..a2d8f4f3 100644 --- a/conn_common.go +++ b/netconn.go @@ -1,17 +1,14 @@ -// This file contains *Conn symbols relevant to both -// Wasm and non Wasm builds. - package websocket import ( "context" - "fmt" "io" "math" "net" "sync" - "sync/atomic" "time" + + "golang.org/x/xerrors" ) // NetConn converts a *websocket.Conn into a net.Conn. @@ -111,7 +108,7 @@ func (c *netConn) Read(p []byte) (int, error) { return 0, err } if typ != c.msgType { - err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) + err := xerrors.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) c.c.Close(StatusUnsupportedData, err.Error()) return 0, err } @@ -168,78 +165,3 @@ func (c *netConn) SetReadDeadline(t time.Time) error { } return nil } - -// CloseRead will start a goroutine to read from the connection until it is closed or a data message -// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. -// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. -// After calling this method, you cannot read any data messages from the connection. -// The returned context will be cancelled when the connection is closed. -// -// Use this when you do not want to read data messages from the connection anymore but will -// want to write messages to it. -func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - - ctx, cancel := context.WithCancel(ctx) - go func() { - defer cancel() - // We use the unexported reader method so that we don't get the read closed error. - c.reader(ctx, true) - // Either the connection is already closed since there was a read error - // or the context was cancelled or a message was read and we should close - // the connection. - c.Close(StatusPolicyViolation, "unexpected data message") - }() - return ctx -} - -// SetReadLimit sets the max number of bytes to read for a single message. -// It applies to the Reader and Read methods. -// -// By default, the connection has a message read limit of 32768 bytes. -// -// When the limit is hit, the connection will be closed with StatusMessageTooBig. -func (c *Conn) SetReadLimit(n int64) { - c.msgReadLimit.Store(n) -} - -func (c *Conn) setCloseErr(err error) { - c.closeErrOnce.Do(func() { - c.closeErr = fmt.Errorf("websocket closed: %w", err) - }) -} - -// See https://github.com/nhooyr/websocket/issues/153 -type atomicInt64 struct { - v int64 -} - -func (v *atomicInt64) Load() int64 { - return atomic.LoadInt64(&v.v) -} - -func (v *atomicInt64) Store(i int64) { - atomic.StoreInt64(&v.v, i) -} - -func (v *atomicInt64) String() string { - return fmt.Sprint(v.Load()) -} - -// Increment increments the value and returns the new value. -func (v *atomicInt64) Increment(delta int64) int64 { - return atomic.AddInt64(&v.v, delta) -} - -func (v *atomicInt64) CAS(old, new int64) (swapped bool) { - return atomic.CompareAndSwapInt64(&v.v, old, new) -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} diff --git a/read.go b/read.go new file mode 100644 index 00000000..bbad30d1 --- /dev/null +++ b/read.go @@ -0,0 +1,468 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "io" + "io/ioutil" + "strings" + "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/xsync" +) + +// Reader reads from the connection until until there is a WebSocket +// data message to be read. It will handle ping, pong and close frames as appropriate. +// +// It returns the type of the message and an io.Reader to read it. +// The passed context will also bound the reader. +// Ensure you read to EOF otherwise the connection will hang. +// +// Call CloseRead if you do not expect any data messages from the peer. +// +// Only one Reader may be open at a time. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + return c.reader(ctx) +} + +// Read is a convenience method around Reader to read a single message +// from the connection. +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, r, err := c.Reader(ctx) + if err != nil { + return 0, nil, err + } + + b, err := ioutil.ReadAll(r) + return typ, b, err +} + +// CloseRead starts a goroutine to read from the connection until it is closed +// or a data message is received. +// +// Once CloseRead is called you cannot read any messages from the connection. +// The returned context will be cancelled when the connection is closed. +// +// If a data message is received, the connection will be closed with StatusPolicyViolation. +// +// Call CloseRead when you do not expect to read any more messages. +// Since it actively reads from the connection, it will ensure that ping, pong and close +// frames are responded to. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.Reader(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusMessageTooBig. +func (c *Conn) SetReadLimit(n int64) { + // We add read one more byte than the limit in case + // there is a fin frame that needs to be read. + c.msgReader.limitReader.limit.Store(n + 1) +} + +const defaultReadLimit = 32768 + +func newMsgReader(c *Conn) *msgReader { + mr := &msgReader{ + c: c, + fin: true, + } + mr.readFunc = mr.read + + mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) + return mr +} + +func (mr *msgReader) resetFlate() { + if mr.flateContextTakeover() { + mr.dict.init(32768) + } + if mr.flateBufio == nil { + mr.flateBufio = getBufioReader(mr.readFunc) + } + + mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) + mr.limitReader.r = mr.flateReader + mr.flateTail.Reset(deflateMessageTail) +} + +func (mr *msgReader) putFlateReader() { + if mr.flateReader != nil { + putFlateReader(mr.flateReader) + mr.flateReader = nil + } +} + +func (mr *msgReader) close() { + mr.c.readMu.Lock(context.Background()) + mr.putFlateReader() + mr.dict.close() + if mr.flateBufio != nil { + putBufioReader(mr.flateBufio) + } +} + +func (mr *msgReader) flateContextTakeover() bool { + if mr.c.client { + return !mr.c.copts.serverNoContextTakeover + } + return !mr.c.copts.clientNoContextTakeover +} + +func (c *Conn) readRSV1Illegal(h header) bool { + // If compression is disabled, rsv1 is illegal. + if !c.flate() { + return true + } + // rsv1 is only allowed on data frames beginning messages. + if h.opcode != opText && h.opcode != opBinary { + return true + } + return false +} + +func (c *Conn) readLoop(ctx context.Context) (header, error) { + for { + h, err := c.readFrameHeader(ctx) + if err != nil { + return header{}, err + } + + if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { + err := xerrors.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) + c.writeError(StatusProtocolError, err) + return header{}, err + } + + if !c.client && !h.masked { + return header{}, xerrors.New("received unmasked frame from client") + } + + switch h.opcode { + case opClose, opPing, opPong: + err = c.handleControl(ctx, h) + if err != nil { + // Pass through CloseErrors when receiving a close frame. + if h.opcode == opClose && CloseStatus(err) != -1 { + return header{}, err + } + return header{}, xerrors.Errorf("failed to handle control frame %v: %w", h.opcode, err) + } + case opContinuation, opText, opBinary: + return h, nil + default: + err := xerrors.Errorf("received unknown opcode %v", h.opcode) + c.writeError(StatusProtocolError, err) + return header{}, err + } + } +} + +func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { + select { + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- ctx: + } + + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) + if err != nil { + select { + case <-c.closed: + return header{}, c.closeErr + case <-ctx.Done(): + return header{}, ctx.Err() + default: + c.close(err) + return header{}, err + } + } + + select { + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- context.Background(): + } + + return h, nil +} + +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { + select { + case <-c.closed: + return 0, c.closeErr + case c.readTimeout <- ctx: + } + + n, err := io.ReadFull(c.br, p) + if err != nil { + select { + case <-c.closed: + return n, c.closeErr + case <-ctx.Done(): + return n, ctx.Err() + default: + err = xerrors.Errorf("failed to read frame payload: %w", err) + c.close(err) + return n, err + } + } + + select { + case <-c.closed: + return n, c.closeErr + case c.readTimeout <- context.Background(): + } + + return n, err +} + +func (c *Conn) handleControl(ctx context.Context, h header) (err error) { + if h.payloadLength < 0 || h.payloadLength > maxControlPayload { + err := xerrors.Errorf("received control frame payload with invalid length: %d", h.payloadLength) + c.writeError(StatusProtocolError, err) + return err + } + + if !h.fin { + err := xerrors.New("received fragmented control frame") + c.writeError(StatusProtocolError, err) + return err + } + + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + b := c.readControlBuf[:h.payloadLength] + _, err = c.readFramePayload(ctx, b) + if err != nil { + return err + } + + if h.masked { + mask(h.maskKey, b) + } + + switch h.opcode { + case opPing: + return c.writeControl(ctx, opPong, b) + case opPong: + c.activePingsMu.Lock() + pong, ok := c.activePings[string(b)] + c.activePingsMu.Unlock() + if ok { + close(pong) + } + return nil + } + + defer func() { + c.readCloseFrameErr = err + }() + + ce, err := parseClosePayload(b) + if err != nil { + err = xerrors.Errorf("received invalid close payload: %w", err) + c.writeError(StatusProtocolError, err) + return err + } + + err = xerrors.Errorf("received close frame: %w", ce) + c.setCloseErr(err) + c.writeClose(ce.Code, ce.Reason) + c.close(err) + return err +} + +func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { + defer errd.Wrap(&err, "failed to get reader") + + err = c.readMu.Lock(ctx) + if err != nil { + return 0, nil, err + } + defer c.readMu.Unlock() + + if !c.msgReader.fin { + return 0, nil, xerrors.New("previous message not read to completion") + } + + h, err := c.readLoop(ctx) + if err != nil { + return 0, nil, err + } + + if h.opcode == opContinuation { + err := xerrors.New("received continuation frame without text or binary frame") + c.writeError(StatusProtocolError, err) + return 0, nil, err + } + + c.msgReader.reset(ctx, h) + + return MessageType(h.opcode), c.msgReader, nil +} + +type msgReader struct { + c *Conn + + ctx context.Context + flate bool + flateReader io.Reader + flateBufio *bufio.Reader + flateTail strings.Reader + limitReader *limitReader + dict slidingWindow + + fin bool + payloadLength int64 + maskKey uint32 + + // readerFunc(mr.Read) to avoid continuous allocations. + readFunc readerFunc +} + +func (mr *msgReader) reset(ctx context.Context, h header) { + mr.ctx = ctx + mr.flate = h.rsv1 + mr.limitReader.reset(mr.readFunc) + + if mr.flate { + mr.resetFlate() + } + + mr.setFrame(h) +} + +func (mr *msgReader) setFrame(h header) { + mr.fin = h.fin + mr.payloadLength = h.payloadLength + mr.maskKey = h.maskKey +} + +func (mr *msgReader) Read(p []byte) (n int, err error) { + defer func() { + if xerrors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { + err = io.EOF + } + if xerrors.Is(err, io.EOF) { + err = io.EOF + mr.putFlateReader() + return + } + errd.Wrap(&err, "failed to read") + }() + + err = mr.c.readMu.Lock(mr.ctx) + if err != nil { + return 0, err + } + defer mr.c.readMu.Unlock() + + n, err = mr.limitReader.Read(p) + if mr.flate && mr.flateContextTakeover() { + p = p[:n] + mr.dict.write(p) + } + return n, err +} + +func (mr *msgReader) read(p []byte) (int, error) { + for { + if mr.payloadLength == 0 { + if mr.fin { + if mr.flate { + return mr.flateTail.Read(p) + } + return 0, io.EOF + } + + h, err := mr.c.readLoop(mr.ctx) + if err != nil { + return 0, err + } + if h.opcode != opContinuation { + err := xerrors.New("received new data message without finishing the previous message") + mr.c.writeError(StatusProtocolError, err) + return 0, err + } + mr.setFrame(h) + + continue + } + + if int64(len(p)) > mr.payloadLength { + p = p[:mr.payloadLength] + } + + n, err := mr.c.readFramePayload(mr.ctx, p) + if err != nil { + return n, err + } + + mr.payloadLength -= int64(n) + + if !mr.c.client { + mr.maskKey = mask(mr.maskKey, p) + } + + return n, nil + } +} + +type limitReader struct { + c *Conn + r io.Reader + limit xsync.Int64 + n int64 +} + +func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { + lr := &limitReader{ + c: c, + } + lr.limit.Store(limit) + lr.reset(r) + return lr +} + +func (lr *limitReader) reset(r io.Reader) { + lr.n = lr.limit.Load() + lr.r = r +} + +func (lr *limitReader) Read(p []byte) (int, error) { + if lr.n <= 0 { + err := xerrors.Errorf("read limited at %v bytes", lr.limit.Load()) + lr.c.writeError(StatusMessageTooBig, err) + return 0, err + } + + if int64(len(p)) > lr.n { + p = p[:lr.n] + } + n, err := lr.r.Read(p) + lr.n -= int64(n) + return n, err +} + +type readerFunc func(p []byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { + return f(p) +} diff --git a/frame_stringer.go b/stringer.go similarity index 98% rename from frame_stringer.go rename to stringer.go index 72b865fc..5a66ba29 100644 --- a/frame_stringer.go +++ b/stringer.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go"; DO NOT EDIT. +// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT. package websocket diff --git a/websocket_js_test.go b/websocket_js_test.go deleted file mode 100644 index 9b7bb813..00000000 --- a/websocket_js_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package websocket_test - -import ( - "context" - "net/http" - "os" - "testing" - "time" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/assert" -) - -func TestConn(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - }) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - err = assertSubprotocol(c, "echo") - if err != nil { - t.Fatal(err) - } - - err = assert.Equalf(&http.Response{}, resp, "unexpected http response") - if err != nil { - t.Fatal(err) - } - - err = assertJSONEcho(ctx, c, 1024) - if err != nil { - t.Fatal(err) - } - - err = assertEcho(ctx, c, websocket.MessageBinary, 1024) - if err != nil { - t.Fatal(err) - } - - err = c.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatal(err) - } -} diff --git a/write.go b/write.go new file mode 100644 index 00000000..b560b44c --- /dev/null +++ b/write.go @@ -0,0 +1,350 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "crypto/rand" + "encoding/binary" + "io" + "sync" + "time" + + "github.com/klauspost/compress/flate" + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/errd" +) + +// Writer returns a writer bounded by the context that will write +// a WebSocket message of type dataType to the connection. +// +// You must close the writer once you have written the entire message. +// +// Only one writer can be open at a time, multiple calls will block until the previous writer +// is closed. +func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + w, err := c.writer(ctx, typ) + if err != nil { + return nil, xerrors.Errorf("failed to get writer: %w", err) + } + return w, nil +} + +// Write writes a message to the connection. +// +// See the Writer method if you want to stream a message. +// +// If compression is disabled or the threshold is not met, then it +// will write the message in a single frame. +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + _, err := c.write(ctx, typ, p) + if err != nil { + return xerrors.Errorf("failed to write msg: %w", err) + } + return nil +} + +type msgWriter struct { + mw *msgWriterState + closed bool +} + +func (mw *msgWriter) Write(p []byte) (int, error) { + if mw.closed { + return 0, xerrors.New("cannot use closed writer") + } + return mw.mw.Write(p) +} + +func (mw *msgWriter) Close() error { + if mw.closed { + return xerrors.New("cannot use closed writer") + } + mw.closed = true + return mw.mw.Close() +} + +type msgWriterState struct { + c *Conn + + mu *mu + writeMu sync.Mutex + + ctx context.Context + opcode opcode + flate bool + + trimWriter *trimLastFourBytesWriter + dict slidingWindow +} + +func newMsgWriterState(c *Conn) *msgWriterState { + mw := &msgWriterState{ + c: c, + mu: newMu(c), + } + return mw +} + +func (mw *msgWriterState) ensureFlate() { + if mw.trimWriter == nil { + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), + } + } + + mw.dict.init(8192) + mw.flate = true +} + +func (mw *msgWriterState) flateContextTakeover() bool { + if mw.c.client { + return !mw.c.copts.clientNoContextTakeover + } + return !mw.c.copts.serverNoContextTakeover +} + +func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + err := c.msgWriterState.reset(ctx, typ) + if err != nil { + return nil, err + } + return &msgWriter{ + mw: c.msgWriterState, + closed: false, + }, nil +} + +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { + mw, err := c.writer(ctx, typ) + if err != nil { + return 0, err + } + + if !c.flate() { + defer c.msgWriterState.mu.Unlock() + return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) + } + + n, err := mw.Write(p) + if err != nil { + return n, err + } + + err = mw.Close() + return n, err +} + +func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { + err := mw.mu.Lock(ctx) + if err != nil { + return err + } + + mw.ctx = ctx + mw.opcode = opcode(typ) + mw.flate = false + + mw.trimWriter.reset() + + return nil +} + +// Write writes the given bytes to the WebSocket connection. +func (mw *msgWriterState) Write(p []byte) (_ int, err error) { + defer errd.Wrap(&err, "failed to write") + + mw.writeMu.Lock() + defer mw.writeMu.Unlock() + + if mw.c.flate() { + // Only enables flate if the length crosses the + // threshold on the first frame + if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { + mw.ensureFlate() + } + } + + if mw.flate { + err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) + if err != nil { + return 0, err + } + mw.dict.write(p) + return len(p), nil + } + + return mw.write(p) +} + +func (mw *msgWriterState) write(p []byte) (int, error) { + n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) + if err != nil { + return n, xerrors.Errorf("failed to write data frame: %w", err) + } + mw.opcode = opContinuation + return n, nil +} + +// Close flushes the frame to the connection. +func (mw *msgWriterState) Close() (err error) { + defer errd.Wrap(&err, "failed to close writer") + + mw.writeMu.Lock() + defer mw.writeMu.Unlock() + + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) + if err != nil { + return xerrors.Errorf("failed to write fin frame: %w", err) + } + + if mw.flate && !mw.flateContextTakeover() { + mw.dict.close() + } + mw.mu.Unlock() + return nil +} + +func (mw *msgWriterState) close() { + mw.writeMu.Lock() + mw.dict.close() +} + +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + _, err := c.writeFrame(ctx, true, false, opcode, p) + if err != nil { + return xerrors.Errorf("failed to write control frame %v: %w", opcode, err) + } + return nil +} + +// frame handles all writes to the connection. +func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) { + err := c.writeFrameMu.Lock(ctx) + if err != nil { + return 0, err + } + defer c.writeFrameMu.Unlock() + + select { + case <-c.closed: + return 0, c.closeErr + case c.writeTimeout <- ctx: + } + + c.writeHeader.fin = fin + c.writeHeader.opcode = opcode + c.writeHeader.payloadLength = int64(len(p)) + + if c.client { + c.writeHeader.masked = true + _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) + if err != nil { + return 0, xerrors.Errorf("failed to generate masking key: %w", err) + } + c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) + } + + c.writeHeader.rsv1 = false + if flate && (opcode == opText || opcode == opBinary) { + c.writeHeader.rsv1 = true + } + + err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) + if err != nil { + return 0, err + } + + n, err := c.writeFramePayload(p) + if err != nil { + return n, err + } + + if c.writeHeader.fin { + err = c.bw.Flush() + if err != nil { + return n, xerrors.Errorf("failed to flush: %w", err) + } + } + + select { + case <-c.closed: + return n, c.closeErr + case c.writeTimeout <- context.Background(): + } + + return n, nil +} + +func (c *Conn) writeFramePayload(p []byte) (n int, err error) { + defer errd.Wrap(&err, "failed to write frame payload") + + if !c.writeHeader.masked { + return c.bw.Write(p) + } + + maskKey := c.writeHeader.maskKey + for len(p) > 0 { + // If the buffer is full, we need to flush. + if c.bw.Available() == 0 { + err = c.bw.Flush() + if err != nil { + return n, err + } + } + + // Start of next write in the buffer. + i := c.bw.Buffered() + + j := len(p) + if j > c.bw.Available() { + j = c.bw.Available() + } + + _, err := c.bw.Write(p[:j]) + if err != nil { + return n, err + } + + maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) + + p = p[j:] + n += j + } + + return n, nil +} + +type writerFunc func(p []byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { + return f(p) +} + +// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer +// and returns it. +func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { + var writeBuf []byte + bw.Reset(writerFunc(func(p2 []byte) (int, error) { + writeBuf = p2[:cap(p2)] + return len(p2), nil + })) + + bw.WriteByte(0) + bw.Flush() + + bw.Reset(w) + + return writeBuf +} + +func (c *Conn) writeError(code StatusCode, err error) { + c.setCloseErr(err) + c.writeClose(code, err.Error()) + c.close(nil) +} diff --git a/websocket_js.go b/ws_js.go similarity index 74% rename from websocket_js.go rename to ws_js.go index d27809cf..ecf3d78c 100644 --- a/websocket_js.go +++ b/ws_js.go @@ -3,8 +3,6 @@ package websocket // import "nhooyr.io/websocket" import ( "bytes" "context" - "errors" - "fmt" "io" "net/http" "reflect" @@ -12,8 +10,11 @@ import ( "sync" "syscall/js" + "golang.org/x/xerrors" + "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" + "nhooyr.io/websocket/internal/xsync" ) // Conn provides a wrapper around the browser WebSocket API. @@ -21,10 +22,10 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit *atomicInt64 + msgReadLimit xsync.Int64 closingMu sync.Mutex - isReadClosed *atomicInt64 + isReadClosed xsync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -44,7 +45,7 @@ func (c *Conn) close(err error, wasClean bool) { runtime.SetFinalizer(c, nil) if !wasClean { - err = fmt.Errorf("unclean connection close: %w", err) + err = xerrors.Errorf("unclean connection close: %w", err) } c.setCloseErr(err) c.closeWasClean = wasClean @@ -56,17 +57,17 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) - c.msgReadLimit = &atomicInt64{} c.msgReadLimit.Store(32768) - c.isReadClosed = &atomicInt64{} - c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { err := CloseError{ Code: StatusCode(e.Code), Reason: e.Reason, } - c.close(fmt.Errorf("received close: %w", err), e.WasClean) + // We do not know if we sent or received this close as + // its possible the browser triggered it without us + // explicitly sending it. + c.close(err, e.WasClean) c.releaseOnClose() c.releaseOnMessage() @@ -86,7 +87,7 @@ func (c *Conn) init() { }) runtime.SetFinalizer(c, func(c *Conn) { - c.setCloseErr(errors.New("connection garbage collected")) + c.setCloseErr(xerrors.New("connection garbage collected")) c.closeWithInternal() }) } @@ -99,15 +100,15 @@ func (c *Conn) closeWithInternal() { // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if c.isReadClosed.Load() == 1 { - return 0, nil, fmt.Errorf("websocket connection read closed") + return 0, nil, xerrors.New("WebSocket connection read closed") } typ, p, err := c.read(ctx) if err != nil { - return 0, nil, fmt.Errorf("failed to read: %w", err) + return 0, nil, xerrors.Errorf("failed to read: %w", err) } if int64(len(p)) > c.msgReadLimit.Load() { - err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit) + err := xerrors.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) c.Close(StatusMessageTooBig, err.Error()) return 0, nil, err } @@ -151,6 +152,11 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { } } +// Ping is mocked out for Wasm. +func (c *Conn) Ping(ctx context.Context) error { + return nil +} + // Write writes a message of the given type to the connection. // Always non blocking. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { @@ -160,7 +166,7 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { // to match the Go API. It can only error if the message type // is unexpected or the passed bytes contain invalid UTF-8 for // MessageText. - err := fmt.Errorf("failed to write: %w", err) + err := xerrors.Errorf("failed to write: %w", err) c.setCloseErr(err) c.closeWithInternal() return err @@ -178,18 +184,18 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { case MessageText: return c.ws.SendText(string(p)) default: - return fmt.Errorf("unexpected message type: %v", typ) + return xerrors.Errorf("unexpected message type: %v", typ) } } -// Close closes the websocket with the given code and reason. +// Close closes the WebSocket with the given code and reason. // It will wait until the peer responds with a close frame // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { - return fmt.Errorf("failed to close websocket: %w", err) + return xerrors.Errorf("failed to close WebSocket: %w", err) } return nil } @@ -198,13 +204,13 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { c.closingMu.Lock() defer c.closingMu.Unlock() - ce := fmt.Errorf("sent close: %w", CloseError{ + ce := xerrors.Errorf("sent close: %w", CloseError{ Code: code, Reason: reason, }) if c.isClosed() { - return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) + return xerrors.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) } c.setCloseErr(ce) @@ -234,12 +240,12 @@ type DialOptions struct { // Dial creates a new WebSocket connection to the given url with the given options. // The passed context bounds the maximum time spent waiting for the connection to open. -// The returned *http.Response is always nil or the zero value. It's only in the signature +// The returned *http.Response is always nil or a mock. It's only in the signature // to match the core API. func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { c, resp, err := dial(ctx, url, opts) if err != nil { - return nil, resp, fmt.Errorf("failed to websocket dial %q: %w", url, err) + return nil, nil, xerrors.Errorf("failed to WebSocket dial %q: %w", url, err) } return c, resp, nil } @@ -270,12 +276,12 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp c.Close(StatusPolicyViolation, "dial timed out") return nil, nil, ctx.Err() case <-opench: + return c, &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + }, nil case <-c.closed: - return c, nil, c.closeErr + return nil, nil, c.closeErr } - - // Have to return a non nil response as the normal API does that. - return c, &http.Response{}, nil } // Reader attempts to read a message from the connection. @@ -288,11 +294,6 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return typ, bytes.NewReader(p), nil } -// Only implemented for use by *Conn.CloseRead in conn_common.go -func (c *Conn) reader(ctx context.Context, _ bool) { - c.read(ctx) -} - // Writer returns a writer to write a WebSocket data message to the connection. // It buffers the entire message in memory and then sends it when the writer // is closed. @@ -317,25 +318,58 @@ type writer struct { func (w writer) Write(p []byte) (int, error) { if w.closed { - return 0, errors.New("cannot write to closed writer") + return 0, xerrors.New("cannot write to closed writer") } n, err := w.b.Write(p) if err != nil { - return n, fmt.Errorf("failed to write message: %w", err) + return n, xerrors.Errorf("failed to write message: %w", err) } return n, nil } func (w writer) Close() error { if w.closed { - return errors.New("cannot close closed writer") + return xerrors.New("cannot close closed writer") } w.closed = true defer bpool.Put(w.b) err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) if err != nil { - return fmt.Errorf("failed to close writer: %w", err) + return xerrors.Errorf("failed to close writer: %w", err) } return nil } + +// CloseRead implements *Conn.CloseRead for wasm. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.isReadClosed.Store(1) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.read(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit implements *Conn.SetReadLimit for wasm. +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit.Store(n) +} + +func (c *Conn) setCloseErr(err error) { + c.closeErrOnce.Do(func() { + c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) + }) +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/ws_js_test.go b/ws_js_test.go new file mode 100644 index 00000000..e6be6181 --- /dev/null +++ b/ws_js_test.go @@ -0,0 +1,38 @@ +package websocket_test + +import ( + "context" + "net/http" + "os" + "testing" + "time" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/test/wstest" +) + +func TestWasm(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ + Subprotocols: []string{"echo"}, + }) + assert.Success(t, err) + defer c.Close(websocket.StatusInternalError, "") + + assert.Equal(t, "subprotocol", "echo", c.Subprotocol()) + assert.Equal(t, "response code", http.StatusSwitchingProtocols, resp.StatusCode) + + c.SetReadLimit(65536) + for i := 0; i < 10; i++ { + err = wstest.Echo(ctx, c, 65536) + assert.Success(t, err) + } + + err = c.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) +} diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index fe935fa1..e6f06a2f 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -1,34 +1,34 @@ -// Package wsjson provides websocket helpers for JSON messages. +// Package wsjson provides helpers for reading and writing JSON messages. package wsjson // import "nhooyr.io/websocket/wsjson" import ( "context" "encoding/json" - "fmt" + + "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/errd" ) -// Read reads a json message from c into v. -// It will reuse buffers to avoid allocations. +// Read reads a JSON message from c into v. +// It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { - err := read(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to read json: %w", err) - } - return nil + return read(ctx, c, v) } -func read(ctx context.Context, c *websocket.Conn, v interface{}) error { +func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { + defer errd.Wrap(&err, "failed to read JSON message") + typ, r, err := c.Reader(ctx) if err != nil { return err } if typ != websocket.MessageText { - c.Close(websocket.StatusUnsupportedData, "can only accept text messages") - return fmt.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) + c.Close(websocket.StatusUnsupportedData, "expected text message") + return xerrors.Errorf("expected text message for JSON but got: %v", typ) } b := bpool.Get() @@ -42,39 +42,32 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") - return fmt.Errorf("failed to unmarshal json: %w", err) + return xerrors.Errorf("failed to unmarshal JSON: %w", err) } return nil } -// Write writes the json message v to c. -// It will reuse buffers to avoid allocations. +// Write writes the JSON message v to c. +// It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { - err := write(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to write json: %w", err) - } - return nil + return write(ctx, c, v) } -func write(ctx context.Context, c *websocket.Conn, v interface{}) error { +func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { + defer errd.Wrap(&err, "failed to write JSON message") + w, err := c.Writer(ctx, websocket.MessageText) if err != nil { return err } - // We use Encode because it automatically enables buffer reuse without us - // needing to do anything. Though see https://github.com/golang/go/issues/27735 - e := json.NewEncoder(w) - err = e.Encode(v) + // json.Marshal cannot reuse buffers between calls as it has to return + // a copy of the byte slice but Encoder does as it directly writes to w. + err = json.NewEncoder(w).Encode(v) if err != nil { - return fmt.Errorf("failed to encode json: %w", err) + return xerrors.Errorf("failed to marshal JSON: %w", err) } - err = w.Close() - if err != nil { - return err - } - return nil + return w.Close() } diff --git a/wspb/wspb.go b/wspb/wspb.go index 3c9e0f76..06ac3368 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -1,36 +1,35 @@ -// Package wspb provides websocket helpers for protobuf messages. +// Package wspb provides helpers for reading and writing protobuf messages. package wspb // import "nhooyr.io/websocket/wspb" import ( "bytes" "context" - "fmt" "github.com/golang/protobuf/proto" + "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/errd" ) // Read reads a protobuf message from c into v. -// It will reuse buffers to avoid allocations. +// It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { - err := read(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to read protobuf: %w", err) - } - return nil + return read(ctx, c, v) } -func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { +func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { + defer errd.Wrap(&err, "failed to read protobuf message") + typ, r, err := c.Reader(ctx) if err != nil { return err } if typ != websocket.MessageBinary { - c.Close(websocket.StatusUnsupportedData, "can only accept binary messages") - return fmt.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) + c.Close(websocket.StatusUnsupportedData, "expected binary message") + return xerrors.Errorf("expected binary message for protobuf but got: %v", typ) } b := bpool.Get() @@ -44,32 +43,30 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { err = proto.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") - return fmt.Errorf("failed to unmarshal protobuf: %w", err) + return xerrors.Errorf("failed to unmarshal protobuf: %w", err) } return nil } // Write writes the protobuf message v to c. -// It will reuse buffers to avoid allocations. +// It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - err := write(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to write protobuf: %w", err) - } - return nil + return write(ctx, c, v) } -func write(ctx context.Context, c *websocket.Conn, v proto.Message) error { +func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { + defer errd.Wrap(&err, "failed to write protobuf message") + b := bpool.Get() pb := proto.NewBuffer(b.Bytes()) defer func() { bpool.Put(bytes.NewBuffer(pb.Bytes())) }() - err := pb.Marshal(v) + err = pb.Marshal(v) if err != nil { - return fmt.Errorf("failed to marshal protobuf: %w", err) + return xerrors.Errorf("failed to marshal protobuf: %w", err) } return c.Write(ctx, websocket.MessageBinary, pb.Bytes())