Skip to content

Commit b359fb9

Browse files
spikecurtisstirby
authored andcommitted
fix: fix goroutine leak in log streaming over websocket (#15709)
fixes #14881 Our handlers for streaming logs don't read from the websocket. We don't allow the client to send us any data, but the websocket library we use requires reading from the websocket to properly handle pings and closing. Not doing so can [can cause the websocket to hang on write](coder/websocket#405), leaking go routines which were noticed in #14881. This fixes the issue, and in process refactors our log streaming to a encoder/decoder package which provides generic types for sending JSON over websocket. I'd also like for us to upgrade to the latest https://github.com/coder/websocket but we should also upgrade our tailscale fork before doing so to avoid including two copies of the websocket library. (cherry picked from commit 148a5a3)
1 parent 54f7605 commit b359fb9

File tree

6 files changed

+134
-78
lines changed

6 files changed

+134
-78
lines changed

coderd/provisionerjobs.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"nhooyr.io/websocket"
1616

1717
"cdr.dev/slog"
18+
"github.com/coder/coder/v2/codersdk/wsjson"
1819

1920
"github.com/coder/coder/v2/coderd/database"
2021
"github.com/coder/coder/v2/coderd/database/db2sdk"
@@ -312,6 +313,7 @@ type logFollower struct {
312313
r *http.Request
313314
rw http.ResponseWriter
314315
conn *websocket.Conn
316+
enc *wsjson.Encoder[codersdk.ProvisionerJobLog]
315317

316318
jobID uuid.UUID
317319
after int64
@@ -391,6 +393,7 @@ func (f *logFollower) follow() {
391393
}
392394
defer f.conn.Close(websocket.StatusNormalClosure, "done")
393395
go httpapi.Heartbeat(f.ctx, f.conn)
396+
f.enc = wsjson.NewEncoder[codersdk.ProvisionerJobLog](f.conn, websocket.MessageText)
394397

395398
// query for logs once right away, so we can get historical data from before
396399
// subscription
@@ -488,11 +491,7 @@ func (f *logFollower) query() error {
488491
return xerrors.Errorf("error fetching logs: %w", err)
489492
}
490493
for _, log := range logs {
491-
logB, err := json.Marshal(convertProvisionerJobLog(log))
492-
if err != nil {
493-
return xerrors.Errorf("error marshaling log: %w", err)
494-
}
495-
err = f.conn.Write(f.ctx, websocket.MessageText, logB)
494+
err := f.enc.Encode(convertProvisionerJobLog(log))
496495
if err != nil {
497496
return xerrors.Errorf("error writing to websocket: %w", err)
498497
}

coderd/workspaceagents.go

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import (
3939
"github.com/coder/coder/v2/codersdk"
4040
"github.com/coder/coder/v2/codersdk/agentsdk"
4141
"github.com/coder/coder/v2/codersdk/workspacesdk"
42+
"github.com/coder/coder/v2/codersdk/wsjson"
4243
"github.com/coder/coder/v2/tailnet"
4344
"github.com/coder/coder/v2/tailnet/proto"
4445
)
@@ -396,11 +397,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
396397
}
397398
go httpapi.Heartbeat(ctx, conn)
398399

399-
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
400-
defer wsNetConn.Close() // Also closes conn.
400+
encoder := wsjson.NewEncoder[[]codersdk.WorkspaceAgentLog](conn, websocket.MessageText)
401+
defer encoder.Close(websocket.StatusNormalClosure)
401402

402-
// The Go stdlib JSON encoder appends a newline character after message write.
403-
encoder := json.NewEncoder(wsNetConn)
404403
err = encoder.Encode(convertWorkspaceAgentLogs(logs))
405404
if err != nil {
406405
return
@@ -740,16 +739,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
740739
})
741740
return
742741
}
743-
ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary)
744-
defer nconn.Close()
745-
746-
// Slurp all packets from the connection into io.Discard so pongs get sent
747-
// by the websocket package. We don't do any reads ourselves so this is
748-
// necessary.
749-
go func() {
750-
_, _ = io.Copy(io.Discard, nconn)
751-
_ = nconn.Close()
752-
}()
742+
encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary)
743+
defer encoder.Close(websocket.StatusGoingAway)
753744

