From 7fabce4be2489efb092ce212a8a650b45d490612 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Tue, 7 Sep 2021 14:47:58 -0500 Subject: [PATCH] chore: fix dial test assertions --- wsnet/dial.go | 4 ++-- wsnet/dial_test.go | 33 ++++++++++++++++++--------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/wsnet/dial.go b/wsnet/dial.go index 550735b5..c0f850b4 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -290,8 +290,8 @@ func (d *Dialer) negotiate(ctx context.Context) (err error) { return <-errCh } -// ActiveConnections returns the amount of active connections. -// DialContext opens a connection, and close will end it. +// ActiveConnections returns the amount of active connections. DialContext +// opens a connection, and close will end it. func (d *Dialer) activeConnections() int { stats, ok := d.rtc.GetStats().GetConnectionStats(d.rtc) if !ok { diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index a9b09417..4aa6f195 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -11,11 +11,12 @@ import ( "testing" "time" - "cdr.dev/slog/sloggers/slogtest" "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" ) func ExampleDial_basic() { @@ -260,33 +261,35 @@ func TestDial(t *testing.T) { log := slogtest.Make(t, nil) listener, err := net.Listen("tcp", "0.0.0.0:0") - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) + go func() { _, _ = listener.Accept() }() + connectAddr, listenAddr := createDumbBroker(t) _, err = Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) dialer, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{ Log: &log, }, nil) - if err != nil { - t.Error(err) - } - conn, _ := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) + require.NoError(t, err) + + conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) + require.NoError(t, err) assert.Equal(t, 1, dialer.activeConnections()) + _ = conn.Close() assert.Equal(t, 0, dialer.activeConnections()) - _, _ = dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) - conn, _ = dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) + + _, err = dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) + require.NoError(t, err) + + conn, err = dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) + require.NoError(t, err) assert.Equal(t, 2, dialer.activeConnections()) + _ = conn.Close() assert.Equal(t, 1, dialer.activeConnections()) })