diff --git a/internal/cmd/tunnel.go b/internal/cmd/tunnel.go index fc59ddaa..254e6df9 100644 --- a/internal/cmd/tunnel.go +++ b/internal/cmd/tunnel.go @@ -127,7 +127,7 @@ func (c *tunnneler) start(ctx context.Context) error { } c.log.Info(ctx, "Connecting to workspace...") - wd, err := wsnet.Dial(ctx, wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), []webrtc.ICEServer{server}) + wd, err := wsnet.DialWebsocket(ctx, wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), []webrtc.ICEServer{server}) if err != nil { return xerrors.Errorf("creating workspace dialer: %w", err) } diff --git a/wsnet/dial.go b/wsnet/dial.go index ce92390f..6af7293a 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -16,12 +16,8 @@ import ( "cdr.dev/coder-cli/coder-sdk" ) -// Dial connects to the broker and negotiates a connection to a listener. -func Dial(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*Dialer, error) { - if iceServers == nil { - iceServers = []webrtc.ICEServer{} - } - +// DialWebsocket dials the broker with a WebSocket and negotiates a connection. +func DialWebsocket(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*Dialer, error) { conn, resp, err := websocket.Dial(ctx, broker, nil) if err != nil { if resp != nil { @@ -40,13 +36,21 @@ func Dial(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*D // We should close the socket intentionally. _ = conn.Close(websocket.StatusInternalError, "an error occurred") }() + return Dial(ctx, nconn, iceServers) +} + +// Dial negotiates a connection to a listener. +func Dial(ctx context.Context, conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) { + if iceServers == nil { + iceServers = []webrtc.ICEServer{} + } rtc, err := newPeerConnection(iceServers) if err != nil { return nil, fmt.Errorf("create peer connection: %w", err) } - flushCandidates := proxyICECandidates(rtc, nconn) + flushCandidates := proxyICECandidates(rtc, conn) ctrl, err := rtc.CreateDataChannel(controlChannel, &webrtc.DataChannelInit{ Protocol: stringPtr(controlChannel), @@ -72,34 +76,34 @@ func Dial(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*D if err != nil { return nil, fmt.Errorf("marshal offer message: %w", err) } - _, err = nconn.Write(offerMessage) + _, err = conn.Write(offerMessage) if err != nil { return nil, fmt.Errorf("write offer: %w", err) } flushCandidates() dialer := &Dialer{ - ws: conn, + conn: conn, ctrl: ctrl, rtc: rtc, } - return dialer, dialer.negotiate(nconn) + return dialer, dialer.negotiate() } // Dialer enables arbitrary dialing to any network and address // inside a workspace. The opposing end of the WebSocket messages // should be proxied with a Listener. type Dialer struct { - ws *websocket.Conn + conn net.Conn ctrl *webrtc.DataChannel ctrlrw datachannel.ReadWriteCloser rtc *webrtc.PeerConnection } -func (d *Dialer) negotiate(nconn net.Conn) (err error) { +func (d *Dialer) negotiate() (err error) { var ( - decoder = json.NewDecoder(nconn) + decoder = json.NewDecoder(d.conn) errCh = make(chan error) // If candidates are sent before an offer, we place them here. // We currently have no assurances to ensure this can't happen, @@ -111,7 +115,7 @@ func (d *Dialer) negotiate(nconn net.Conn) (err error) { defer close(errCh) err := waitForDataChannelOpen(context.Background(), d.ctrl) if err != nil { - _ = d.ws.Close(websocket.StatusAbnormalClosure, "timeout") + _ = d.conn.Close() errCh <- err return } @@ -119,7 +123,7 @@ func (d *Dialer) negotiate(nconn net.Conn) (err error) { if err != nil { errCh <- err } - _ = d.ws.Close(websocket.StatusNormalClosure, "connected") + _ = d.conn.Close() }() for { diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index fab01069..84466c36 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -30,7 +30,7 @@ func ExampleDial_basic() { } } - dialer, err := Dial(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", servers) + dialer, err := DialWebsocket(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", servers) if err != nil { // Do something... } @@ -49,7 +49,7 @@ func TestDial(t *testing.T) { if err != nil { t.Error(err) } - dialer, err := Dial(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) } @@ -65,7 +65,7 @@ func TestDial(t *testing.T) { if err != nil { t.Error(err) } - dialer, err := Dial(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) } @@ -101,7 +101,7 @@ func TestDial(t *testing.T) { if err != nil { t.Error(err) } - dialer, err := Dial(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) } @@ -133,7 +133,7 @@ func TestDial(t *testing.T) { if err != nil { t.Error(err) } - dialer, err := Dial(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) }