Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit aef2996

Browse files
committed
Centralize webrtc dial logic into xwebrtc
1 parent 43edc2f commit aef2996

File tree

9 files changed

+445
-266
lines changed

9 files changed

+445
-266
lines changed

agent/stream.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ import (
77
"io"
88
"net"
99

10+
"cdr.dev/coder-cli/xwebrtc"
11+
1012
"cdr.dev/slog"
1113
"github.com/hashicorp/yamux"
1214
"github.com/pion/webrtc/v3"
1315
"golang.org/x/xerrors"
1416

15-
"cdr.dev/coder-cli/internal/x/xwebrtc"
1617
"cdr.dev/coder-cli/pkg/proto"
1718
)
1819

@@ -128,6 +129,10 @@ func (s *stream) processMessage(msg proto.Message) {
128129
}
129130

130131
func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
132+
if channel.Protocol() == "control" {
133+
return
134+
}
135+
131136
if channel.Protocol() == "ping" {
132137
channel.OnOpen(func() {
133138
rw, err := channel.Detach()
@@ -149,7 +154,7 @@ func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
149154
return
150155
}
151156

152-
prto, port, err := xwebrtc.ParseProxyDataChannel(channel)
157+
prto, addr, err := xwebrtc.ParseProxyDataChannel(channel)
153158
if err != nil {
154159
s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err))
155160
return
@@ -159,14 +164,14 @@ func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
159164
return
160165
}
161166

162-
conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port))
167+
conn, err := net.Dial(prto, addr)
163168
if err != nil {
164-
s.fatal(fmt.Errorf("failed to dial client port: %d", port))
169+
s.fatal(fmt.Errorf("failed to dial client addr: %s", addr))
165170
return
166171
}
167172