754745
go func(ctx context.Context) {
755746
// TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout?
@@ -767,7 +758,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
767758
err := ws.Ping(ctx)
768759
cancel()
769760
if err != nil {
770-
_ = nconn.Close()
761+
_ = ws.Close(websocket.StatusGoingAway, "ping failed")
771762
return
772763
}
773764
}
@@ -780,9 +771,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
780771
for {
781772
derpMap := api.DERPMap()
782773
if lastDERPMap == nil || !tailnet.CompareDERPMaps(lastDERPMap, derpMap) {
783-
err := json.NewEncoder(nconn).Encode(derpMap)
774+
err := encoder.Encode(derpMap)
784775
if err != nil {
785-
_ = nconn.Close()
786776
return
787777
}
788778
lastDERPMap = derpMap

codersdk/provisionerdaemons.go

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919

2020
"github.com/coder/coder/v2/buildinfo"
2121
"github.com/coder/coder/v2/codersdk/drpc"
22+
"github.com/coder/coder/v2/codersdk/wsjson"
2223
"github.com/coder/coder/v2/provisionerd/proto"
2324
"github.com/coder/coder/v2/provisionerd/runner"
2425
)
@@ -161,36 +162,8 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
161162
}
162163
return nil, nil, ReadBodyAsError(res)
163164
}
164-
logs := make(chan ProvisionerJobLog)
165-
closed := make(chan struct{})
166-
go func() {
167-
defer close(closed)
168-
defer close(logs)
169-
defer conn.Close(websocket.StatusGoingAway, "")
170-
var log ProvisionerJobLog
171-
for {
172-
msgType, msg, err := conn.Read(ctx)
173-
if err != nil {
174-
return
175-
}
176-
if msgType != websocket.MessageText {
177-
return
178-
}
179-
err = json.Unmarshal(msg, &log)
180-
if err != nil {
181-
return
182-
}
183-
select {
184-
case <-ctx.Done():
185-
return
186-
case logs <- log:
187-
}
188-
}
189-
}()
190-
return logs, closeFunc(func() error {
191-
<-closed
192-
return nil
193-
}), nil
165+
d := wsjson.NewDecoder[ProvisionerJobLog](conn, websocket.MessageText, c.logger)
166+
return d.Chan(), d, nil
194167
}
195168

196169
// ServeProvisionerDaemonRequest are the parameters to call ServeProvisionerDaemon with

codersdk/workspaceagents.go

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"nhooyr.io/websocket"
1616

1717
"github.com/coder/coder/v2/coderd/tracing"
18+
"github.com/coder/coder/v2/codersdk/wsjson"
1819
)
1920

2021
type WorkspaceAgentStatus string
@@ -454,30 +455,6 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID,
454455
}
455456
return nil, nil, ReadBodyAsError(res)
456457
}
457-
logChunks := make(chan []WorkspaceAgentLog, 1)
458-
closed := make(chan struct{})
459-
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText)
460-
decoder := json.NewDecoder(wsNetConn)
461-
go func() {
462-
defer close(closed)
463-
defer close(logChunks)
464-
defer conn.Close(websocket.StatusGoingAway, "")
465-
for {
466-
var logs []WorkspaceAgentLog
467-
err = decoder.Decode(&logs)
468-
if err != nil {
469-
return
470-
}
471-
select {
472-
case <-ctx.Done():
473-
return
474-
case logChunks <- logs:
475-
}
476-
}
477-
}()
478-
return logChunks, closeFunc(func() error {
479-
_ = wsNetConn.Close()
480-
<-closed
481-
return nil
482-
}), nil
458+
d := wsjson.NewDecoder[[]WorkspaceAgentLog](conn, websocket.MessageText, c.logger)
459+
return d.Chan(), d, nil
483460
}

