diff --git a/wsnet/auth.go b/wsnet/auth.go index 94ffa59d..a5daf45e 100644 --- a/wsnet/auth.go +++ b/wsnet/auth.go @@ -2,6 +2,7 @@ package wsnet import ( "crypto/sha256" + "encoding/base64" "errors" "strings" ) @@ -16,6 +17,6 @@ func TURNCredentials(token string) (username, password string, err error) { } username = str[0] hash := sha256.Sum256([]byte(str[1])) - password = string(hash[:]) + password = base64.StdEncoding.EncodeToString(hash[:]) return } diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 84466c36..3d2e1f1c 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -48,10 +48,12 @@ func TestDial(t *testing.T) { _, err := Listen(context.Background(), listenAddr) if err != nil { t.Error(err) + return } dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) + return } err = dialer.Ping(context.Background()) if err != nil { @@ -64,6 +66,7 @@ func TestDial(t *testing.T) { _, err := Listen(context.Background(), listenAddr) if err != nil { t.Error(err) + return } dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { @@ -100,10 +103,12 @@ func TestDial(t *testing.T) { _, err = Listen(context.Background(), listenAddr) if err != nil { t.Error(err) + return } dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) + return } conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) if err != nil { @@ -132,6 +137,7 @@ func TestDial(t *testing.T) { srv, err := Listen(context.Background(), listenAddr) if err != nil { t.Error(err) + return } dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { diff --git a/wsnet/listen.go b/wsnet/listen.go index 6ce569b4..4382f503 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -9,6 +9,7 @@ import ( "net" "strings" "sync" + "time" "github.com/hashicorp/yamux" "github.com/pion/webrtc/v3" @@ -20,7 +21,51 @@ import ( // Listen connects to the broker proxies connections to the local net. // Close will end all RTC connections. func Listen(ctx context.Context, broker string) (io.Closer, error) { - conn, resp, err := websocket.Dial(ctx, broker, nil) + l := &listener{ + broker: broker, + connClosers: make([]io.Closer, 0), + } + // We do a one-off dial outside of the loop to ensure the initial + // connection is successful. If not, there's likely an error the + // user needs to act on. + ch, err := l.dial(ctx) + if err != nil { + return nil, err + } + go func() { + for { + err := <-ch + if errors.Is(err, io.EOF) { + // If we hit an EOF, then the connection to the broker + // was interrupted. We'll take a short break then dial + // again. + time.Sleep(time.Second) + ch, err = l.dial(ctx) + } + if err != nil { + l.acceptError = err + _ = l.Close() + break + } + } + }() + return l, nil +} + +type listener struct { + broker string + + acceptError error + ws *websocket.Conn + connClosers []io.Closer + connClosersMut sync.Mutex +} + +func (l *listener) dial(ctx context.Context) (<-chan error, error) { + if l.ws != nil { + _ = l.ws.Close(websocket.StatusNormalClosure, "new connection inbound") + } + conn, resp, err := websocket.Dial(ctx, l.broker, nil) if err != nil { if resp != nil { return nil, &coder.HTTPError{ @@ -29,40 +74,31 @@ func Listen(ctx context.Context, broker string) (io.Closer, error) { } return nil, err } + l.ws = conn nconn := websocket.NetConn(ctx, conn, websocket.MessageBinary) session, err := yamux.Server(nconn, nil) if err != nil { return nil, fmt.Errorf("create multiplex: %w", err) } - l := &listener{ - ws: conn, - connClosers: make([]io.Closer, 0), - } + errCh := make(chan error) go func() { + defer close(errCh) for { conn, err := session.Accept() if err != nil { - if errors.Is(err, io.EOF) { - continue - } - l.acceptError = err - l.Close() - return + errCh <- err + break } go l.negotiate(conn) } }() - return l, nil -} - -type listener struct { - acceptError error - ws *websocket.Conn - connClosers []io.Closer - connClosersMut sync.Mutex + return errCh, nil } // Negotiates the handshake protocol over the connection provided. +// This functions control-flow is important to readability, +// so the cognitive overload linter has been disabled. +// nolint:gocognit func (l *listener) negotiate(conn net.Conn) { var ( err error @@ -119,6 +155,13 @@ func (l *listener) negotiate(conn net.Conn) { closeError(fmt.Errorf("ICEServers must be provided")) return } + for _, server := range msg.Servers { + err = DialICE(server, nil) + if err != nil { + closeError(fmt.Errorf("dial server %+v: %w", server.URLs, err)) + return + } + } rtc, err = newPeerConnection(msg.Servers) if err != nil { closeError(err) diff --git a/wsnet/rtc.go b/wsnet/rtc.go index ce70e557..bd08baf0 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -157,6 +157,18 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro se := webrtc.SettingEngine{} se.DetachDataChannels() se.SetICETimeouts(time.Second*5, time.Second*5, time.Second*2) + + // If one server is provided and we know it's TURN, we can set the + // relay acceptable so the connection starts immediately. + if len(servers) == 1 { + server := servers[0] + if server.Credential != nil && len(server.URLs) == 1 { + url, err := ice.ParseURL(server.URLs[0]) + if err == nil && url.Proto == ice.ProtoTypeTCP { + se.SetRelayAcceptanceMinWait(0) + } + } + } api := webrtc.NewAPI(webrtc.WithSettingEngine(se)) return api.NewPeerConnection(webrtc.Configuration{