168173
channel.OnOpen(func() {
169-
s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port))
174+
s.logger.Debug(context.Background(), "proxying data channel", slog.F("addr", addr))
170175
rw, err := channel.Detach()
171176
if err != nil {
172177
_ = channel.Close()

internal/cmd/tunnel.go

Lines changed: 35 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,20 @@ package cmd
22

33
import (
44
"context"
5-
"encoding/json"
65
"fmt"
76
"io"
87
"net"
9-
"net/url"
108
"os"
119
"strconv"
12-
"time"
1310

1411
"cdr.dev/slog"
1512
"cdr.dev/slog/sloggers/sloghuman"
16-
"github.com/pion/webrtc/v3"
1713
"github.com/spf13/cobra"
1814
"golang.org/x/xerrors"
19-
"nhooyr.io/websocket"
2015

2116
"cdr.dev/coder-cli/coder-sdk"
2217
"cdr.dev/coder-cli/internal/x/xcobra"
23-
"cdr.dev/coder-cli/internal/x/xwebrtc"
24-
"cdr.dev/coder-cli/pkg/proto"
18+
"cdr.dev/coder-cli/xwebrtc"
2519
)
2620

2721
func tunnelCmd() *cobra.Command {
@@ -41,26 +35,26 @@ coder tunnel my-dev 3000 3000
4135

4236
remotePort, err := strconv.ParseUint(args[1], 10, 16)
4337
if err != nil {
44-
log.Fatal(ctx, "parse remote port", slog.Error(err))
38+
return xerrors.Errorf("parse remote port: %w", err)
4539
}
4640

4741
var localPort uint64
4842
if args[2] != "stdio" {
4943
localPort, err = strconv.ParseUint(args[2], 10, 16)
5044
if err != nil {
51-
log.Fatal(ctx, "parse local port", slog.Error(err))
45+
return xerrors.Errorf("parse local port: %w", err)
5246
}
5347
}
5448

5549
sdk, err := newClient(ctx)
5650
if err != nil {
57-
return err
51+
return xerrors.Errorf("getting coder client: %w", err)
5852
}
5953
baseURL := sdk.BaseURL()
6054

6155
envs, err := getEnvs(ctx, sdk, coder.Me)
6256
if err != nil {
63-
return err
57+
return xerrors.Errorf("get workspaces: %w", err)
6458
}
6559

6660
var envID string
@@ -74,20 +68,17 @@ coder tunnel my-dev 3000 3000
7468
return xerrors.Errorf("No workspace found by name '%s'", args[0])
7569
}
7670

77-
c := &client{
78-
id: envID,
79-
stdio: args[2] == "stdio",
80-
localPort: uint16(localPort),
81-
remotePort: uint16(remotePort),
82-
ctx: context.Background(),
83-
logger: log.Leveled(slog.LevelDebug),
84-
brokerAddr: baseURL,
85-
token: sdk.Token(),
71+
c := &tunnneler{
72+
wsClient: xwebrtc.NewWorkspaceClient(log.Leveled(slog.LevelDebug), &baseURL, sdk.Token()),
73+
workspaceID: envID,
74+
stdio: args[2] == "stdio",
75+
localPort: uint16(localPort),
76+
remotePort: uint16(remotePort),
8677
}
8778

88-
err = c.start()
79+
err = c.start(ctx)
8980
if err != nil {
90-
log.Fatal(ctx, err.Error())
81+
return xerrors.Errorf("running tunnel: %w", err)
9182
}
9283

9384
return nil
@@ -97,197 +88,56 @@ coder tunnel my-dev 3000 3000
9788
return cmd
9889
}
9990

100-
type client struct {
101-
ctx context.Context
102-
brokerAddr url.URL
103-
token string
104-
logger slog.Logger
105-
id string
106-
remotePort uint16
107-
localPort uint16
108-
stdio bool
91+
type tunnneler struct {
92+
wsClient *xwebrtc.WorkspaceClient
93+
workspaceID string
94+
remotePort uint16
95+
localPort uint16
96+
stdio bool
10997
}
11098

111-
func (c *client) start() error {
112-
url := fmt.Sprintf("%s%s%s%s%s", c.brokerAddr.String(), "/api/private/envagent/", c.id, "/connect?session_token=", c.token)
113-
turnScheme := "turns"
114-
if c.brokerAddr.Scheme == "http" {
115-
turnScheme = "turn"
116-
}
117-
tcpProxy := fmt.Sprintf("%s:%s:5349?transport=tcp", turnScheme, c.brokerAddr.Host)
118-
c.logger.Info(c.ctx, "connecting to broker", slog.F("url", url), slog.F("tcp-proxy", tcpProxy))
119-
conn, resp, err := websocket.Dial(c.ctx, url, nil)
120-
if err != nil && resp == nil {
121-
return fmt.Errorf("dial: %w", err)
122-
}
123-
if err != nil && resp != nil {
124-
return &coder.HTTPError{
125-
Response: resp,
126-
}
127-
}
128-
nconn := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
129-
130-
// Only enabled under a private feature flag for now,
131-
// so insecure connections are entirely fine to allow.
132-
servers := []webrtc.ICEServer{{
133-
URLs: []string{tcpProxy},
134-
Username: "insecure",
135-
Credential: "pass",
136-
CredentialType: webrtc.ICECredentialTypePassword,
137-
}}
138-
rtc, err := xwebrtc.NewPeerConnection(servers)
139-
if err != nil {
140-
return fmt.Errorf("create connection: %w", err)
141-
}
142-
143-
rtc.OnNegotiationNeeded(func() {
144-
c.logger.Debug(context.Background(), "negotiation needed...")
145-
})
146-
147-
rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
148-
c.logger.Info(context.Background(), "connection state changed", slog.F("state", pcs))
149-
})
150-
151-
channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
152-
if err != nil {
153-
return fmt.Errorf("create data channel: %w", err)
154-
}
155-
flushCandidates := proto.ProxyICECandidates(rtc, nconn)
156-
157-
localDesc, err := rtc.CreateOffer(&webrtc.OfferOptions{})
158-
if err != nil {
159-
return fmt.Errorf("create offer: %w", err)
160-
}
161-
162-
err = rtc.SetLocalDescription(localDesc)
163-
if err != nil {
164-
return fmt.Errorf("set local desc: %w", err)
165-
}
166-
167-
c.logger.Debug(context.Background(), "writing offer")
168-
b, _ := json.Marshal(&proto.Message{
169-
Offer: &localDesc,
170-
Servers: servers,
171-
})
172-
_, err = nconn.Write(b)
99+
func (c *tunnneler) start(ctx context.Context) error {
100+
wd, err := c.wsClient.DialWorkspace(ctx, c.workspaceID)
173101
if err != nil {
174-
return fmt.Errorf("write offer: %w", err)
175-
}
176-
flushCandidates()
177-
178-
go func() {
179-
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
180-
if err != nil {
181-
c.logger.Fatal(context.Background(), "waiting for data channel open", slog.Error(err))
182-
}
183-
_ = conn.Close(websocket.StatusNormalClosure, "rtc connected")
184-
}()
185-
186-
decoder := json.NewDecoder(nconn)
187-
for {
188-
var msg proto.Message
189-
err = decoder.Decode(&msg)
190-
if err == io.EOF {
191-
break
192-
}
193-
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
194-
break
195-
}
196-
if err != nil {
197-
return fmt.Errorf("read msg: %w", err)
198-
}
199-
if msg.Candidate != "" {
200-
c.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate))
201-
err = proto.AcceptICECandidate(rtc, &msg)
202-
if err != nil {
203-
return fmt.Errorf("accept ice: %w", err)
204-
}
205-
}
206-
if msg.Answer != nil {
207-
c.logger.Debug(context.Background(), "got answer", slog.F("answer", msg.Answer))
208-
err = rtc.SetRemoteDescription(*msg.Answer)
209-
if err != nil {
210-
return fmt.Errorf("set remote: %w", err)
211-
}
212-
}
102+
return xerrors.Errorf("create workspace dialer: %w", err)
213103
}
214-
215-
// Once we're open... let's test out the ping.
216-
pingProto := "ping"
217-
pingChannel, err := rtc.CreateDataChannel("pinger", &webrtc.DataChannelInit{
218-
Protocol: &pingProto,
219-
})
104+
nc, err := wd.DialContext(ctx, xwebrtc.NetworkTCP, fmt.Sprintf("localhost:%d", c.remotePort))
220105
if err != nil {
221-
return fmt.Errorf("create ping channel")
106+
return xerrors.Errorf("dial: %w", err)
222107
}
223-
pingChannel.OnOpen(func() {
224-
defer func() {
225-
_ = pingChannel.Close()
226-
}()
227-
t1 := time.Now()
228-
rw, _ := pingChannel.Detach()
229-
defer func() {
230-
_ = rw.Close()
231-
}()
232-
_, _ = rw.Write([]byte("hello"))
233-
b := make([]byte, 64)
234-
_, _ = rw.Read(b)
235-
c.logger.Info(c.ctx, "your latency directly to the agent", slog.F("ms", time.Since(t1).Milliseconds()))
236-
})
237108

