Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 11bda98

Browse files
FrauElstermafredri
andauthored
fix: avoid writing messages after close and improve handshake (#476)
Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
1 parent 1253b77 commit 11bda98

File tree

5 files changed

+252
-65
lines changed

5 files changed

+252
-65
lines changed

‎close.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode {
100100
func (c *Conn) Close(code StatusCode, reason string) (err error) {
101101
defer errd.Wrap(&err, "failed to close WebSocket")
102102

103-
if !c.casClosing() {
103+
if c.casClosing() {
104104
err = c.waitGoroutines()
105105
if err != nil {
106106
return err
@@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) {
133133
func (c *Conn) CloseNow() (err error) {
134134
defer errd.Wrap(&err, "failed to immediately close WebSocket")
135135

136-
if !c.casClosing() {
136+
if c.casClosing() {
137137
err = c.waitGoroutines()
138138
if err != nil {
139139
return err
@@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) {
329329
}
330330

331331
func (c *Conn) casClosing() bool {
332-
c.closeMu.Lock()
333-
defer c.closeMu.Unlock()
334-
if !c.closing {
335-
c.closing = true
336-
return true
337-
}
338-
return false
332+
return c.closing.Swap(true)
339333
}
340334

341335
func (c *Conn) isClosed() bool {

‎conn.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,19 @@ type Conn struct {
6969
writeHeaderBuf [8]byte
7070
writeHeader header
7171

72+
// Close handshake state.
73+
closeStateMu sync.RWMutex
74+
closeReceivedErr error
75+
closeSentErr error
76+
77+
// CloseRead state.
7278
closeReadMu sync.Mutex
7379
closeReadCtx context.Context
7480
closeReadDone chan struct{}
7581

82+
closing atomic.Bool
83+
closeMu sync.Mutex // Protects following.
7684
closed chan struct{}
77-
closeMu sync.Mutex
78-
closing bool
7985

8086
pingCounter atomic.Int64
8187
activePingsMu sync.Mutex

‎conn_test.go

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"errors"
99
"fmt"
1010
"io"
11+
"net"
1112
"net/http"
1213
"net/http/httptest"
1314
"os"
@@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
460461
}
461462

462463
func BenchmarkConn(b *testing.B) {
463-
varbenchCases = []struct {
464+
benchCases := []struct {
464465
name string
465466
mode websocket.CompressionMode
466467
}{
@@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
625626
}()
626627
}
627628
}
629+
630+
func TestConnClosePropagation(t *testing.T) {
631+
t.Parallel()
632+
633+
want := []byte("hello")
634+
keepWriting := func(c *websocket.Conn) <-chan error {
635+
return xsync.Go(func() error {
636+
for {
637+
err := c.Write(context.Background(), websocket.MessageText, want)
638+
if err != nil {
639+
return err
640+
}
641+
}
642+
})
643+
}
644+
keepReading := func(c *websocket.Conn) <-chan error {
645+
return xsync.Go(func() error {
646+
for {
647+
_, got, err := c.Read(context.Background())
648+
if err != nil {
649+
return err
650+
}
651+
if !bytes.Equal(want, got) {
652+
return fmt.Errorf("unexpected message: want %q, got %q", want, got)
653+
}
654+
}
655+
})
656+
}
657+
checkReadErr := func(t *testing.T, err error) {
658+
// Check read error (output depends on when read is called in relation to connection closure).
659+
var ce websocket.CloseError
660+
if errors.As(err, &ce) {
661+
assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
662+
} else {
663+
assert.ErrorIs(t, net.ErrClosed, err)
664+
}
665+
}
666+
checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
667+
for _, c := range conn {
668+
// Check write error.
669+
err := c.Write(context.Background(), websocket.MessageText, want)
670+
assert.ErrorIs(t, net.ErrClosed, err)
671+
672+
_, _, err = c.Read(context.Background())
673+
checkReadErr(t, err)
674+
}
675+
}
676+
677+
t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
678+
tt, this, other := newConnTest(t, nil, nil)
679+
680+
_ = this.CloseRead(tt.ctx)
681+
thisWriteErr := keepWriting(this)
682+
683+
_, got, err := other.Read(tt.ctx)
684+
assert.Success(t, err)
685+
assert.Equal(t, "msg", want, got)
686+
687+
err = other.Close(websocket.StatusNormalClosure, "")
688+
assert.Success(t, err)
689+
690+
select {
691+
case err := <-thisWriteErr:
692+
assert.ErrorIs(t, net.ErrClosed, err)
693+
case <-tt.ctx.Done():
694+
t.Fatal(tt.ctx.Err())
695+
}
696+
697+
checkConnErrs(t, this, other)
698+
})
699+
t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
700+
tt, this, other := newConnTest(t, nil, nil)
701+
702+
_ = this.CloseRead(tt.ctx)
703+
thisWriteErr := keepWriting(this)
704+
otherReadErr := keepReading(other)
705+
706+
err := this.Close(websocket.StatusNormalClosure, "")
707+
assert.Success(t, err)
708+
709+
select {
710+
case err := <-thisWriteErr:
711+
assert.ErrorIs(t, net.ErrClosed, err)
712+
case <-tt.ctx.Done():
713+
t.Fatal(tt.ctx.Err())
714+
}
715+
716+
select {
717+
case err := <-otherReadErr:
718+
checkReadErr(t, err)
719+
case <-tt.ctx.Done():
720+
t.Fatal(tt.ctx.Err())
721+
}
722+
723+
checkConnErrs(t, this, other)
724+
})
725+
t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
726+
tt, this, other := newConnTest(t, nil, nil)
727+
728+
_ = other.CloseRead(tt.ctx)
729+
errs := keepReading(this)
730+
731+
err := other.Write(tt.ctx, websocket.MessageText, want)
732+
assert.Success(t, err)
733+
734+
err = other.Close(websocket.StatusNormalClosure, "")
735+
assert.Success(t, err)
736+
737+
select {
738+
case err := <-errs:
739+
checkReadErr(t, err)
740+
case <-tt.ctx.Done():
741+
t.Fatal(tt.ctx.Err())
742+
}
743+
744+
checkConnErrs(t, this, other)
745+
})
746+
t.Run("CloseThisSideDuringRead", func(t *testing.T) {
747+
tt, this, other := newConnTest(t, nil, nil)
748+
749+
thisReadErr := keepReading(this)
750+
otherReadErr := keepReading(other)
751+
752+
err := other.Write(tt.ctx, websocket.MessageText, want)
753+
assert.Success(t, err)
754+
755+
err = this.Close(websocket.StatusNormalClosure, "")
756+
assert.Success(t, err)
757+
758+
select {
759+
case err := <-thisReadErr:
760+
checkReadErr(t, err)
761+
case <-tt.ctx.Done():
762+
t.Fatal(tt.ctx.Err())
763+
}
764+
765+
select {
766+
case err := <-otherReadErr:
767+
checkReadErr(t, err)
768+
case <-tt.ctx.Done():
769+
t.Fatal(tt.ctx.Err())
770+
}
771+
772+
checkConnErrs(t, this, other)
773+
})
774+
}

‎read.go

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -217,57 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
217217
}
218218
}
219219

220-
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
220+
// prepareRead sets the readTimeout context and returns a done function
221+
// to be called after the read is done. It also returns an error if the
222+
// connection is closed. The reference to the error is used to assign
223+
// an error depending on if the connection closed or the context timed
224+
// out during use. Typically the referenced error is a named return
225+
// variable of the function calling this method.
226+
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
221227
select {
222228
case <-c.closed:
223-
return header{}, net.ErrClosed
229+
return nil, net.ErrClosed
224230
case c.readTimeout <- ctx:
225231
}
226232

227-
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
228-
if err != nil {
233+
done := func() {
229234
select {
230235
case <-c.closed:
231-
return header{}, net.ErrClosed
232-
case <-ctx.Done():
233-
return header{}, ctx.Err()
234-
default:
235-
return header{}, err
236+
if *err != nil {
237+
*err = net.ErrClosed
238+
}
239+
case c.readTimeout <- context.Background():
240+
}
241+
if *err != nil && ctx.Err() != nil {
242+
*err = ctx.Err()
236243
}
237244
}
238245

239-
select {
240-
case <-c.closed:
241-
return header{}, net.ErrClosed
242-
case c.readTimeout <- context.Background():
246+
c.closeStateMu.Lock()
247+
closeReceivedErr := c.closeReceivedErr
248+
c.closeStateMu.Unlock()
249+
if closeReceivedErr != nil {
250+
defer done()
251+
return nil, closeReceivedErr
243252
}
244253

245-
return h, nil
254+
return done, nil
246255
}
247256

248-
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
249-
select {
250-
case <-c.closed:
251-
return 0, net.ErrClosed
252-
case c.readTimeout <- ctx:
257+
func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
258+
readDone, err := c.prepareRead(ctx, &err)
259+
if err != nil {
260+
return header{}, err
253261
}
262+
defer readDone()
254263

255-
n, err := io.ReadFull(c.br, p)
264+
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
256265
if err != nil {
257-
select {
258-
case <-c.closed:
259-
return n, net.ErrClosed
260-
case <-ctx.Done():
261-
return n, ctx.Err()
262-
default:
263-
return n, fmt.Errorf("failed to read frame payload: %w", err)
264-
}
266+
return header{}, err
265267
}
266268

267-
select {
268-
case <-c.closed:
269-
return n, net.ErrClosed
270-
case c.readTimeout <- context.Background():
269+
return h, nil
270+
}
271+
272+
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
273+
readDone, err := c.prepareRead(ctx, &err)
274+
if err != nil {
275+
return 0, err
276+
}
277+
defer readDone()
278+
279+
n, err := io.ReadFull(c.br, p)
280+
if err != nil {
281+
return n, fmt.Errorf("failed to read frame payload: %w", err)
271282
}
272283

273284
return n, err
@@ -325,9 +336,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
325336
}
326337

327338
err = fmt.Errorf("received close frame: %w", ce)
328-
c.writeClose(ce.Code, ce.Reason)
329-
c.readMu.unlock()
330-
c.close()
339+
c.closeStateMu.Lock()
340+
c.closeReceivedErr = err
341+
closeSent := c.closeSentErr != nil
342+
c.closeStateMu.Unlock()
343+
344+
// Only unlock readMu if this connection is being closed becaue
345+
// c.close will try to acquire the readMu lock. We unlock for
346+
// writeClose as well because it may also call c.close.
347+
if !closeSent {
348+
c.readMu.unlock()
349+
_ = c.writeClose(ce.Code, ce.Reason)
350+
}
351+
if !c.casClosing() {
352+
c.readMu.unlock()
353+
_ = c.close()
354+
}
331355
return err
332356
}
333357

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /