Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 16ca4a6

Browse files
authored
feat: Add DialWebsocket func to enable Dial through net.Conn (#335)
1 parent e140b59 commit 16ca4a6

File tree

3 files changed

+25
-21
lines changed

3 files changed

+25
-21
lines changed

internal/cmd/tunnel.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (c *tunnneler) start(ctx context.Context) error {
127127
}
128128

129129
c.log.Info(ctx, "Connecting to workspace...")
130-
wd, err := wsnet.Dial(ctx, wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), []webrtc.ICEServer{server})
130+
wd, err := wsnet.DialWebsocket(ctx, wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), []webrtc.ICEServer{server})
131131
if err != nil {
132132
return xerrors.Errorf("creating workspace dialer: %w", err)
133133
}

wsnet/dial.go

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,8 @@ import (
1616
"cdr.dev/coder-cli/coder-sdk"
1717
)
1818

19-
// Dial connects to the broker and negotiates a connection to a listener.
20-
func Dial(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*Dialer, error) {
21-
if iceServers == nil {
22-
iceServers = []webrtc.ICEServer{}
23-
}
24-
19+
// DialWebsocket dials the broker with a WebSocket and negotiates a connection.
20+
func DialWebsocket(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*Dialer, error) {
2521
conn, resp, err := websocket.Dial(ctx, broker, nil)
2622
if err != nil {
2723
if resp != nil {
@@ -40,13 +36,21 @@ func Dial(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*D
4036
// We should close the socket intentionally.
4137
_ = conn.Close(websocket.StatusInternalError, "an error occurred")
4238
}()
39+
return Dial(ctx, nconn, iceServers)
40+
}
41+
42+
// Dial negotiates a connection to a listener.
43+
func Dial(ctx context.Context, conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) {
44+
if iceServers == nil {
45+
iceServers = []webrtc.ICEServer{}
46+
}
4347

4448
rtc, err := newPeerConnection(iceServers)
4549
if err != nil {
4650
return nil, fmt.Errorf("create peer connection: %w", err)
4751
}
4852

49-
flushCandidates := proxyICECandidates(rtc, nconn)
53+
flushCandidates := proxyICECandidates(rtc, conn)
5054

5155
ctrl, err := rtc.CreateDataChannel(controlChannel, &webrtc.DataChannelInit{
5256
Protocol: stringPtr(controlChannel),
@@ -72,34 +76,34 @@ func Dial(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*D
7276
if err != nil {
7377
return nil, fmt.Errorf("marshal offer message: %w", err)
7478
}
75-
_, err = nconn.Write(offerMessage)
79+
_, err = conn.Write(offerMessage)
7680
if err != nil {
7781
return nil, fmt.Errorf("write offer: %w", err)
7882
}
7983
flushCandidates()
8084

8185
dialer := &Dialer{
82-
ws: conn,
86+
conn: conn,
8387
ctrl: ctrl,
8488
rtc: rtc,
8589
}
8690

87-
return dialer, dialer.negotiate(nconn)
91+
return dialer, dialer.negotiate()
8892
}
8993

9094
// Dialer enables arbitrary dialing to any network and address
9195
// inside a workspace. The opposing end of the WebSocket messages
9296
// should be proxied with a Listener.
9397
type Dialer struct {
94-
ws *websocket.Conn
98+
conn net.Conn
9599
ctrl *webrtc.DataChannel
96100
ctrlrw datachannel.ReadWriteCloser
97101
rtc *webrtc.PeerConnection
98102
}
99103

100-
func (d *Dialer) negotiate(nconn net.Conn) (err error) {
104+
func (d *Dialer) negotiate() (err error) {
101105
var (
102-
decoder = json.NewDecoder(nconn)
106+
decoder = json.NewDecoder(d.conn)
103107
errCh = make(chan error)
104108
// If candidates are sent before an offer, we place them here.
105109
// We currently have no assurances to ensure this can't happen,
@@ -111,15 +115,15 @@ func (d *Dialer) negotiate(nconn net.Conn) (err error) {
111115
defer close(errCh)
112116
err := waitForDataChannelOpen(context.Background(), d.ctrl)
113117
if err != nil {
114-
_ = d.ws.Close(websocket.StatusAbnormalClosure, "timeout")
118+
_ = d.conn.Close()
115119
errCh <- err
116120
return
117121
}
118122
d.ctrlrw, err = d.ctrl.Detach()
119123
if err != nil {
120124
errCh <- err
121125
}
122-
_ = d.ws.Close(websocket.StatusNormalClosure, "connected")
126+
_ = d.conn.Close()
123127
}()
124128

125129
for {

wsnet/dial_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func ExampleDial_basic() {
3030
}
3131
}
3232

33-
dialer, err := Dial(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", servers)
33+
dialer, err := DialWebsocket(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", servers)
3434
if err != nil {
3535
// Do something...
3636
}
@@ -49,7 +49,7 @@ func TestDial(t *testing.T) {
4949
if err != nil {
5050
t.Error(err)
5151
}
52-
dialer, err := Dial(context.Background(), connectAddr, nil)
52+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
5353
if err != nil {
5454
t.Error(err)
5555
}
@@ -65,7 +65,7 @@ func TestDial(t *testing.T) {
6565
if err != nil {
6666
t.Error(err)
6767
}
68-
dialer, err := Dial(context.Background(), connectAddr, nil)
68+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
6969
if err != nil {
7070
t.Error(err)
7171
}
@@ -101,7 +101,7 @@ func TestDial(t *testing.T) {
101101
if err != nil {
102102
t.Error(err)
103103
}
104-
dialer, err := Dial(context.Background(), connectAddr, nil)
104+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
105105
if err != nil {
106106
t.Error(err)
107107
}
@@ -133,7 +133,7 @@ func TestDial(t *testing.T) {
133133
if err != nil {
134134
t.Error(err)
135135
}
136-
dialer, err := Dial(context.Background(), connectAddr, nil)
136+
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
137137
if err != nil {
138138
t.Error(err)
139139
}

0 commit comments

Comments
 (0)