Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 50933f6

Browse files
authored
fix: Loop Dial for reconnects (#337)
* fix: Loop Dial for reconnects * Use base64 for password credential * Fix unused comments * Disable cognit for func * Fix slow TURN dials
1 parent 45cae5a commit 50933f6

File tree

4 files changed

+82
-20
lines changed

4 files changed

+82
-20
lines changed

wsnet/auth.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package wsnet
22

33
import (
44
"crypto/sha256"
5+
"encoding/base64"
56
"errors"
67
"strings"
78
)
@@ -16,6 +17,6 @@ func TURNCredentials(token string) (username, password string, err error) {
1617
}
1718
username = str[0]
1819
hash := sha256.Sum256([]byte(str[1]))
19-
password = string(hash[:])
20+
password = base64.StdEncoding.EncodeToString(hash[:])
2021
return
2122
}

wsnet/dial_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ func TestDial(t *testing.T) {
4848
_, err := Listen(context.Background(), listenAddr)
4949
if err != nil {
5050
t.Error(err)
51+
return
5152
}
5253
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
5354
if err != nil {
5455
t.Error(err)
56+
return
5557
}
5658
err = dialer.Ping(context.Background())
5759
if err != nil {
@@ -64,6 +66,7 @@ func TestDial(t *testing.T) {
6466
_, err := Listen(context.Background(), listenAddr)
6567
if err != nil {
6668
t.Error(err)
69+
return
6770
}
6871
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
6972
if err != nil {
@@ -100,10 +103,12 @@ func TestDial(t *testing.T) {
100103
_, err = Listen(context.Background(), listenAddr)
101104
if err != nil {
102105
t.Error(err)
106+
return
103107
}
104108
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
105109
if err != nil {
106110
t.Error(err)
111+
return
107112
}
108113
conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String())
109114
if err != nil {
@@ -132,6 +137,7 @@ func TestDial(t *testing.T) {
132137
srv, err := Listen(context.Background(), listenAddr)
133138
if err != nil {
134139
t.Error(err)
140+
return
135141
}
136142
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
137143
if err != nil {

wsnet/listen.go

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net"
1010
"strings"
1111
"sync"
12+
"time"
1213

1314
"github.com/hashicorp/yamux"
1415
"github.com/pion/webrtc/v3"
@@ -20,7 +21,51 @@ import (
2021
// Listen connects to the broker proxies connections to the local net.
2122
// Close will end all RTC connections.
2223
func Listen(ctx context.Context, broker string) (io.Closer, error) {
23-
conn, resp, err := websocket.Dial(ctx, broker, nil)
24+
l := &listener{
25+
broker: broker,
26+
connClosers: make([]io.Closer, 0),
27+
}
28+
// We do a one-off dial outside of the loop to ensure the initial
29+
// connection is successful. If not, there's likely an error the
30+
// user needs to act on.
31+
ch, err := l.dial(ctx)
32+
if err != nil {
33+
return nil, err
34+
}
35+
go func() {
36+
for {
37+
err := <-ch
38+
if errors.Is(err, io.EOF) {
39+
// If we hit an EOF, then the connection to the broker
40+
// was interrupted. We'll take a short break then dial
41+
// again.
42+
time.Sleep(time.Second)
43+
ch, err = l.dial(ctx)
44+
}
45+
if err != nil {
46+
l.acceptError = err
47+
_ = l.Close()
48+
break
49+
}
50+
}
51+
}()
52+
return l, nil
53+
}
54+
55+
type listener struct {
56+
broker string
57+
58+
acceptError error
59+
ws *websocket.Conn
60+
connClosers []io.Closer
61+
connClosersMut sync.Mutex
62+
}
63+
64+
func (l *listener) dial(ctx context.Context) (<-chan error, error) {
65+
if l.ws != nil {
66+
_ = l.ws.Close(websocket.StatusNormalClosure, "new connection inbound")
67+
}
68+
conn, resp, err := websocket.Dial(ctx, l.broker, nil)
2469
if err != nil {
2570
if resp != nil {
2671
return nil, &coder.HTTPError{
@@ -29,40 +74,31 @@ func Listen(ctx context.Context, broker string) (io.Closer, error) {
2974
}
3075
return nil, err
3176
}
77+
l.ws = conn
3278
nconn := websocket.NetConn(ctx, conn, websocket.MessageBinary)
3379
session, err := yamux.Server(nconn, nil)
3480
if err != nil {
3581
return nil, fmt.Errorf("create multiplex: %w", err)
3682
}
37-
l := &listener{
38-
ws: conn,
39-
connClosers: make([]io.Closer, 0),
40-
}
83+
errCh := make(chan error)
4184
go func() {
85+
defer close(errCh)
4286
for {
4387
conn, err := session.Accept()
4488
if err != nil {
45-
if errors.Is(err, io.EOF) {
46-
continue
47-
}
48-
l.acceptError = err
49-
l.Close()
50-
return
89+
errCh <- err
90+
break
5191
}
5292
go l.negotiate(conn)
5393
}
5494
}()
55-
return l, nil
56-
}
57-
58-
type listener struct {
59-
acceptError error
60-
ws *websocket.Conn
61-
connClosers []io.Closer
62-
connClosersMut sync.Mutex
95+
return errCh, nil
6396
}
6497

6598
// Negotiates the handshake protocol over the connection provided.
99+
// This functions control-flow is important to readability,
100+
// so the cognitive overload linter has been disabled.
101+
// nolint:gocognit
66102
func (l *listener) negotiate(conn net.Conn) {
67103
var (
68104
err error
@@ -119,6 +155,13 @@ func (l *listener) negotiate(conn net.Conn) {
119155
closeError(fmt.Errorf("ICEServers must be provided"))
120156
return
121157
}
158+
for _, server := range msg.Servers {
159+
err = DialICE(server, nil)
160+
if err != nil {
161+
closeError(fmt.Errorf("dial server %+v: %w", server.URLs, err))
162+
return
163+
}
164+
}
122165
rtc, err = newPeerConnection(msg.Servers)
123166
if err != nil {
124167
closeError(err)

wsnet/rtc.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro
157157
se := webrtc.SettingEngine{}
158158
se.DetachDataChannels()
159159
se.SetICETimeouts(time.Second*5, time.Second*5, time.Second*2)
160+
161+
// If one server is provided and we know it's TURN, we can set the
162+
// relay acceptable so the connection starts immediately.
163+
if len(servers) == 1 {
164+
server := servers[0]
165+
if server.Credential != nil && len(server.URLs) == 1 {
166+
url, err := ice.ParseURL(server.URLs[0])
167+
if err == nil && url.Proto == ice.ProtoTypeTCP {
168+
se.SetRelayAcceptanceMinWait(0)
169+
}
170+
}
171+
}
160172
api := webrtc.NewAPI(webrtc.WithSettingEngine(se))
161173

162174
return api.NewPeerConnection(webrtc.Configuration{

0 commit comments

Comments
 (0)