codersdk/wsjson/decoder.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package wsjson
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"sync/atomic"
7+
8+
"nhooyr.io/websocket"
9+
10+
"cdr.dev/slog"
11+
)
12+
13+
type Decoder[T any] struct {
14+
conn *websocket.Conn
15+
typ websocket.MessageType
16+
ctx context.Context
17+
cancel context.CancelFunc
18+
chanCalled atomic.Bool
19+
logger slog.Logger
20+
}
21+
22+
// Chan starts the decoder reading from the websocket and returns a channel for reading the
23+
// resulting values. The chan T is closed if the underlying websocket is closed, or we encounter an
24+
// error. We also close the underlying websocket if we encounter an error reading or decoding.
25+
func (d *Decoder[T]) Chan() <-chan T {
26+
if !d.chanCalled.CompareAndSwap(false, true) {
27+
panic("chan called more than once")
28+
}
29+
values := make(chan T, 1)
30+
go func() {
31+
defer close(values)
32+
defer d.conn.Close(websocket.StatusGoingAway, "")
33+
for {
34+
// we don't use d.ctx here because it only gets canceled after closing the connection
35+
// and a "connection closed" type error is more clear than context canceled.
36+
typ, b, err := d.conn.Read(context.Background())
37+
if err != nil {
38+
// might be benign like EOF, so just log at debug
39+
d.logger.Debug(d.ctx, "error reading from websocket", slog.Error(err))
40+
return
41+
}
42+
if typ != d.typ {
43+
d.logger.Error(d.ctx, "websocket type mismatch while decoding")
44+
return
45+
}
46+
var value T
47+
err = json.Unmarshal(b, &value)
48+
if err != nil {
49+
d.logger.Error(d.ctx, "error unmarshalling", slog.Error(err))
50+
return
51+
}
52+
select {
53+
case values <- value:
54+
// OK
55+
case <-d.ctx.Done():
56+
return
57+
}
58+
}
59+
}()
60+
return values
61+
}
62+
63+
// nolint: revive // complains that Encoder has the same function name
64+
func (d *Decoder[T]) Close() error {
65+
err := d.conn.Close(websocket.StatusNormalClosure, "")
66+
d.cancel()
67+
return err
68+
}
69+
70+
// NewDecoder creates a JSON-over-websocket decoder for type T, which must be deserializable from
71+
// JSON.
72+
func NewDecoder[T any](conn *websocket.Conn, typ websocket.MessageType, logger slog.Logger) *Decoder[T] {
73+
ctx, cancel := context.WithCancel(context.Background())
74+
return &Decoder[T]{conn: conn, ctx: ctx, cancel: cancel, typ: typ, logger: logger}
75+
}

codersdk/wsjson/encoder.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package wsjson
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
7+
"golang.org/x/xerrors"
8+
"nhooyr.io/websocket"
9+
)
10+
11+
type Encoder[T any] struct {
12+
conn *websocket.Conn
13+
typ websocket.MessageType
14+
}
15+
16+
func (e *Encoder[T]) Encode(v T) error {
17+
w, err := e.conn.Writer(context.Background(), e.typ)
18+
if err != nil {
19+
return xerrors.Errorf("get websocket writer: %w", err)
20+
}
21+
defer w.Close()
22+
j := json.NewEncoder(w)
23+
err = j.Encode(v)
24+
if err != nil {
25+
return xerrors.Errorf("encode json: %w", err)
26+
}
27+
return nil
28+
}
29+
30+
func (e *Encoder[T]) Close(c websocket.StatusCode) error {
31+
return e.conn.Close(c, "")
32+
}
33+
34+
// NewEncoder creates a JSON-over websocket encoder for the type T, which must be JSON-serializable.
35+
// You may then call Encode() to send objects over the websocket. Creating an Encoder closes the
36+
// websocket for reading, turning it into a unidirectional write stream of JSON-encoded objects.
37+
func NewEncoder[T any](conn *websocket.Conn, typ websocket.MessageType) *Encoder[T] {
38+
// Here we close the websocket for reading, so that the websocket library will handle pings and
39+
// close frames.
40+
_ = conn.CloseRead(context.Background())
41+
return &Encoder[T]{conn: conn, typ: typ}
42+
}

0 commit comments

Comments
 (0)