Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 3039497

Browse files
committed
Add coder tunnel command
1 parent 739c8e4 commit 3039497

File tree

4 files changed

+298
-8
lines changed

4 files changed

+298
-8
lines changed

internal/cmd/agent.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ coder agent start https://my-coder.com
8686
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
8787
defer cancelFunc()
8888
log.Info(ctx, "connecting to broker", slog.F("url", u.String()))
89-
conn, res, err := websocket.Dial(ctx, u.String(), nil)
89+
// nolint: bodyclose
90+
conn, _, err := websocket.Dial(ctx, u.String(), nil)
9091
if err != nil {
9192
return fmt.Errorf("dial: %w", err)
9293
}
93-
_ = res.Body.Close()
9494
nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
9595
session, err := yamux.Server(nc, nil)
9696
if err != nil {

internal/cmd/cmd.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func Make() *cobra.Command {
3838
providersCmd(),
3939
genDocsCmd(app),
4040
agentCmd(),
41+
tunnelCmd(),
4142
)
4243
app.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "show verbose output")
4344
return app

internal/cmd/configssh.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,23 @@ func configSSHCmd() *cobra.Command {
3535
var (
3636
configpath string
3737
remove = false
38+
p2p = false
3839
)
3940

4041
cmd := &cobra.Command{
4142
Use: "config-ssh",
4243
Short: "Configure SSH to access Coder environments",
4344
Long: "Inject the proper OpenSSH configuration into your local SSH config file.",
44-
RunE: configSSH(&configpath, &remove),
45+
RunE: configSSH(&configpath, &remove, &p2p),
4546
}
4647
cmd.Flags().StringVar(&configpath, "filepath", filepath.Join("~", ".ssh", "config"), "override the default path of your ssh config file")
4748
cmd.Flags().BoolVar(&remove, "remove", false, "remove the auto-generated Coder ssh config")
49+
cmd.Flags().BoolVar(&p2p, "p2p", false, "(experimental) uses coder tunnel to proxy ssh connection")
4850

4951
return cmd
5052
}
5153

