Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 2869801

Browse files
committed
Add coder tunnel command
1 parent 739c8e4 commit 2869801

File tree

3 files changed

+275
-2
lines changed

3 files changed

+275
-2
lines changed

internal/cmd/agent.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ coder agent start https://my-coder.com
8686
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
8787
defer cancelFunc()
8888
log.Info(ctx, "connecting to broker", slog.F("url", u.String()))
89-
conn, res, err := websocket.Dial(ctx, u.String(), nil)
89+
// nolint: bodyclose
90+
conn, _, err := websocket.Dial(ctx, u.String(), nil)
9091
if err != nil {
9192
return fmt.Errorf("dial: %w", err)
9293
}
93-
_ = res.Body.Close()
9494
nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
9595
session, err := yamux.Server(nc, nil)
9696
if err != nil {

internal/cmd/cmd.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func Make() *cobra.Command {
3838
providersCmd(),
3939
genDocsCmd(app),
4040
agentCmd(),
41+
tunnelCmd(),
4142
)
4243
app.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "show verbose output")
4344
return app

internal/cmd/tunnel.go

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
package cmd
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net"
9+
"os"
10+
"strconv"
11+
"time"
12+
13+
"cdr.dev/slog"
14+
"cdr.dev/slog/sloggers/sloghuman"
15+
"github.com/pion/webrtc/v3"
16+
"github.com/spf13/cobra"
17+
"golang.org/x/xerrors"
18+
"nhooyr.io/websocket"
19+
20+
"cdr.dev/coder-cli/internal/x/xcobra"
21+
"cdr.dev/coder-cli/internal/x/xwebrtc"
22+
"cdr.dev/coder-cli/pkg/proto"
23+
)
24+
25+
func tunnelCmd() *cobra.Command {
26+
cmd := &cobra.Command{
27+
Use: "tunnel [workspace_name] [workspace_port] [localhost_port]",
28+
Args: xcobra.ExactArgs(3),
29+
Short: "proxies a port on the workspace to localhost",
30+
Long: "proxies a port on the workspace to localhost",
31+
Example: `# run a tcp tunnel from the workspace on port 3000 to localhost:3000
32+
33+
coder tunnel my-dev 3000 3000
34+
`,
35+
RunE: func(cmd *cobra.Command, args []string) error {
36+
ctx := context.Background()
37+
log := slog.Make(sloghuman.Sink(os.Stderr))
38+
39+
remotePort, err := strconv.ParseUint(args[1], 10, 16)
40+
if err != nil {
41+
log.Fatal(ctx, "parse remote port", slog.Error(err))
42+
}
43+
44+
var localPort uint64
45+
if args[2] != "stdio" {
46+
localPort, err = strconv.ParseUint(args[2], 10, 16)
47+
if err != nil {
48+
log.Fatal(ctx, "parse local port", slog.Error(err))
49+
}
50+
}
51+
52+
sdk, err := newClient(ctx)
53+
if err != nil {
54+
return err
55+
}
56+
baseURL := sdk.BaseURL()
57+
58+
envs, err := sdk.Environments(ctx)
59+
if err != nil {
60+
return err
61+
}
62+
63+
var envID string
64+
for _, env := range envs {
65+
if env.Name == args[0] {
66+
envID = env.ID
67+
break
68+
}
69+
}
70+
if envID == "" {
71+
return xerrors.Errorf("No workspace found by name '%s'", args[0])
72+
}
73+
74+
c := &client{
75+
id: envID,
76+
stdio: args[2] == "stdio",
77+
localPort: uint16(localPort),
78+
remotePort: uint16(remotePort),
79+
ctx: context.Background(),
80+
logger: log,
81+
brokerAddr: baseURL.String(),
82+
token: sdk.Token(),
83+
}
84+
85+
err = c.start()
86+
if err != nil {
87+
log.Fatal(ctx, err.Error())
88+
}
89+
90+
return nil
91+
},
92+
}
93+
94+
return cmd
95+
}
96+
97+
type client struct {
98+
ctx context.Context
99+
brokerAddr string
100+
token string
101+
logger slog.Logger
102+
id string
103+
remotePort uint16
104+
localPort uint16
105+
stdio bool
106+
}
107+
108+
func (c *client) start() error {
109+
url := fmt.Sprintf("%s%s%s%s%s", c.brokerAddr, "/api/private/envagent/", c.id, "/connect?session_token=", c.token)
110+
c.logger.Info(c.ctx, "connecting to broker", slog.F("url", url))
111+
112+
conn, _, err := websocket.Dial(c.ctx, url, nil)
113+
if err != nil {
114+
return fmt.Errorf("dial: %w", err)
115+
}
116+
nconn := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
117+
118+
rtc, err := xwebrtc.NewPeerConnection()
119+
if err != nil {
120+
return fmt.Errorf("create connection: %w", err)
121+
}
122+
123+
rtc.OnNegotiationNeeded(func() {
124+
c.logger.Debug(context.Background(), "negotiation needed...")
125+
})
126+
127+
rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
128+
c.logger.Info(context.Background(), "connection state changed", slog.F("state", pcs))
129+
})
130+
131+
channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
132+
if err != nil {
133+
return fmt.Errorf("create data channel: %w", err)
134+
}
135+
flushCandidates := proto.ProxyICECandidates(rtc, nconn)
136+
137+
localDesc, err := rtc.CreateOffer(&webrtc.OfferOptions{})
138+
if err != nil {
139+
return fmt.Errorf("create offer: %w", err)
140+
}
141+
142+
err = rtc.SetLocalDescription(localDesc)
143+
if err != nil {
144+
return fmt.Errorf("set local desc: %w", err)
145+
}
146+
flushCandidates()
147+
148+
c.logger.Debug(context.Background(), "writing offer")
149+
b, _ := json.Marshal(&proto.Message{
150+
Offer: &localDesc,
151+
})
152+
_, err = nconn.Write(b)
153+
if err != nil {
154+
return fmt.Errorf("write offer: %w", err)
155+
}
156+
157+
go func() {
158+
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
159+
if err != nil {
160+
c.logger.Fatal(context.Background(), "waiting for data channel open", slog.Error(err))
161+
}
162+
_ = conn.Close(websocket.StatusNormalClosure, "rtc connected")
163+
}()
164+
165+
decoder := json.NewDecoder(nconn)
166+
for {
167+
var msg proto.Message
168+
err = decoder.Decode(&msg)
169+
if err == io.EOF {
170+
break
171+
}
172+
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
173+
break
174+
}
175+
if err != nil {
176+
return fmt.Errorf("read msg: %w", err)
177+
}
178+
if msg.Candidate != "" {
179+
c.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate))
180+
err = proto.AcceptICECandidate(rtc, &msg)
181+
if err != nil {
182+
return fmt.Errorf("accept ice: %w", err)
183+
}
184+
}
185+
if msg.Answer != nil {
186+
c.logger.Debug(context.Background(), "got answer", slog.F("answer", msg.Answer))
187+
err = rtc.SetRemoteDescription(*msg.Answer)
188+
if err != nil {
189+
return fmt.Errorf("set remote: %w", err)
190+
}
191+
}
192+
}
193+
194+
// Once we're open... let's test out the ping.
195+
pingProto := "ping"
196+
pingChannel, err := rtc.CreateDataChannel("pinger", &webrtc.DataChannelInit{
197+
Protocol: &pingProto,
198+
})
199+
if err != nil {
200+
return fmt.Errorf("create ping channel")
201+
}
202+
pingChannel.OnOpen(func() {
203+
defer func() {
204+
_ = pingChannel.Close()
205+
}()
206+
t1 := time.Now()
207+
rw, _ := pingChannel.Detach()
208+
defer func() {
209+
_ = rw.Close()
210+
}()
211+
_, _ = rw.Write([]byte("hello"))
212+
b := make([]byte, 64)
213+
_, _ = rw.Read(b)
214+
c.logger.Info(c.ctx, "your latency directly to the agent", slog.F("ms", time.Since(t1).Milliseconds()))
215+
})
216+
217+
if c.stdio {
218+
// At this point the RTC is connected and data channel is opened...
219+
rw, err := channel.Detach()
220+
if err != nil {
221+
return fmt.Errorf("detach channel: %w", err)
222+
}
223+
go func() {
224+
_, _ = io.Copy(rw, os.Stdin)
225+
}()
226+
_, err = io.Copy(os.Stdout, rw)
227+
if err != nil {
228+
return fmt.Errorf("copy: %w", err)
229+
}
230+
return nil
231+
}
232+
233+
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", c.localPort))
234+
if err != nil {
235+
return fmt.Errorf("listen: %w", err)
236+
}
237+
238+
for {
239+
conn, err := listener.Accept()
240+
if err != nil {
241+
return fmt.Errorf("accept: %w", err)
242+
}
243+
go func() {
244+
defer func() {
245+
_ = conn.Close()
246+
}()
247+
channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
248+
if err != nil {
249+
c.logger.Warn(context.Background(), "create data channel for proxying", slog.Error(err))
250+
return
251+
}
252+
defer func() {
253+
_ = channel.Close()
254+
}()
255+
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
256+
if err != nil {
257+
c.logger.Warn(context.Background(), "wait for data channel open", slog.Error(err))
258+
return
259+
}
260+
rw, err := channel.Detach()
261+
if err != nil {
262+
c.logger.Warn(context.Background(), "detach channel", slog.Error(err))
263+
return
264+
}
265+
266+
go func() {
267+
_, _ = io.Copy(conn, rw)
268+
}()
269+
_, _ = io.Copy(rw, conn)
270+
}()
271+
}
272+
}

0 commit comments

Comments
 (0)