109+
// proxy via stdio
238110
if c.stdio {
239-
// At this point the RTC is connected and data channel is opened...
240-
rw, err := channel.Detach()
241-
if err != nil {
242-
return fmt.Errorf("detach channel: %w", err)
243-
}
244111
go func() {
245-
_, _ = io.Copy(rw, os.Stdin)
112+
_, _ = io.Copy(nc, os.Stdin)
246113
}()
247-
_, err = io.Copy(os.Stdout, rw)
114+
_, err = io.Copy(os.Stdout, nc)
248115
if err != nil {
249-
return fmt.Errorf("copy: %w", err)
116+
return xerrors.Errorf("copy: %w", err)
250117
}
251118
return nil
252119
}
253120

121+
// proxy via tcp listener
254122
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", c.localPort))
255123
if err != nil {
256-
return fmt.Errorf("listen: %w", err)
124+
return xerrors.Errorf("listen: %w", err)
257125
}
258126

259127
for {
260-
conn, err := listener.Accept()
128+
lc, err := listener.Accept()
261129
if err != nil {
262-
return fmt.Errorf("accept: %w", err)
130+
return xerrors.Errorf("accept: %w", err)
263131
}
264132
go func() {
265133
defer func() {
266-
_ = conn.Close()
267-
}()
268-
channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
269-
if err != nil {
270-
c.logger.Warn(context.Background(), "create data channel for proxying", slog.Error(err))
271-
return
272-
}
273-
defer func() {
274-
_ = channel.Close()
134+
_ = lc.Close()
275135
}()
276-
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
277-
if err != nil {
278-
c.logger.Warn(context.Background(), "wait for data channel open", slog.Error(err))
279-
return
280-
}
281-
rw, err := channel.Detach()
282-
if err != nil {
283-
c.logger.Warn(context.Background(), "detach channel", slog.Error(err))
284-
return
285-
}
286136

287137
go func() {
288-
_, _ = io.Copy(conn, rw)
138+
_, _ = io.Copy(lc, nc)
289139
}()
290-
_, _ = io.Copy(rw, conn)
140+
_, _ = io.Copy(nc, lc)
291141
}()
292142
}
293143
}

0 commit comments

Comments
 (0)