Skip to content

Commit 1200707

Browse files
committed
Ensure connection is closed at all error points
Closes #191
1 parent 43c4dc0 commit 1200707

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

read.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
304304
defer c.readMu.unlock()
305305

306306
if !c.msgReader.fin {
307-
return 0, nil, errors.New("previous message not read to completion")
307+
err = errors.New("previous message not read to completion")
308+
c.close(fmt.Errorf("failed to get reader: %w", err))
309+
return 0, nil, err
308310
}
309311

310312
h, err := c.readLoop(ctx)
@@ -361,21 +363,9 @@ func (mr *msgReader) setFrame(h header) {
361363
}
362364

363365
func (mr *msgReader) Read(p []byte) (n int, err error) {
364-
defer func() {
365-
if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
366-
err = io.EOF
367-
}
368-
if errors.Is(err, io.EOF) {
369-
err = io.EOF
370-
mr.putFlateReader()
371-
return
372-
}
373-
errd.Wrap(&err, "failed to read")
374-
}()
375-
376366
err = mr.c.readMu.lock(mr.ctx)
377367
if err != nil {
378-
return 0, err
368+
return 0, fmt.Errorf("failed to read: %w", err)
379369
}
380370
defer mr.c.readMu.unlock()
381371

@@ -384,6 +374,14 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
384374
p = p[:n]
385375
mr.dict.write(p)
386376
}
377+
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
378+
mr.putFlateReader()
379+
return n, io.EOF
380+
}
381+
if err != nil {
382+
err = fmt.Errorf("failed to read: %w", err)
383+
mr.c.close(err)
384+
}
387385
return n, err
388386
}
389387

write.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,16 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
155155

156156
// Write writes the given bytes to the WebSocket connection.
157157
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
158-
defer errd.Wrap(&err, "failed to write")
159-
160158
mw.writeMu.Lock()
161159
defer mw.writeMu.Unlock()
162160

161+
defer func() {
162+
err = fmt.Errorf("failed to write: %w", err)
163+
if err != nil {
164+
mw.c.close(err)
165+
}
166+
}()
167+
163168
if mw.c.flate() {
164169
// Only enables flate if the length crosses the
165170
// threshold on the first frame
@@ -230,8 +235,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
230235
}
231236

232237
// frame handles all writes to the connection.
233-
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) {
234-
err := c.writeFrameMu.lock(ctx)
238+
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
239+
err = c.writeFrameMu.lock(ctx)
235240
if err != nil {
236241
return 0, err
237242
}
@@ -243,6 +248,12 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
243248
case c.writeTimeout <- ctx:
244249
}
245250

251+
defer func() {
252+
if err != nil {
253+
c.close(fmt.Errorf("failed to write frame: %w", err))
254+
}
255+
}()
256+
246257
c.writeHeader.fin = fin
247258
c.writeHeader.opcode = opcode
248259
c.writeHeader.payloadLength = int64(len(p))

0 commit comments

Comments
 (0)