52-
func configSSH(configpath *string, remove *bool) func(cmd *cobra.Command, _ []string) error {
54+
func configSSH(configpath *string, remove *bool, p2p *bool) func(cmd *cobra.Command, _ []string) error {
5355
return func(cmd *cobra.Command, _ []string) error {
5456
ctx := cmd.Context()
5557
usr, err := user.Current()
@@ -113,7 +115,7 @@ func configSSH(configpath *string, remove *bool) func(cmd *cobra.Command, _ []st
113115
return xerrors.New("SSH is disabled or not available for any environments in your Coder deployment.")
114116
}
115117

116-
newConfig := makeNewConfigs(user.Username, envsWithProviders, privateKeyFilepath)
118+
newConfig := makeNewConfigs(user.Username, envsWithProviders, privateKeyFilepath, *p2p)
117119

118120
err = os.MkdirAll(filepath.Dir(*configpath), os.ModePerm)
119121
if err != nil {
@@ -174,7 +176,7 @@ func writeSSHKey(ctx context.Context, client coder.Client, privateKeyPath string
174176
return ioutil.WriteFile(privateKeyPath, []byte(key.PrivateKey), 0600)
175177
}
176178

177-
func makeNewConfigs(userName string, envs []coderutil.EnvWithWorkspaceProvider, privateKeyFilepath string) string {
179+
func makeNewConfigs(userName string, envs []coderutil.EnvWithWorkspaceProvider, privateKeyFilepath string, p2p bool) string {
178180
newConfig := fmt.Sprintf("\n%s\n%s\n\n", sshStartToken, sshStartMessage)
179181

180182
sort.Slice(envs, func(i, j int) bool { return envs[i].Env.Name < envs[j].Env.Name })
@@ -192,14 +194,28 @@ func makeNewConfigs(userName string, envs []coderutil.EnvWithWorkspaceProvider,
192194
clog.LogWarn("invalid access url", clog.Causef("malformed url: %q", env.WorkspaceProvider.EnvproxyAccessURL))
193195
continue
194196
}
195-
newConfig += makeSSHConfig(u.Host, userName, env.Env.Name, privateKeyFilepath)
197+
newConfig += makeSSHConfig(u.Host, userName, env.Env.Name, privateKeyFilepath, p2p)
196198
}
197199
newConfig += fmt.Sprintf("\n%s\n", sshEndToken)
198200

199201
return newConfig
200202
}
201203

202-
func makeSSHConfig(host, userName, envName, privateKeyFilepath string) string {
204+
func makeSSHConfig(host, userName, envName, privateKeyFilepath string, p2p bool) string {
205+
if p2p {
206+
return fmt.Sprintf(
207+
`Host coder.%s
208+
HostName localhost
209+
User %s-%s
210+
ProxyCommand go run cmd/coder/main.go tunnel %s 22 stdio
211+
StrictHostKeyChecking no
212+
ConnectTimeout=0
213+
IdentitiesOnly yes
214+
IdentityFile="%s"
215+
ServerAliveInterval 60
216+
ServerAliveCountMax 3
217+
`, envName, userName, envName, envName, privateKeyFilepath)
218+
}
203219
return fmt.Sprintf(
204220
`Host coder.%s
205221
HostName %s

internal/cmd/tunnel.go

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
package cmd
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net"
9+
"os"
10+
"strconv"
11+
"time"
12+
13+
"cdr.dev/slog"
14+
"cdr.dev/slog/sloggers/sloghuman"
15+
"github.com/pion/webrtc/v3"
16+
"github.com/spf13/cobra"
17+
"golang.org/x/xerrors"
18+
"nhooyr.io/websocket"
19+
20+
"cdr.dev/coder-cli/internal/x/xcobra"
21+
"cdr.dev/coder-cli/internal/x/xwebrtc"
22+
"cdr.dev/coder-cli/pkg/proto"
23+
)
24+
25+
func tunnelCmd() *cobra.Command {
26+
cmd := &cobra.Command{
27+
Use: "tunnel [workspace_name] [workspace_port] [localhost_port]",
28+
Args: xcobra.ExactArgs(3),
29+
Short: "proxies a port on the workspace to localhost",
30+
Long: "proxies a port on the workspace to localhost",
31+
Example: `# run a tcp tunnel from the workspace on port 3000 to localhost:3000
32+
33+
coder tunnel my-dev 3000 3000
34+
`,
35+
Hidden: true,
36+
RunE: func(cmd *cobra.Command, args []string) error {
37+
ctx := context.Background()
38+
log := slog.Make(sloghuman.Sink(os.Stderr))
39+
40+
remotePort, err := strconv.ParseUint(args[1], 10, 16)
41+
if err != nil {
42+
log.Fatal(ctx, "parse remote port", slog.Error(err))
43+
}
44+
45+
var localPort uint64
46+
if args[2] != "stdio" {
47+
localPort, err = strconv.ParseUint(args[2], 10, 16)
48+
if err != nil {
49+
log.Fatal(ctx, "parse local port", slog.Error(err))
50+
}
51+
}
52+
53+
sdk, err := newClient(ctx)
54+
if err != nil {
55+
return err
56+
}
57+
baseURL := sdk.BaseURL()
58+
59+
envs, err := sdk.Environments(ctx)
60+
if err != nil {
61+
return err
62+
}
63+
64+
var envID string
65+
for _, env := range envs {
66+
if env.Name == args[0] {
67+
envID = env.ID
68+
break
69+
}
70+
}
71+
if envID == "" {
72+
return xerrors.Errorf("No workspace found by name '%s'", args[0])
73+
}
74+
75+
c := &client{
76+
id: envID,
77+
stdio: args[2] == "stdio",
78+
localPort: uint16(localPort),
79+
remotePort: uint16(remotePort),
80+
ctx: context.Background(),
81+
logger: log,
82+
brokerAddr: baseURL.String(),
83+
token: sdk.Token(),
84+
}
85+
86+
err = c.start()
87+
if err != nil {
88+
log.Fatal(ctx, err.Error())
89+
}
90+
91+
return nil
92+
},
93+
}
94+
95+
return cmd
96+
}
97+
98+
type client struct {
99+
ctx context.Context
100+
brokerAddr string
101+
token string
102+
logger slog.Logger
103+
id string
104+
remotePort uint16
105+
localPort uint16
106+
stdio bool
107+
}
108+
109+
func (c *client) start() error {
110+
url := fmt.Sprintf("%s%s%s%s%s", c.brokerAddr, "/api/private/envagent/", c.id, "/connect?session_token=", c.token)
111+
c.logger.Info(c.ctx, "connecting to broker", slog.F("url", url))
112+
113+
conn, _, err := websocket.Dial(c.ctx, url, nil)
114+
if err != nil {
115+
return fmt.Errorf("dial: %w", err)
116+
}
117+
nconn := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
118+
119+
rtc, err := xwebrtc.NewPeerConnection()
120+
if err != nil {
121+
return fmt.Errorf("create connection: %w", err)
122+
}
123+
124+
rtc.OnNegotiationNeeded(func() {
125+
c.logger.Debug(context.Background(), "negotiation needed...")
126+
})
127+
128+
rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
129+
c.logger.Info(context.Background(), "connection state changed", slog.F("state", pcs))
130+
})
131+
132+
channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
133+
if err != nil {
134+
return fmt.Errorf("create data channel: %w", err)
135+
}
136+
flushCandidates := proto.ProxyICECandidates(rtc, nconn)
137+
138+
localDesc, err := rtc.CreateOffer(&webrtc.OfferOptions{})
139+
if err != nil {
140+
return fmt.Errorf("create offer: %w", err)
141+
}
142+
143+
err = rtc.SetLocalDescription(localDesc)
144+
if err != nil {
145+
return fmt.Errorf("set local desc: %w", err)
146+
}
147+
flushCandidates()
148+
149+
c.logger.Debug(context.Background(), "writing offer")
150+
b, _ := json.Marshal(&proto.Message{
151+
Offer: &localDesc,
152+
})
153+
_, err = nconn.Write(b)
154+
if err != nil {
155+
return fmt.Errorf("write offer: %w", err)
156+
}
157+
158+
go func() {
159+
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
160+
if err != nil {
161+
c.logger.Fatal(context.Background(), "waiting for data channel open", slog.Error(err))
162+
}
163+
_ = conn.Close(websocket.StatusNormalClosure, "rtc connected")
164+
}()
165+
166+
decoder := json.NewDecoder(nconn)
167+
for {
168+
var msg proto.Message
169+
err = decoder.Decode(&msg)
170+
if err == io.EOF {
171+
break
172+
}
173+
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
174+
break
175+
}
176+
if err != nil {
177+
return fmt.Errorf("read msg: %w", err)
178+
}
179+
if msg.Candidate != "" {
180+
c.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate))
181+
err = proto.AcceptICECandidate(rtc, &msg)
182+
if err != nil {
183+
return fmt.Errorf("accept ice: %w", err)
184+
}
185+
}
186+
if msg.Answer != nil {
187+
c.logger.Debug(context.Background(), "got answer", slog.F("answer", msg.Answer))
188+
err = rtc.SetRemoteDescription(*msg.Answer)
189+
if err != nil {
190+
return fmt.Errorf("set remote: %w", err)
191+
}
192+
}
193+
}
194+
195+
// Once we're open... let's test out the ping.
196+
pingProto := "ping"
197+
pingChannel, err := rtc.CreateDataChannel("pinger", &webrtc.DataChannelInit{
198+
Protocol: &pingProto,
199+
})
200+
if err != nil {
201+
return fmt.Errorf("create ping channel")
202+
}
203+
pingChannel.OnOpen(func() {
204+
defer func() {
205+
_ = pingChannel.Close()
206+
}()
207+
t1 := time.Now()
208+
rw, _ := pingChannel.Detach()
209+
defer func() {
210+
_ = rw.Close()
211+
}()
212+
_, _ = rw.Write([]byte("hello"))
213+
b := make([]byte, 64)
214+
_, _ = rw.Read(b)
215+
c.logger.Info(c.ctx, "your latency directly to the agent", slog.F("ms", time.Since(t1).Milliseconds()))
216+
})
217+
218+
if c.stdio {
219+
// At this point the RTC is connected and data channel is opened...
220+
rw, err := channel.Detach()
221+
if err != nil {
222+
return fmt.Errorf("detach channel: %w", err)
223+
}
224+
go func() {
225+
_, _ = io.Copy(rw, os.Stdin)
226+
}()
227+
_, err = io.Copy(os.Stdout, rw)
228+
if err != nil {
229+
return fmt.Errorf("copy: %w", err)
230+
}
231+
return nil
232+
}
233+
234+
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", c.localPort))
235+
if err != nil {
236+
return fmt.Errorf("listen: %w", err)
237+
}
238+
239+
for {
240+
conn, err := listener.Accept()
241+
if err != nil {
242+
return fmt.Errorf("accept: %w", err)
243+
}
244+
go func() {
245+
defer func() {
246+
_ = conn.Close()
247+
}()
248+
channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
249+
if err != nil {
250+
c.logger.Warn(context.Background(), "create data channel for proxying", slog.Error(err))
251+
return
252+
}
253+
defer func() {
254+
_ = channel.Close()
255+
}()
256+
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
257+
if err != nil {
258+
c.logger.Warn(context.Background(), "wait for data channel open", slog.Error(err))
259+
return
260+
}
261+
rw, err := channel.Detach()
262+
if err != nil {
263+
c.logger.Warn(context.Background(), "detach channel", slog.Error(err))
264+
return
265+
}
266+
267+
go func() {
268+
_, _ = io.Copy(conn, rw)
269+
}()
270+
_, _ = io.Copy(rw, conn)
271+
}()
272+
}
273+
}

0 commit comments

Comments
 (0)