diff --git a/go.mod b/go.mod index b76ff7931d..4d8f58967e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module golang.org/x/net go 1.17 require ( - golang.org/x/sys v0.3.0 - golang.org/x/term v0.3.0 - golang.org/x/text v0.5.0 + golang.org/x/sys v0.5.0 + golang.org/x/term v0.5.0 + golang.org/x/text v0.7.0 ) diff --git a/go.sum b/go.sum index 1077b4d1d3..bcd80060dd 100644 --- a/go.sum +++ b/go.sum @@ -12,17 +12,17 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= -golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= +golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= -golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/html/comment_test.go b/html/comment_test.go new file mode 100644 index 0000000000..2c80bc748c --- /dev/null +++ b/html/comment_test.go @@ -0,0 +1,270 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package html + +import ( + "bytes" + "testing" +) + +// TestComments exhaustively tests every 'interesting' N-byte string is +// correctly parsed as a comment. N ranges from 4+1 to 4+suffixLen inclusive, +// where 4 is the length of the "") return + } else if c == '-' { + dashCount = 1 + beginning = false + continue } } } @@ -645,6 +649,35 @@ func (z *Tokenizer) readComment() { } } +func (z *Tokenizer) calculateAbruptCommentDataEnd() int { + raw := z.Raw() + const prefixLen = len("", }, - // Comments. + // Comments. See also func TestComments. { "comment0", "abcdef", @@ -376,6 +376,41 @@ var tokenTests = []tokenTest{ "az", "a$$z", }, + { + "comment16", + "az", + "a$$z", + }, + { + "comment17", + "a", + }, + { + "comment18", + "az", + "a$$z", + }, + { + "comment19", + "a", + }, + { + "comment20", + "az", + "a$$z", + }, + { + "comment21", + "az", + "a$$z", + }, + { + "comment22", + "az", + "a$$z", + }, // An attribute with a backslash. { "backslash", diff --git a/http2/flow.go b/http2/flow.go index b51f0e0cf1..b7dbd18695 100644 --- a/http2/flow.go +++ b/http2/flow.go @@ -6,23 +6,91 @@ package http2 -// flow is the flow control window's size. -type flow struct { +// inflowMinRefresh is the minimum number of bytes we'll send for a +// flow control window update. +const inflowMinRefresh = 4 << 10 + +// inflow accounts for an inbound flow control window. +// It tracks both the latest window sent to the peer (used for enforcement) +// and the accumulated unsent window. +type inflow struct { + avail int32 + unsent int32 +} + +// init sets the initial window. +func (f *inflow) init(n int32) { + f.avail = n +} + +// add adds n bytes to the window, with a maximum window size of max, +// indicating that the peer can now send us more data. +// For example, the user read from a {Request,Response} body and consumed +// some of the buffered data, so the peer can now send more. +// It returns the number of bytes to send in a WINDOW_UPDATE frame to the peer. +// Window updates are accumulated and sent when the unsent capacity +// is at least inflowMinRefresh or will at least double the peer's available window. +func (f *inflow) add(n int) (connAdd int32) { + if n < 0 { + panic("negative update") + } + unsent := int64(f.unsent) + int64(n) + // "A sender MUST NOT allow a flow-control window to exceed 2^31-1 octets." + // RFC 7540 Section 6.9.1. + const maxWindow = 1<<31 - 1 + if unsent+int64(f.avail) > maxWindow { + panic("flow control update exceeds maximum window size") + } + f.unsent = int32(unsent) + if f.unsent < inflowMinRefresh && f.unsent < f.avail { + // If there aren't at least inflowMinRefresh bytes of window to send, + // and this update won't at least double the window, buffer the update for later. + return 0 + } + f.avail += f.unsent + f.unsent = 0 + return int32(unsent) +} + +// take attempts to take n bytes from the peer's flow control window. +// It reports whether the window has available capacity. +func (f *inflow) take(n uint32) bool { + if n > uint32(f.avail) { + return false + } + f.avail -= int32(n) + return true +} + +// takeInflows attempts to take n bytes from two inflows, +// typically connection-level and stream-level flows. +// It reports whether both windows have available capacity. +func takeInflows(f1, f2 *inflow, n uint32) bool { + if n > uint32(f1.avail) || n > uint32(f2.avail) { + return false + } + f1.avail -= int32(n) + f2.avail -= int32(n) + return true +} + +// outflow is the outbound flow control window's size. +type outflow struct { _ incomparable // n is the number of DATA bytes we're allowed to send. - // A flow is kept both on a conn and a per-stream. + // An outflow is kept both on a conn and a per-stream. n int32 - // conn points to the shared connection-level flow that is - // shared by all streams on that conn. It is nil for the flow + // conn points to the shared connection-level outflow that is + // shared by all streams on that conn. It is nil for the outflow // that's on the conn directly. - conn *flow + conn *outflow } -func (f *flow) setConnFlow(cf *flow) { f.conn = cf } +func (f *outflow) setConnFlow(cf *outflow) { f.conn = cf } -func (f *flow) available() int32 { +func (f *outflow) available() int32 { n := f.n if f.conn != nil && f.conn.n < n { n = f.conn.n @@ -30,7 +98,7 @@ func (f *flow) available() int32 { return n } -func (f *flow) take(n int32) { +func (f *outflow) take(n int32) { if n > f.available() { panic("internal error: took too much") } @@ -42,7 +110,7 @@ func (f *flow) take(n int32) { // add adds n bytes (positive or negative) to the flow control window. // It returns false if the sum would exceed 2^31-1. -func (f *flow) add(n int32) bool { +func (f *outflow) add(n int32) bool { sum := f.n + n if (sum > n) == (f.n > 0) { f.n = sum diff --git a/http2/flow_test.go b/http2/flow_test.go index 7ae82c7817..cae4f38c0c 100644 --- a/http2/flow_test.go +++ b/http2/flow_test.go @@ -6,9 +6,61 @@ package http2 import "testing" -func TestFlow(t *testing.T) { - var st flow - var conn flow +func TestInFlowTake(t *testing.T) { + var f inflow + f.init(100) + if !f.take(40) { + t.Fatalf("f.take(40) from 100: got false, want true") + } + if !f.take(40) { + t.Fatalf("f.take(40) from 60: got false, want true") + } + if f.take(40) { + t.Fatalf("f.take(40) from 20: got true, want false") + } + if !f.take(20) { + t.Fatalf("f.take(20) from 20: got false, want true") + } +} + +func TestInflowAddSmall(t *testing.T) { + var f inflow + f.init(0) + // Adding even a small amount when there is no flow causes an immediate send. + if got, want := f.add(1), int32(1); got != want { + t.Fatalf("f.add(1) to 1 = %v, want %v", got, want) + } +} + +func TestInflowAdd(t *testing.T) { + var f inflow + f.init(10 * inflowMinRefresh) + if got, want := f.add(inflowMinRefresh-1), int32(0); got != want { + t.Fatalf("f.add(minRefresh - 1) = %v, want %v", got, want) + } + if got, want := f.add(1), int32(inflowMinRefresh); got != want { + t.Fatalf("f.add(minRefresh) = %v, want %v", got, want) + } +} + +func TestTakeInflows(t *testing.T) { + var a, b inflow + a.init(10) + b.init(20) + if !takeInflows(&a, &b, 5) { + t.Fatalf("takeInflows(a, b, 5) from 10, 20: got false, want true") + } + if takeInflows(&a, &b, 6) { + t.Fatalf("takeInflows(a, b, 6) from 5, 15: got true, want false") + } + if !takeInflows(&a, &b, 5) { + t.Fatalf("takeInflows(a, b, 5) from 5, 15: got false, want true") + } +} + +func TestOutFlow(t *testing.T) { + var st outflow + var conn outflow st.add(3) conn.add(2) @@ -29,8 +81,8 @@ func TestFlow(t *testing.T) { } } -func TestFlowAdd(t *testing.T) { - var f flow +func TestOutFlowAdd(t *testing.T) { + var f outflow if !f.add(1) { t.Fatal("failed to add 1") } @@ -51,8 +103,8 @@ func TestFlowAdd(t *testing.T) { } } -func TestFlowAddOverflow(t *testing.T) { - var f flow +func TestOutFlowAddOverflow(t *testing.T) { + var f outflow if !f.add(0) { t.Fatal("failed to add 0") } diff --git a/http2/frame.go b/http2/frame.go index 184ac45feb..c1f6b90dc3 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -662,6 +662,15 @@ func (f *Framer) WriteData(streamID uint32, endStream bool, data []byte) error { // It is the caller's responsibility not to violate the maximum frame size // and to not call other Write methods concurrently. func (f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { + if err := f.startWriteDataPadded(streamID, endStream, data, pad); err != nil { + return err + } + return f.endWrite() +} + +// startWriteDataPadded is WriteDataPadded, but only writes the frame to the Framer's internal buffer. +// The caller should call endWrite to flush the frame to the underlying writer. +func (f *Framer) startWriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { if !validStreamID(streamID) && !f.AllowIllegalWrites { return errStreamID } @@ -691,7 +700,7 @@ func (f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []by } f.wbuf = append(f.wbuf, data...) f.wbuf = append(f.wbuf, pad...) - return f.endWrite() + return nil } // A SettingsFrame conveys configuration parameters that affect how diff --git a/http2/hpack/hpack.go b/http2/hpack/hpack.go index ebdfbee964..7a1d976696 100644 --- a/http2/hpack/hpack.go +++ b/http2/hpack/hpack.go @@ -211,7 +211,7 @@ func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) { return dt.ents[dt.len()-(int(i)-staticTable.len())], true } -// Decode decodes an entire block. +// DecodeFull decodes an entire block. // // TODO: remove this method and make it incremental later? This is // easier for debugging now. @@ -359,6 +359,7 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error { var hf HeaderField wantStr := d.emitEnabled || it.indexed() + var undecodedName undecodedString if nameIdx > 0 { ihf, ok := d.at(nameIdx) if !ok { @@ -366,15 +367,27 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error { } hf.Name = ihf.Name } else { - hf.Name, buf, err = d.readString(buf, wantStr) + undecodedName, buf, err = d.readString(buf) if err != nil { return err } } - hf.Value, buf, err = d.readString(buf, wantStr) + undecodedValue, buf, err := d.readString(buf) if err != nil { return err } + if wantStr { + if nameIdx <= 0 { + hf.Name, err = d.decodeString(undecodedName) + if err != nil { + return err + } + } + hf.Value, err = d.decodeString(undecodedValue) + if err != nil { + return err + } + } d.buf = buf if it.indexed() { d.dynTab.add(hf) @@ -459,46 +472,52 @@ func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) { return 0, origP, errNeedMore } -// readString decodes an hpack string from p. +// readString reads an hpack string from p. // -// wantStr is whether s will be used. If false, decompression and -// []byte->string garbage are skipped if s will be ignored -// anyway. This does mean that huffman decoding errors for non-indexed -// strings past the MAX_HEADER_LIST_SIZE are ignored, but the server -// is returning an error anyway, and because they're not indexed, the error -// won't affect the decoding state. -func (d *Decoder) readString(p []byte, wantStr bool) (s string, remain []byte, err error) { +// It returns a reference to the encoded string data to permit deferring decode costs +// until after the caller verifies all data is present. +func (d *Decoder) readString(p []byte) (u undecodedString, remain []byte, err error) { if len(p) == 0 { - return "", p, errNeedMore + return u, p, errNeedMore } isHuff := p[0]&128 != 0 strLen, p, err := readVarInt(7, p) if err != nil { - return "", p, err + return u, p, err } if d.maxStrLen != 0 && strLen > uint64(d.maxStrLen) { - return "", nil, ErrStringLength + // Returning an error here means Huffman decoding errors + // for non-indexed strings past the maximum string length + // are ignored, but the server is returning an error anyway + // and because the string is not indexed the error will not + // affect the decoding state. + return u, nil, ErrStringLength } if uint64(len(p)) < strLen { - return "", p, errNeedMore - } - if !isHuff { - if wantStr { - s = string(p[:strLen]) - } - return s, p[strLen:], nil + return u, p, errNeedMore } + u.isHuff = isHuff + u.b = p[:strLen] + return u, p[strLen:], nil +} - if wantStr { - buf := bufPool.Get().(*bytes.Buffer) - buf.Reset() // don't trust others - defer bufPool.Put(buf) - if err := huffmanDecode(buf, d.maxStrLen, p[:strLen]); err != nil { - buf.Reset() - return "", nil, err - } +type undecodedString struct { + isHuff bool + b []byte +} + +func (d *Decoder) decodeString(u undecodedString) (string, error) { + if !u.isHuff { + return string(u.b), nil + } + buf := bufPool.Get().(*bytes.Buffer) + buf.Reset() // don't trust others + var s string + err := huffmanDecode(buf, d.maxStrLen, u.b) + if err == nil { s = buf.String() - buf.Reset() // be nice to GC } - return s, p[strLen:], nil + buf.Reset() // be nice to GC + bufPool.Put(buf) + return s, err } diff --git a/http2/hpack/hpack_test.go b/http2/hpack/hpack_test.go index f9092e8bb9..b4b2a5d666 100644 --- a/http2/hpack/hpack_test.go +++ b/http2/hpack/hpack_test.go @@ -728,6 +728,36 @@ func TestEmitEnabled(t *testing.T) { } } +func TestSlowIncrementalDecode(t *testing.T) { + // TODO(dneil): Fix for -race mode. + t.Skip("too slow in -race mode") + + var buf bytes.Buffer + enc := NewEncoder(&buf) + hf := HeaderField{ + Name: strings.Repeat("k", 1<<20), + Value: strings.Repeat("v", 1<<20), + } + enc.WriteField(hf) + hbuf := buf.Bytes() + count := 0 + dec := NewDecoder(initialHeaderTableSize, func(got HeaderField) { + count++ + if count != 1 { + t.Errorf("decoded %v fields, want 1", count) + } + if got.Name != hf.Name { + t.Errorf("decoded Name does not match input") + } + if got.Value != hf.Value { + t.Errorf("decoded Value does not match input") + } + }) + for i := 0; i < len(hbuf); i++ { + dec.Write(hbuf[i : i+1]) + } +} + func TestSaveBufLimit(t *testing.T) { const maxStr = 1 << 10 var got []HeaderField diff --git a/http2/server.go b/http2/server.go index 4eb7617fa0..8cb14f3c97 100644 --- a/http2/server.go +++ b/http2/server.go @@ -448,7 +448,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { // configured value for inflow, that will be updated when we send a // WINDOW_UPDATE shortly after sending SETTINGS. sc.flow.add(initialWindowSize) - sc.inflow.add(initialWindowSize) + sc.inflow.init(initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize()) @@ -563,8 +563,8 @@ type serverConn struct { wroteFrameCh chan frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes bodyReadCh chan bodyReadMsg // from handlers -> serve serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop - flow flow // conn-wide (not stream-specific) outbound flow control - inflow flow // conn-wide inbound flow control + flow outflow // conn-wide (not stream-specific) outbound flow control + inflow inflow // conn-wide inbound flow control tlsState *tls.ConnectionState // shared by all handlers, like net/http remoteAddrStr string writeSched WriteScheduler @@ -641,10 +641,10 @@ type stream struct { cancelCtx func() // owned by serverConn's serve loop: - bodyBytes int64 // body bytes seen so far - declBodyBytes int64 // or -1 if undeclared - flow flow // limits writing from Handler to client - inflow flow // what the client is allowed to POST/etc to us + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow outflow // limits writing from Handler to client + inflow inflow // what the client is allowed to POST/etc to us state streamState resetQueued bool // RST_STREAM queued for write; set by sc.resetStream gotTrailerHeader bool // HEADER frame for trailers was seen @@ -843,8 +843,13 @@ type frameWriteResult struct { // and then reports when it's done. // At most one goroutine can be running writeFrameAsync at a time per // serverConn. -func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest) { - err := wr.write.writeFrame(sc) +func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) { + var err error + if wd == nil { + err = wr.write.writeFrame(sc) + } else { + err = sc.framer.endWrite() + } sc.wroteFrameCh <- frameWriteResult{wr: wr, err: err} } @@ -1251,9 +1256,16 @@ func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) { sc.writingFrameAsync = false err := wr.write.writeFrame(sc) sc.wroteFrame(frameWriteResult{wr: wr, err: err}) + } else if wd, ok := wr.write.(*writeData); ok { + // Encode the frame in the serve goroutine, to ensure we don't have + // any lingering asynchronous references to data passed to Write. + // See https://go.dev/issue/58446. + sc.framer.startWriteDataPadded(wd.streamID, wd.endStream, wd.p, nil) + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr, wd) } else { sc.writingFrameAsync = true - go sc.writeFrameAsync(wr) + go sc.writeFrameAsync(wr, nil) } } @@ -1503,7 +1515,7 @@ func (sc *serverConn) processFrame(f Frame) error { if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) { if f, ok := f.(*DataFrame); ok { - if sc.inflow.available() < int32(f.Length) { + if !sc.inflow.take(f.Length) { return sc.countError("data_flow", streamError(f.Header().StreamID, ErrCodeFlowControl)) } sc.sendWindowUpdate(nil, int(f.Length)) // conn-level @@ -1775,14 +1787,9 @@ func (sc *serverConn) processData(f *DataFrame) error { // But still enforce their connection-level flow control, // and return any flow control bytes since we're not going // to consume them. - if sc.inflow.available() < int32(f.Length) { + if !sc.inflow.take(f.Length) { return sc.countError("data_flow", streamError(id, ErrCodeFlowControl)) } - // Deduct the flow control from inflow, since we're - // going to immediately add it back in - // sendWindowUpdate, which also schedules sending the - // frames. - sc.inflow.take(int32(f.Length)) sc.sendWindowUpdate(nil, int(f.Length)) // conn-level if st != nil && st.resetQueued { @@ -1797,10 +1804,9 @@ func (sc *serverConn) processData(f *DataFrame) error { // Sender sending more than they'd declared? if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { - if sc.inflow.available() < int32(f.Length) { + if !sc.inflow.take(f.Length) { return sc.countError("data_flow", streamError(id, ErrCodeFlowControl)) } - sc.inflow.take(int32(f.Length)) sc.sendWindowUpdate(nil, int(f.Length)) // conn-level st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) @@ -1811,10 +1817,9 @@ func (sc *serverConn) processData(f *DataFrame) error { } if f.Length > 0 { // Check whether the client has flow control quota. - if st.inflow.available() < int32(f.Length) { + if !takeInflows(&sc.inflow, &st.inflow, f.Length) { return sc.countError("flow_on_data_length", streamError(id, ErrCodeFlowControl)) } - st.inflow.take(int32(f.Length)) if len(data) > 0 { wrote, err := st.body.Write(data) @@ -1830,10 +1835,12 @@ func (sc *serverConn) processData(f *DataFrame) error { // Return any padded flow control now, since we won't // refund it later on body reads. - if pad := int32(f.Length) - int32(len(data)); pad > 0 { - sc.sendWindowUpdate32(nil, pad) - sc.sendWindowUpdate32(st, pad) - } + // Call sendWindowUpdate even if there is no padding, + // to return buffered flow control credit if the sent + // window has shrunk. + pad := int32(f.Length) - int32(len(data)) + sc.sendWindowUpdate32(nil, pad) + sc.sendWindowUpdate32(st, pad) } if f.StreamEnded() { st.endStream() @@ -2105,8 +2112,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.cw.Init() st.flow.conn = &sc.flow // link to conn-level counter st.flow.add(sc.initialStreamSendWindowSize) - st.inflow.conn = &sc.inflow // link to conn-level counter - st.inflow.add(sc.srv.initialStreamRecvWindowSize()) + st.inflow.init(sc.srv.initialStreamRecvWindowSize()) if sc.hs.WriteTimeout != 0 { st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } @@ -2198,7 +2204,7 @@ func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*r tlsState = sc.tlsState } - needsContinue := rp.header.Get("Expect") == "100-continue" + needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue") if needsContinue { rp.header.Del("Expect") } @@ -2388,47 +2394,28 @@ func (sc *serverConn) noteBodyRead(st *stream, n int) { } // st may be nil for conn-level -func (sc *serverConn) sendWindowUpdate(st *stream, n int) { - sc.serveG.check() - // "The legal range for the increment to the flow control - // window is 1 to 2^31-1 (2,147,483,647) octets." - // A Go Read call on 64-bit machines could in theory read - // a larger Read than this. Very unlikely, but we handle it here - // rather than elsewhere for now. - const maxUint31 = 1<<31 - 1 - for n > maxUint31 { - sc.sendWindowUpdate32(st, maxUint31) - n -= maxUint31 - } - sc.sendWindowUpdate32(st, int32(n)) +func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) { + sc.sendWindowUpdate(st, int(n)) } // st may be nil for conn-level -func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) { +func (sc *serverConn) sendWindowUpdate(st *stream, n int) { sc.serveG.check() - if n == 0 { - return - } - if n < 0 { - panic("negative update") - } var streamID uint32 - if st != nil { + var send int32 + if st == nil { + send = sc.inflow.add(n) + } else { streamID = st.id + send = st.inflow.add(n) + } + if send == 0 { + return } sc.writeFrame(FrameWriteRequest{ - write: writeWindowUpdate{streamID: streamID, n: uint32(n)}, + write: writeWindowUpdate{streamID: streamID, n: uint32(send)}, stream: st, }) - var ok bool - if st == nil { - ok = sc.inflow.add(n) - } else { - ok = st.inflow.add(n) - } - if !ok { - panic("internal error; sent too many window updates without decrements?") - } } // requestBody is the Handler's Request.Body type. diff --git a/http2/server_test.go b/http2/server_test.go index a1e086c193..d32b2d85bd 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -482,6 +482,22 @@ func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, p } } +// writeReadPing sends a PING and immediately reads the PING ACK. +// It will fail if any other unread data was pending on the connection. +func (st *serverTester) writeReadPing() { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + if err := st.fr.WritePing(false, data); err != nil { + st.t.Fatalf("Error writing PING: %v", err) + } + p := st.wantPing() + if p.Flags&FlagPingAck == 0 { + st.t.Fatalf("got a PING, want a PING ACK") + } + if p.Data != data { + st.t.Fatalf("got PING data = %x, want %x", p.Data, data) + } +} + func (st *serverTester) readFrame() (Frame, error) { return st.fr.ReadFrame() } @@ -592,6 +608,28 @@ func (st *serverTester) wantWindowUpdate(streamID, incr uint32) { } } +func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) { + var initial int32 + if streamID == 0 { + initial = st.sc.srv.initialConnRecvWindowSize() + } else { + initial = st.sc.srv.initialStreamRecvWindowSize() + } + donec := make(chan struct{}) + st.sc.sendServeMsg(func(sc *serverConn) { + defer close(donec) + var avail int32 + if streamID == 0 { + avail = sc.inflow.avail + sc.inflow.unsent + } else { + } + if got, want := initial-avail, consumed; got != want { + st.t.Errorf("stream %v flow control consumed: %v, want %v", streamID, got, want) + } + }) + <-donec +} + func (st *serverTester) wantSettingsAck() { f, err := st.readFrame() if err != nil { @@ -811,7 +849,8 @@ func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) { st.writeData(1, true, []byte("12345")) // Return flow control bytes back, since the data handler closed // the stream. - st.wantWindowUpdate(0, 5) + st.wantRSTStream(1, ErrCodeProtocol) + st.wantFlowControlConsumed(0, 0) }) } @@ -1238,69 +1277,89 @@ func TestServer_RejectsLargeFrames(t *testing.T) { } func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { + // Need to set this to at least twice the initial window size, + // or st.greet gets stuck waiting for a WINDOW_UPDATE. + // + // This also needs to be less than MAX_FRAME_SIZE. + const windowSize = 65535 * 2 puppet := newHandlerPuppet() st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { puppet.act(w, r) + }, func(s *Server) { + s.MaxUploadBufferPerConnection = windowSize + s.MaxUploadBufferPerStream = windowSize }) defer st.Close() defer puppet.done() st.greet() - st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, // data coming EndHeaders: true, }) - st.writeData(1, false, []byte("abcdef")) - puppet.do(readBodyHandler(t, "abc")) - st.wantWindowUpdate(0, 3) - st.wantWindowUpdate(1, 3) - - puppet.do(readBodyHandler(t, "def")) - st.wantWindowUpdate(0, 3) - st.wantWindowUpdate(1, 3) - - st.writeData(1, true, []byte("ghijkl")) // END_STREAM here - puppet.do(readBodyHandler(t, "ghi")) - puppet.do(readBodyHandler(t, "jkl")) - st.wantWindowUpdate(0, 3) - st.wantWindowUpdate(0, 3) // no more stream-level, since END_STREAM + st.writeReadPing() + + // Write less than half the max window of data and consume it. + // The server doesn't return flow control yet, buffering the 1024 bytes to + // combine with a future update. + data := make([]byte, windowSize) + st.writeData(1, false, data[:1024]) + puppet.do(readBodyHandler(t, string(data[:1024]))) + st.writeReadPing() + + // Write up to the window limit. + // The server returns the buffered credit. + st.writeData(1, false, data[1024:]) + st.wantWindowUpdate(0, 1024) + st.wantWindowUpdate(1, 1024) + st.writeReadPing() + + // The handler consumes the data and the server returns credit. + puppet.do(readBodyHandler(t, string(data[1024:]))) + st.wantWindowUpdate(0, windowSize-1024) + st.wantWindowUpdate(1, windowSize-1024) + st.writeReadPing() } // the version of the TestServer_Handler_Sends_WindowUpdate with padding. // See golang.org/issue/16556 func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) { + const windowSize = 65535 * 2 puppet := newHandlerPuppet() st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { puppet.act(w, r) + }, func(s *Server) { + s.MaxUploadBufferPerConnection = windowSize + s.MaxUploadBufferPerStream = windowSize }) defer st.Close() defer puppet.done() st.greet() - st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, EndHeaders: true, }) - st.writeDataPadded(1, false, []byte("abcdef"), []byte{0, 0, 0, 0}) + st.writeReadPing() - // Expect to immediately get our 5 bytes of padding back for - // both the connection and stream (4 bytes of padding + 1 byte of length) - st.wantWindowUpdate(0, 5) - st.wantWindowUpdate(1, 5) + // Write half a window of data, with some padding. + // The server doesn't return the padding yet, buffering the 5 bytes to combine + // with a future update. + data := make([]byte, windowSize/2) + pad := make([]byte, 4) + st.writeDataPadded(1, false, data, pad) + st.writeReadPing() - puppet.do(readBodyHandler(t, "abc")) - st.wantWindowUpdate(0, 3) - st.wantWindowUpdate(1, 3) - - puppet.do(readBodyHandler(t, "def")) - st.wantWindowUpdate(0, 3) - st.wantWindowUpdate(1, 3) + // The handler consumes the body. + // The server returns flow control for the body and padding + // (4 bytes of padding + 1 byte of length). + puppet.do(readBodyHandler(t, string(data))) + st.wantWindowUpdate(0, uint32(len(data)+1+len(pad))) + st.wantWindowUpdate(1, uint32(len(data)+1+len(pad))) } func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) { @@ -1645,7 +1704,7 @@ func TestServer_Rejects_HeadersSelfDependence(t *testing.T) { }) } -// No PRIORTY frame with a self-dependence. +// No PRIORITY frame with a self-dependence. func TestServer_Rejects_PrioritySelfDependence(t *testing.T) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.fr.AllowIllegalWrites = true @@ -2273,7 +2332,7 @@ func TestServer_Response_Automatic100Continue(t *testing.T) { }, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers - BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-continue"), + BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-Continue"), EndStream: false, EndHeaders: true, }) @@ -2296,8 +2355,6 @@ func TestServer_Response_Automatic100Continue(t *testing.T) { // gigantic and/or sensitive "foo" payload now. st.writeData(1, true, []byte(msg)) - st.wantWindowUpdate(0, uint32(len(msg))) - hf = st.wantHeaders() if hf.StreamEnded() { t.Fatal("expected data to follow") @@ -2485,15 +2542,16 @@ func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { // it did before. st.writeData(1, true, []byte("foo")) - // Get our flow control bytes back, since the handler didn't get them. - st.wantWindowUpdate(0, uint32(len("foo"))) - // Sent after a peer sends data anyway (admittedly the // previous RST_STREAM might've still been in-flight), // but they'll get the more friendly 'cancel' code // first. st.wantRSTStream(1, ErrCodeStreamClosed) + // We should have our flow control bytes back, + // since the handler didn't get them. + st.wantFlowControlConsumed(0, 0) + // Set up a bunch of machinery to record the panic we saw // previously. var ( @@ -3967,8 +4025,8 @@ func TestServer_Rejects_TooSmall(t *testing.T) { EndHeaders: true, }) st.writeData(1, true, []byte("12345")) - st.wantWindowUpdate(0, 5) st.wantRSTStream(1, ErrCodeProtocol) + st.wantFlowControlConsumed(0, 0) }) } @@ -4258,7 +4316,8 @@ func TestContentEncodingNoSniffing(t *testing.T) { } func TestServerWindowUpdateOnBodyClose(t *testing.T) { - const content = "12345678" + const windowSize = 65535 * 2 + content := make([]byte, windowSize) blockCh := make(chan bool) errc := make(chan error, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -4275,6 +4334,9 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) { blockCh <- true <-blockCh errc <- nil + }, func(s *Server) { + s.MaxUploadBufferPerConnection = windowSize + s.MaxUploadBufferPerStream = windowSize }) defer st.Close() @@ -4288,13 +4350,13 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) { EndStream: false, // to say DATA frames are coming EndHeaders: true, }) - st.writeData(1, false, []byte(content[:5])) + st.writeData(1, false, content[:windowSize/2]) <-blockCh st.stream(1).body.CloseWithError(io.EOF) - st.writeData(1, false, []byte(content[5:])) blockCh <- true - increments := len(content) + // Wait for flow control credit for the portion of the request written so far. + increments := windowSize / 2 for { f, err := st.readFrame() if err == io.EOF { @@ -4311,6 +4373,10 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) { } } + // Writing data after the stream is reset immediately returns flow control credit. + st.writeData(1, false, content[windowSize/2:]) + st.wantWindowUpdate(0, windowSize/2) + if err := <-errc; err != nil { t.Error(err) } @@ -4465,11 +4531,7 @@ func TestProtocolErrorAfterGoAway(t *testing.T) { EndHeaders: true, }) st.writeData(1, false, []byte(content[:5])) - - _, err := st.readFrame() - if err != nil { - st.t.Fatal(err) - } + st.writeReadPing() // Send a GOAWAY with ErrCodeNo, followed by a bogus window update. // The server should close the connection. @@ -4547,10 +4609,11 @@ func TestServerInitialFlowControlWindow(t *testing.T) { // TestCanonicalHeaderCacheGrowth verifies that the canonical header cache // size is capped to a reasonable level. func TestCanonicalHeaderCacheGrowth(t *testing.T) { - defer disableGoroutineTracking()() for _, size := range []int{1, (1 << 20) - 10} { base := strings.Repeat("X", size) - sc := &serverConn{} + sc := &serverConn{ + serveG: newGoroutineLock(), + } const count = 1000 for i := 0; i < count; i++ { h := fmt.Sprintf("%v-%v", base, i) @@ -4568,3 +4631,78 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) { } } } + +// TestServerWriteDoesNotRetainBufferAfterStreamClose checks for access to +// the slice passed to ResponseWriter.Write after Write returns. +// +// Terminating the request stream on the client causes Write to return. +// We should not access the slice after this point. +func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) { + donec := make(chan struct{}) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + defer close(donec) + buf := make([]byte, 1<<20) + var i byte + for { + i++ + _, err := w.Write(buf) + for j := range buf { + buf[j] = byte(i) // trigger race detector + } + if err != nil { + return + } + } + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + req, _ := http.NewRequest("GET", st.ts.URL, nil) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + <-donec +} + +// TestServerWriteDoesNotRetainBufferAfterServerClose checks for access to +// the slice passed to ResponseWriter.Write after Write returns. +// +// Shutting down the Server causes Write to return. +// We should not access the slice after this point. +func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) { + donec := make(chan struct{}, 1) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + donec <- struct{}{} + defer close(donec) + buf := make([]byte, 1<<20) + var i byte + for { + i++ + _, err := w.Write(buf) + for j := range buf { + buf[j] = byte(i) + } + if err != nil { + return + } + } + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + req, _ := http.NewRequest("GET", st.ts.URL, nil) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + <-donec + st.ts.Config.Close() + <-donec +} diff --git a/http2/transport.go b/http2/transport.go index 30f706e6cb..05ba23d3d9 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -47,10 +47,6 @@ const ( // we buffer per stream. transportDefaultStreamFlow = 4 << 20 - // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send - // a stream-level WINDOW_UPDATE for at a time. - transportDefaultStreamMinRefresh = 4 << 10 - defaultUserAgent = "Go-http-client/2.0" // initialMaxConcurrentStreams is a connections maxConcurrentStreams until @@ -310,8 +306,8 @@ type ClientConn struct { mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes - flow flow // our conn-level flow control quota (cs.flow is per stream) - inflow flow // peer's conn-level flow control + flow outflow // our conn-level flow control quota (cs.outflow is per stream) + inflow inflow // peer's conn-level flow control doNotReuse bool // whether conn is marked to not be reused for any future requests closing bool closed bool @@ -376,10 +372,10 @@ type clientStream struct { respHeaderRecv chan struct{} // closed when headers are received res *http.Response // set if respHeaderRecv is closed - flow flow // guarded by cc.mu - inflow flow // guarded by cc.mu - bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read - readErr error // sticky read error; owned by transportResponseBody.Read + flow outflow // guarded by cc.mu + inflow inflow // guarded by cc.mu + bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read + readErr error // sticky read error; owned by transportResponseBody.Read reqBody io.ReadCloser reqBodyContentLength int64 // -1 means unknown @@ -811,7 +807,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro cc.bw.Write(clientPreface) cc.fr.WriteSettings(initialSettings...) cc.fr.WriteWindowUpdate(0, transportDefaultConnFlow) - cc.inflow.add(transportDefaultConnFlow + initialWindowSize) + cc.inflow.init(transportDefaultConnFlow + initialWindowSize) cc.bw.Flush() if cc.werr != nil { cc.Close() @@ -1573,7 +1569,7 @@ func (cs *clientStream) cleanupWriteRequest(err error) { close(cs.donec) } -// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. +// awaitOpenSlotForStreamLocked waits until len(streams) < maxConcurrentStreams. // Must hold cc.mu. func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { for { @@ -2073,8 +2069,7 @@ type resAndError struct { func (cc *ClientConn) addStreamLocked(cs *clientStream) { cs.flow.add(int32(cc.initialWindowSize)) cs.flow.setConnFlow(&cc.flow) - cs.inflow.add(transportDefaultStreamFlow) - cs.inflow.setConnFlow(&cc.inflow) + cs.inflow.init(transportDefaultStreamFlow) cs.ID = cc.nextStreamID cc.nextStreamID += 2 cc.streams[cs.ID] = cs @@ -2533,21 +2528,10 @@ func (b transportResponseBody) Read(p []byte) (n int, err error) { } cc.mu.Lock() - var connAdd, streamAdd int32 - // Check the conn-level first, before the stream-level. - if v := cc.inflow.available(); v < transportDefaultConnFlow/2 { - connAdd = transportDefaultConnFlow - v - cc.inflow.add(connAdd) - } + connAdd := cc.inflow.add(n) + var streamAdd int32 if err == nil { // No need to refresh if the stream is over or failed. - // Consider any buffered body data (read from the conn but not - // consumed by the client) when computing flow control for this - // stream. - v := int(cs.inflow.available()) + cs.bufPipe.Len() - if v < transportDefaultStreamFlow-transportDefaultStreamMinRefresh { - streamAdd = int32(transportDefaultStreamFlow - v) - cs.inflow.add(streamAdd) - } + streamAdd = cs.inflow.add(n) } cc.mu.Unlock() @@ -2575,17 +2559,15 @@ func (b transportResponseBody) Close() error { if unread > 0 { cc.mu.Lock() // Return connection-level flow control. - if unread > 0 { - cc.inflow.add(int32(unread)) - } + connAdd := cc.inflow.add(unread) cc.mu.Unlock() // TODO(dneil): Acquiring this mutex can block indefinitely. // Move flow control return to a goroutine? cc.wmu.Lock() // Return connection-level flow control. - if unread > 0 { - cc.fr.WriteWindowUpdate(0, uint32(unread)) + if connAdd > 0 { + cc.fr.WriteWindowUpdate(0, uint32(connAdd)) } cc.bw.Flush() cc.wmu.Unlock() @@ -2628,13 +2610,18 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { // But at least return their flow control: if f.Length > 0 { cc.mu.Lock() - cc.inflow.add(int32(f.Length)) + ok := cc.inflow.take(f.Length) + connAdd := cc.inflow.add(int(f.Length)) cc.mu.Unlock() - - cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(f.Length)) - cc.bw.Flush() - cc.wmu.Unlock() + if !ok { + return ConnectionError(ErrCodeFlowControl) + } + if connAdd > 0 { + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(connAdd)) + cc.bw.Flush() + cc.wmu.Unlock() + } } return nil } @@ -2665,9 +2652,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { } // Check connection-level flow control. cc.mu.Lock() - if cs.inflow.available() >= int32(f.Length) { - cs.inflow.take(int32(f.Length)) - } else { + if !takeInflows(&cc.inflow, &cs.inflow, f.Length) { cc.mu.Unlock() return ConnectionError(ErrCodeFlowControl) } @@ -2689,19 +2674,20 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { } } - if refund > 0 { - cc.inflow.add(int32(refund)) - if !didReset { - cs.inflow.add(int32(refund)) - } + sendConn := cc.inflow.add(refund) + var sendStream int32 + if !didReset { + sendStream = cs.inflow.add(refund) } cc.mu.Unlock() - if refund > 0 { + if sendConn > 0 || sendStream > 0 { cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(refund)) - if !didReset { - cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) + if sendConn > 0 { + cc.fr.WriteWindowUpdate(0, uint32(sendConn)) + } + if sendStream > 0 { + cc.fr.WriteWindowUpdate(cs.ID, uint32(sendStream)) } cc.bw.Flush() cc.wmu.Unlock() diff --git a/http2/transport_test.go b/http2/transport_test.go index 00776adfdb..5adef42922 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -835,6 +835,55 @@ func (ct *clientTester) readNonSettingsFrame() (Frame, error) { } } +// writeReadPing sends a PING and immediately reads the PING ACK. +// It will fail if any other unread data was pending on the connection, +// aside from SETTINGS frames. +func (ct *clientTester) writeReadPing() error { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + if err := ct.fr.WritePing(false, data); err != nil { + return fmt.Errorf("Error writing PING: %v", err) + } + f, err := ct.readNonSettingsFrame() + if err != nil { + return err + } + p, ok := f.(*PingFrame) + if !ok { + return fmt.Errorf("got a %v, want a PING ACK", f) + } + if p.Flags&FlagPingAck == 0 { + return fmt.Errorf("got a PING, want a PING ACK") + } + if p.Data != data { + return fmt.Errorf("got PING data = %x, want %x", p.Data, data) + } + return nil +} + +func (ct *clientTester) inflowWindow(streamID uint32) int32 { + pool := ct.tr.connPoolOrDef.(*clientConnPool) + pool.mu.Lock() + defer pool.mu.Unlock() + if n := len(pool.keys); n != 1 { + ct.t.Errorf("clientConnPool contains %v keys, expected 1", n) + return -1 + } + for cc := range pool.keys { + cc.mu.Lock() + defer cc.mu.Unlock() + if streamID == 0 { + return cc.inflow.avail + cc.inflow.unsent + } + cs := cc.streams[streamID] + if cs == nil { + ct.t.Errorf("no stream with id %v", streamID) + return -1 + } + return cs.inflow.avail + cs.inflow.unsent + } + return -1 +} + func (ct *clientTester) cleanup() { ct.tr.CloseIdleConnections() @@ -2905,22 +2954,17 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { ct := newClientTester(t) - clientClosed := make(chan struct{}) - serverWroteFirstByte := make(chan struct{}) - ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) if err != nil { return err } - <-serverWroteFirstByte if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) } res.Body.Close() // leaving 4999 bytes unread - close(clientClosed) return nil } @@ -2955,6 +2999,7 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { EndStream: false, BlockFragment: buf.Bytes(), }) + initialInflow := ct.inflowWindow(0) // Two cases: // - Send one DATA frame with 5000 bytes. @@ -2963,50 +3008,63 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { // In both cases, the client should consume one byte of data, // refund that byte, then refund the following 4999 bytes. // - // In the second case, the server waits for the client connection to - // close before seconding the second DATA frame. This tests the case + // In the second case, the server waits for the client to reset the + // stream before sending the second DATA frame. This tests the case // where the client receives a DATA frame after it has reset the stream. if oneDataFrame { ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) - close(serverWroteFirstByte) - <-clientClosed } else { ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) - close(serverWroteFirstByte) - <-clientClosed - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) } - waitingFor := "RSTStreamFrame" - sawRST := false - sawWUF := false - for !sawRST && !sawWUF { - f, err := ct.fr.ReadFrame() + wantRST := true + wantWUF := true + if !oneDataFrame { + wantWUF = false // flow control update is small, and will not be sent + } + for wantRST || wantWUF { + f, err := ct.readNonSettingsFrame() if err != nil { - return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err) + return err } switch f := f.(type) { - case *SettingsFrame: case *RSTStreamFrame: - if sawRST { - return fmt.Errorf("saw second RSTStreamFrame: %v", summarizeFrame(f)) + if !wantRST { + return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } if f.ErrCode != ErrCodeCancel { return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) } - sawRST = true + wantRST = false case *WindowUpdateFrame: - if sawWUF { - return fmt.Errorf("saw second WindowUpdateFrame: %v", summarizeFrame(f)) + if !wantWUF { + return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } - if f.Increment != 4999 { + if f.Increment != 5000 { return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) } - sawWUF = true + wantWUF = false default: return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } } + if !oneDataFrame { + ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) + f, err := ct.readNonSettingsFrame() + if err != nil { + return err + } + wuf, ok := f.(*WindowUpdateFrame) + if !ok || wuf.Increment != 5000 { + return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f)) + } + } + if err := ct.writeReadPing(); err != nil { + return err + } + if got, want := ct.inflowWindow(0), initialInflow; got != want { + return fmt.Errorf("connection flow tokens = %v, want %v", got, want) + } return nil } ct.run() @@ -3133,6 +3191,8 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { break } + initialConnWindow := ct.inflowWindow(0) + var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) @@ -3143,24 +3203,18 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { EndStream: false, BlockFragment: buf.Bytes(), }) + initialStreamWindow := ct.inflowWindow(hf.StreamID) pad := make([]byte, 5) ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream - - f, err := ct.readNonSettingsFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err) - } - wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding - if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 { - return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) + if err := ct.writeReadPing(); err != nil { + return err } - - f, err = ct.readNonSettingsFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err) + // Padding flow control should have been returned. + if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want { + t.Errorf("conn inflow window = %v, want %v", got, want) } - if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 { - return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) + if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want { + t.Errorf("stream inflow window = %v, want %v", got, want) } unblockClient <- true return nil diff --git a/icmp/multipart.go b/icmp/multipart.go index 5f36675594..c7b72bf3dd 100644 --- a/icmp/multipart.go +++ b/icmp/multipart.go @@ -33,7 +33,7 @@ func multipartMessageBodyDataLen(proto int, withOrigDgram bool, b []byte, exts [ } // multipartMessageOrigDatagramLen takes b as an original datagram, -// and returns a required length for a padded orignal datagram in wire +// and returns a required length for a padded original datagram in wire // format. func multipartMessageOrigDatagramLen(proto int, b []byte) int { roundup := func(b []byte, align int) int { diff --git a/internal/sockstest/server.go b/internal/sockstest/server.go index dc2fa67c5e..c25a82f12a 100644 --- a/internal/sockstest/server.go +++ b/internal/sockstest/server.go @@ -46,7 +46,7 @@ func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) { return []byte{byte(ver), byte(m)}, nil } -// A CmdRequest repesents a command request. +// A CmdRequest represents a command request. type CmdRequest struct { Version int Cmd socks.Command @@ -120,12 +120,12 @@ func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) return b, nil } -// A Server repesents a server for handshake testing. +// A Server represents a server for handshake testing. type Server struct { ln net.Listener } -// Addr rerurns a server address. +// Addr returns a server address. func (s *Server) Addr() net.Addr { return s.ln.Addr() } diff --git a/ipv4/multicastlistener_test.go b/ipv4/multicastlistener_test.go index 534ded6793..77bad6676c 100644 --- a/ipv4/multicastlistener_test.go +++ b/ipv4/multicastlistener_test.go @@ -142,7 +142,7 @@ func TestUDPPerInterfaceSinglePacketConnWithSingleGroupListener(t *testing.T) { } c, err := net.ListenPacket("udp4", net.JoinHostPort(ip.String(), port)) // unicast address with non-reusable port if err != nil { - // The listen may fail when the serivce is + // The listen may fail when the service is // already in use, but it's fine because the // purpose of this is not to test the // bookkeeping of IP control block inside the diff --git a/ipv6/bpf_test.go b/ipv6/bpf_test.go index e249e1c923..c43ddd02ec 100644 --- a/ipv6/bpf_test.go +++ b/ipv6/bpf_test.go @@ -19,8 +19,8 @@ func TestBPF(t *testing.T) { if runtime.GOOS != "linux" { t.Skipf("not supported on %s", runtime.GOOS) } - if !nettest.SupportsIPv6() { - t.Skip("ipv6 is not supported") + if _, err := nettest.RoutedInterface("ip6", net.FlagUp|net.FlagLoopback); err != nil { + t.Skip("ipv6 is not enabled for loopback interface") } l, err := net.ListenPacket("udp6", "[::1]:0") diff --git a/ipv6/dgramopt.go b/ipv6/dgramopt.go index 1f422e71dc..846f0e1f9c 100644 --- a/ipv6/dgramopt.go +++ b/ipv6/dgramopt.go @@ -245,7 +245,7 @@ func (c *dgramOpt) Checksum() (on bool, offset int, err error) { return true, offset, nil } -// SetChecksum enables the kernel checksum processing. If on is ture, +// SetChecksum enables the kernel checksum processing. If on is true, // the offset should be an offset in bytes into the data of where the // checksum field is located. func (c *dgramOpt) SetChecksum(on bool, offset int) error { diff --git a/ipv6/multicastlistener_test.go b/ipv6/multicastlistener_test.go index 353327e017..a4dc86342e 100644 --- a/ipv6/multicastlistener_test.go +++ b/ipv6/multicastlistener_test.go @@ -142,7 +142,7 @@ func TestUDPPerInterfaceSinglePacketConnWithSingleGroupListener(t *testing.T) { } c, err := net.ListenPacket("udp6", net.JoinHostPort(ip.String()+"%"+ifi.Name, port)) // unicast address with non-reusable port if err != nil { - // The listen may fail when the serivce is + // The listen may fail when the service is // already in use, but it's fine because the // purpose of this is not to test the // bookkeeping of IP control block inside the diff --git a/ipv6/readwrite_test.go b/ipv6/readwrite_test.go index e8db1199e1..131b1904c5 100644 --- a/ipv6/readwrite_test.go +++ b/ipv6/readwrite_test.go @@ -223,10 +223,10 @@ func TestPacketConnConcurrentReadWriteUnicastUDP(t *testing.T) { case "fuchsia", "hurd", "js", "nacl", "plan9", "windows": t.Skipf("not supported on %s", runtime.GOOS) } - if !nettest.SupportsIPv6() { - t.Skip("ipv6 is not supported") + ifi, err := nettest.RoutedInterface("ip6", net.FlagUp|net.FlagLoopback) + if err != nil { + t.Skip("ipv6 is not enabled for loopback interface") } - c, err := nettest.NewLocalPacketListener("udp6") if err != nil { t.Fatal(err) @@ -236,7 +236,6 @@ func TestPacketConnConcurrentReadWriteUnicastUDP(t *testing.T) { defer p.Close() dst := c.LocalAddr() - ifi, _ := nettest.RoutedInterface("ip6", net.FlagUp|net.FlagLoopback) cf := ipv6.FlagTrafficClass | ipv6.FlagHopLimit | ipv6.FlagSrc | ipv6.FlagDst | ipv6.FlagInterface | ipv6.FlagPathMTU wb := []byte("HELLO-R-U-THERE") diff --git a/ipv6/sockopt_test.go b/ipv6/sockopt_test.go index 3305cfc114..ab0d2e4e51 100644 --- a/ipv6/sockopt_test.go +++ b/ipv6/sockopt_test.go @@ -20,8 +20,9 @@ func TestConnInitiatorPathMTU(t *testing.T) { case "fuchsia", "hurd", "js", "nacl", "plan9", "windows", "zos": t.Skipf("not supported on %s", runtime.GOOS) } - if !nettest.SupportsIPv6() { - t.Skip("ipv6 is not supported") + + if _, err := nettest.RoutedInterface("ip6", net.FlagUp|net.FlagLoopback); err != nil { + t.Skip("ipv6 is not enabled for loopback interface") } ln, err := net.Listen("tcp6", "[::1]:0") @@ -53,8 +54,8 @@ func TestConnResponderPathMTU(t *testing.T) { case "fuchsia", "hurd", "js", "nacl", "plan9", "windows", "zos": t.Skipf("not supported on %s", runtime.GOOS) } - if !nettest.SupportsIPv6() { - t.Skip("ipv6 is not supported") + if _, err := nettest.RoutedInterface("ip6", net.FlagUp|net.FlagLoopback); err != nil { + t.Skip("ipv6 is not enabled for loopback interface") } ln, err := net.Listen("tcp6", "[::1]:0") diff --git a/ipv6/unicast_test.go b/ipv6/unicast_test.go index fe1d44dfa7..e03c2cd336 100644 --- a/ipv6/unicast_test.go +++ b/ipv6/unicast_test.go @@ -23,8 +23,8 @@ func TestPacketConnReadWriteUnicastUDP(t *testing.T) { case "fuchsia", "hurd", "js", "nacl", "plan9", "windows": t.Skipf("not supported on %s", runtime.GOOS) } - if !nettest.SupportsIPv6() { - t.Skip("ipv6 is not supported") + if _, err := nettest.RoutedInterface("ip6", net.FlagUp|net.FlagLoopback); err != nil { + t.Skip("ipv6 is not enabled for loopback interface") } c, err := nettest.NewLocalPacketListener("udp6") diff --git a/ipv6/unicastsockopt_test.go b/ipv6/unicastsockopt_test.go index ac0daf2856..c3abe2d14d 100644 --- a/ipv6/unicastsockopt_test.go +++ b/ipv6/unicastsockopt_test.go @@ -19,8 +19,8 @@ func TestConnUnicastSocketOptions(t *testing.T) { case "fuchsia", "hurd", "js", "nacl", "plan9", "windows": t.Skipf("not supported on %s", runtime.GOOS) } - if !nettest.SupportsIPv6() { - t.Skip("ipv6 is not supported") + if _, err := nettest.RoutedInterface("ip6", net.FlagUp|net.FlagLoopback); err != nil { + t.Skip("ipv6 is not enabled for loopback interface") } ln, err := net.Listen("tcp6", "[::1]:0") @@ -64,8 +64,8 @@ func TestPacketConnUnicastSocketOptions(t *testing.T) { case "fuchsia", "hurd", "js", "nacl", "plan9", "windows": t.Skipf("not supported on %s", runtime.GOOS) } - if !nettest.SupportsIPv6() { - t.Skip("ipv6 is not supported") + if _, err := nettest.RoutedInterface("ip6", net.FlagUp|net.FlagLoopback); err != nil { + t.Skip("ipv6 is not enabled for loopback interface") } ok := nettest.SupportsRawSocket() diff --git a/nettest/nettest.go b/nettest/nettest.go index 6918f2c362..510555ac28 100644 --- a/nettest/nettest.go +++ b/nettest/nettest.go @@ -20,11 +20,13 @@ import ( ) var ( - stackOnce sync.Once - ipv4Enabled bool - ipv6Enabled bool - unStrmDgramEnabled bool - rawSocketSess bool + stackOnce sync.Once + ipv4Enabled bool + canListenTCP4OnLoopback bool + ipv6Enabled bool + canListenTCP6OnLoopback bool + unStrmDgramEnabled bool + rawSocketSess bool aLongTimeAgo = time.Unix(233431200, 0) neverTimeout = time.Time{} @@ -34,13 +36,19 @@ var ( ) func probeStack() { + if _, err := RoutedInterface("ip4", net.FlagUp); err == nil { + ipv4Enabled = true + } if ln, err := net.Listen("tcp4", "127.0.0.1:0"); err == nil { ln.Close() - ipv4Enabled = true + canListenTCP4OnLoopback = true + } + if _, err := RoutedInterface("ip6", net.FlagUp); err == nil { + ipv6Enabled = true } if ln, err := net.Listen("tcp6", "[::1]:0"); err == nil { ln.Close() - ipv6Enabled = true + canListenTCP6OnLoopback = true } rawSocketSess = supportsRawSocket() switch runtime.GOOS { @@ -154,22 +162,23 @@ func TestableAddress(network, address string) bool { // The provided network must be "tcp", "tcp4", "tcp6", "unix" or // "unixpacket". func NewLocalListener(network string) (net.Listener, error) { + stackOnce.Do(probeStack) switch network { case "tcp": - if SupportsIPv4() { + if canListenTCP4OnLoopback { if ln, err := net.Listen("tcp4", "127.0.0.1:0"); err == nil { return ln, nil } } - if SupportsIPv6() { + if canListenTCP6OnLoopback { return net.Listen("tcp6", "[::1]:0") } case "tcp4": - if SupportsIPv4() { + if canListenTCP4OnLoopback { return net.Listen("tcp4", "127.0.0.1:0") } case "tcp6": - if SupportsIPv6() { + if canListenTCP6OnLoopback { return net.Listen("tcp6", "[::1]:0") } case "unix", "unixpacket": @@ -187,22 +196,23 @@ func NewLocalListener(network string) (net.Listener, error) { // // The provided network must be "udp", "udp4", "udp6" or "unixgram". func NewLocalPacketListener(network string) (net.PacketConn, error) { + stackOnce.Do(probeStack) switch network { case "udp": - if SupportsIPv4() { + if canListenTCP4OnLoopback { if c, err := net.ListenPacket("udp4", "127.0.0.1:0"); err == nil { return c, nil } } - if SupportsIPv6() { + if canListenTCP6OnLoopback { return net.ListenPacket("udp6", "[::1]:0") } case "udp4": - if SupportsIPv4() { + if canListenTCP4OnLoopback { return net.ListenPacket("udp4", "127.0.0.1:0") } case "udp6": - if SupportsIPv6() { + if canListenTCP6OnLoopback { return net.ListenPacket("udp6", "[::1]:0") } case "unixgram": diff --git a/netutil/listen.go b/netutil/listen.go index d5dfbab24f..f8b779ea27 100644 --- a/netutil/listen.go +++ b/netutil/listen.go @@ -29,7 +29,7 @@ type limitListener struct { } // acquire acquires the limiting semaphore. Returns true if successfully -// accquired, false if the listener is closed and the semaphore is not +// acquired, false if the listener is closed and the semaphore is not // acquired. func (l *limitListener) acquire() bool { select { diff --git a/trace/histogram.go b/trace/histogram.go index 9bf4286c79..d6c71101e4 100644 --- a/trace/histogram.go +++ b/trace/histogram.go @@ -32,7 +32,7 @@ type histogram struct { valueCount int64 // number of values recorded for single value } -// AddMeasurement records a value measurement observation to the histogram. +// addMeasurement records a value measurement observation to the histogram. func (h *histogram) addMeasurement(value int64) { // TODO: assert invariant h.sum += value diff --git a/webdav/webdav.go b/webdav/webdav.go index 8d0f1b2aed..add2bcd67c 100644 --- a/webdav/webdav.go +++ b/webdav/webdav.go @@ -655,7 +655,7 @@ func handlePropfindError(err error, info os.FileInfo) error { // We need to be careful with other errors: there is no way to abort the xml stream // part way through while returning a valid PROPFIND response. Returning only half // the data would be misleading, but so would be returning results tainted by errors. - // The curent behaviour by returning an error here leads to the stream being aborted, + // The current behaviour by returning an error here leads to the stream being aborted, // and the parent http server complaining about writing a spurious header. We should // consider further enhancing this error handling to more gracefully fail, or perhaps // buffer the entire response until we've walked the tree. diff --git a/websocket/hybi.go b/websocket/hybi.go index 8cffdd16c9..48a069e190 100644 --- a/websocket/hybi.go +++ b/websocket/hybi.go @@ -369,7 +369,7 @@ func generateNonce() (nonce []byte) { return } -// removeZone removes IPv6 zone identifer from host. +// removeZone removes IPv6 zone identifier from host. // E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" func removeZone(host string) string { if !strings.HasPrefix(host, "[") { diff --git a/websocket/websocket.go b/websocket/websocket.go index ea422e110d..90a2257cd5 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -5,11 +5,10 @@ // Package websocket implements a client and server for the WebSocket protocol // as specified in RFC 6455. // -// This package currently lacks some features found in alternative -// and more actively maintained WebSocket packages: +// This package currently lacks some features found in an alternative +// and more actively maintained WebSocket package: // -// https://godoc.org/github.com/gorilla/websocket -// https://godoc.org/nhooyr.io/websocket +// https://pkg.go.dev/nhooyr.io/websocket package websocket // import "golang.org/x/net/websocket" import (