Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 739c8e4

Browse files
f0sseldeansheather
andauthored
Add coder agent start (#311)
* Add coder agent start * Recrease go.sum Co-authored-by: Dean Sheather <[email protected]>
1 parent 996a3cd commit 739c8e4

File tree

11 files changed

+522
-18
lines changed

11 files changed

+522
-18
lines changed

go.mod

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ require (
99
github.com/fatih/color v1.10.0
1010
github.com/google/go-cmp v0.5.5
1111
github.com/gorilla/websocket v1.4.2
12+
github.com/hashicorp/yamux v0.0.0-20210316155119-a95892c5f864
1213
github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f
1314
github.com/klauspost/compress v1.10.8 // indirect
1415
github.com/manifoldco/promptui v0.8.0
16+
github.com/pion/webrtc/v3 v3.0.20
1517
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
1618
github.com/rjeczalik/notify v0.9.2
1719
github.com/spf13/cobra v1.1.3
18-
github.com/stretchr/testify v1.6.1 // indirect
19-
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
20-
golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect
2120
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208
22-
golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13
21+
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005
22+
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1
2323
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
2424
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
2525
nhooyr.io/websocket v1.8.6

go.sum

Lines changed: 102 additions & 9 deletions
Large diffs are not rendered by default.

internal/cmd/agent.go

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
package cmd
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net"
9+
"net/url"
10+
"os"
11+
"strings"
12+
"time"
13+
14+
"cdr.dev/slog"
15+
"cdr.dev/slog/sloggers/sloghuman"
16+
"github.com/hashicorp/yamux"
17+
"github.com/pion/webrtc/v3"
18+
"github.com/spf13/cobra"
19+
"golang.org/x/xerrors"
20+
"nhooyr.io/websocket"
21+
22+
"cdr.dev/coder-cli/internal/x/xcobra"
23+
"cdr.dev/coder-cli/internal/x/xwebrtc"
24+
"cdr.dev/coder-cli/pkg/proto"
25+
)
26+
27+
func agentCmd() *cobra.Command {
28+
cmd := &cobra.Command{
29+
Use: "agent",
30+
Short: "Run the workspace agent",
31+
Long: "Connect to Coder and start running a p2p agent",
32+
Hidden: true,
33+
}
34+
35+
cmd.AddCommand(
36+
startCmd(),
37+
)
38+
return cmd
39+
}
40+
41+
func startCmd() *cobra.Command {
42+
var (
43+
token string
44+
)
45+
cmd := &cobra.Command{
46+
Use: "start [coderURL] --token=[token]",
47+
Args: xcobra.ExactArgs(1),
48+
Short: "starts the coder agent",
49+
Long: "starts the coder agent",
50+
Example: `# start the agent and connect with a Coder agent token
51+
52+
coder agent start https://my-coder.com --token xxxx-xxxx
53+
54+
# start the agent and use CODER_AGENT_TOKEN env var for auth token
55+
56+
coder agent start https://my-coder.com
57+
`,
58+
RunE: func(cmd *cobra.Command, args []string) error {
59+
ctx := cmd.Context()
60+
log := slog.Make(sloghuman.Sink(cmd.OutOrStdout()))
61+
62+
// Pull the URL from the args and do some sanity check.
63+
rawURL := args[0]
64+
if rawURL == "" || !strings.HasPrefix(rawURL, "http") {
65+
return xerrors.Errorf("invalid URL")
66+
}
67+
u, err := url.Parse(rawURL)
68+
if err != nil {
69+
return xerrors.Errorf("parse url: %w", err)
70+
}
71+
// Remove the trailing '/' if any.
72+
u.Path = "/api/private/envagent/listen"
73+
74+
if token == "" {
75+
var ok bool
76+
token, ok = os.LookupEnv("CODER_AGENT_TOKEN")
77+
if !ok {
78+
return xerrors.New("must pass --token or set the CODER_AGENT_TOKEN env variable")
79+
}
80+
}
81+
82+
q := u.Query()
83+
q.Set("service_token", token)
84+
u.RawQuery = q.Encode()
85+
86+
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
87+
defer cancelFunc()
88+
log.Info(ctx, "connecting to broker", slog.F("url", u.String()))
89+
conn, res, err := websocket.Dial(ctx, u.String(), nil)
90+
if err != nil {
91+
return fmt.Errorf("dial: %w", err)
92+
}
93+
_ = res.Body.Close()
94+
nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
95+
session, err := yamux.Server(nc, nil)
96+
if err != nil {
97+
return fmt.Errorf("open: %w", err)
98+
}
99+
log.Info(ctx, "connected to broker. awaiting connection requests")
100+
for {
101+
st, err := session.AcceptStream()
102+
if err != nil {
103+
return fmt.Errorf("accept stream: %w", err)
104+
}
105+
stream := &stream{
106+
logger: log.Named(fmt.Sprintf("stream %d", st.StreamID())),
107+
stream: st,
108+
}
109+
go stream.listen()
110+
}
111+
},
112+
}
113+
114+
cmd.Flags().StringVar(&token, "token", "", "coder agent token")
115+
return cmd
116+
}
117+
118+
type stream struct {
119+
stream *yamux.Stream
120+
logger slog.Logger
121+
122+
rtc *webrtc.PeerConnection
123+
}
124+
125+
// writes an error and closes.
126+
func (s *stream) fatal(err error) {
127+
_ = s.write(proto.Message{
128+
Error: err.Error(),
129+
})
130+
s.logger.Error(context.Background(), err.Error(), slog.Error(err))
131+
_ = s.stream.Close()
132+
}
133+
134+
func (s *stream) listen() {
135+
decoder := json.NewDecoder(s.stream)
136+
for {
137+
var msg proto.Message
138+
err := decoder.Decode(&msg)
139+
if err == io.EOF {
140+
break
141+
}
142+
if err != nil {
143+
s.fatal(err)
144+
return
145+
}
146+
s.processMessage(msg)
147+
}
148+
}
149+
150+
func (s *stream) write(msg proto.Message) error {
151+
d, err := json.Marshal(&msg)
152+
if err != nil {
153+
return err
154+
}
155+
_, err = s.stream.Write(d)
156+
if err != nil {
157+
return err
158+
}
159+
return nil
160+
}
161+
162+
func (s *stream) processMessage(msg proto.Message) {
163+
s.logger.Debug(context.Background(), "processing message", slog.F("msg", msg))
164+
165+
if msg.Error != "" {
166+
s.fatal(xerrors.New(msg.Error))
167+
return
168+
}
169+
170+
if msg.Candidate != "" {
171+
if s.rtc == nil {
172+
s.fatal(xerrors.New("rtc connection must be started before candidates are sent"))
173+
return
174+
}
175+
176+
s.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate))
177+
err := proto.AcceptICECandidate(s.rtc, &msg)
178+
if err != nil {
179+
s.fatal(err)
180+
return
181+
}
182+
}
183+
184+
if msg.Offer != nil {
185+
rtc, err := xwebrtc.NewPeerConnection()
186+
if err != nil {
187+
s.fatal(fmt.Errorf("create connection: %w", err))
188+
return
189+
}
190+
flushCandidates := proto.ProxyICECandidates(rtc, s.stream)
191+
192+
err = rtc.SetRemoteDescription(*msg.Offer)
193+
if err != nil {
194+
s.fatal(fmt.Errorf("set remote desc: %w", err))
195+
return
196+
}
197+
answer, err := rtc.CreateAnswer(nil)
198+
if err != nil {
199+
s.fatal(fmt.Errorf("create answer: %w", err))
200+
return
201+
}
202+
err = rtc.SetLocalDescription(answer)
203+
if err != nil {
204+
s.fatal(fmt.Errorf("set local desc: %w", err))
205+
return
206+
}
207+
flushCandidates()
208+
209+
err = s.write(proto.Message{
210+
Answer: rtc.LocalDescription(),
211+
})
212+
if err != nil {
213+
s.fatal(fmt.Errorf("send local desc: %w", err))
214+
return
215+
}
216+
217+
rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
218+
s.logger.Info(context.Background(), "state changed", slog.F("new", pcs))
219+
})
220+
rtc.OnDataChannel(s.processDataChannel)
221+
s.rtc = rtc
222+
}
223+
}
224+
225+
func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
226+
if channel.Protocol() == "ping" {
227+
channel.OnOpen(func() {
228+
rw, err := channel.Detach()
229+
if err != nil {
230+
return
231+
}
232+
d := make([]byte, 64)
233+
_, _ = rw.Read(d)
234+
_, _ = rw.Write(d)
235+
})
236+
return
237+
}
238+
239+
prto, port, err := xwebrtc.ParseProxyDataChannel(channel)
240+
if err != nil {
241+
s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err))
242+
return
243+
}
244+
if prto != "tcp" {
245+
s.fatal(fmt.Errorf("client provided unsupported protocol: %s", prto))
246+
return
247+
}
248+
249+
conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port))
250+
if err != nil {
251+
s.fatal(fmt.Errorf("failed to dial client port: %d", port))
252+
return
253+
}
254+
255+
channel.OnOpen(func() {
256+
s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port))
257+
rw, err := channel.Detach()
258+
if err != nil {
259+
_ = channel.Close()
260+
s.logger.Error(context.Background(), "detach client data channel", slog.Error(err))
261+
return
262+
}
263+
go func() {
264+
_, _ = io.Copy(rw, conn)
265+
}()
266+
go func() {
267+
_, _ = io.Copy(conn, rw)
268+
}()
269+
})
270+
}

internal/cmd/cmd.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ func Make() *cobra.Command {
3737
imgsCmd(),
3838
providersCmd(),
3939
genDocsCmd(app),
40+
agentCmd(),
4041
)
4142
app.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "show verbose output")
4243
return app

internal/cmd/ssh.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ import (
99
"path/filepath"
1010

1111
"github.com/spf13/cobra"
12-
"golang.org/x/crypto/ssh/terminal"
12+
"golang.org/x/term"
1313
"golang.org/x/xerrors"
1414

1515
"cdr.dev/coder-cli/coder-sdk"
1616
"cdr.dev/coder-cli/pkg/clog"
1717
)
1818

1919
var (
20-
showInteractiveOutput = terminal.IsTerminal(int(os.Stdout.Fd()))
20+
showInteractiveOutput = term.IsTerminal(int(os.Stdout.Fd()))
2121
)
2222

2323
func sshCmd() *cobra.Command {

internal/x/xterminal/terminal.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
package xterminal
44

55
import (
6-
"golang.org/x/crypto/ssh/terminal"
6+
"golang.org/x/term"
77
)
88

99
// State differs per-platform.
1010
type State struct {
11-
s *terminal.State
11+
s *term.State
1212
}
1313

1414
// MakeOutputRaw does nothing on non-Windows platforms.
@@ -20,5 +20,5 @@ func Restore(fd uintptr, state *State) error {
2020
return nil
2121
}
2222

23-
return terminal.Restore(int(fd), state.s)
23+
return term.Restore(int(fd), state.s)
2424
}

internal/x/xwebrtc/channel.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package xwebrtc
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net"
8+
"strconv"
9+
"time"
10+
11+
"github.com/pion/webrtc/v3"
12+
)
13+
14+
// WaitForDataChannelOpen waits for the data channel to have the open state.
15+
// By default, it waits 15 seconds.
16+
func WaitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) error {
17+
if channel.ReadyState() == webrtc.DataChannelStateOpen {
18+
return nil
19+
}
20+
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
21+
defer cancelFunc()
22+
channel.OnOpen(func() {
23+
cancelFunc()
24+
})
25+
<-ctx.Done()
26+
if ctx.Err() == context.DeadlineExceeded {
27+
return ctx.Err()
28+
}
29+
return nil
30+
}
31+
32+
// NewProxyDataChannel creates a new data channel for proxying.
33+
func NewProxyDataChannel(conn *webrtc.PeerConnection, name, protocol string, port uint16) (*webrtc.DataChannel, error) {
34+
proto := fmt.Sprintf("%s:%d", protocol, port)
35+
ordered := true
36+
return conn.CreateDataChannel(name, &webrtc.DataChannelInit{
37+
Protocol: &proto,
38+
Ordered: &ordered,
39+
})
40+
}
41+
42+
// ParseProxyDataChannel parses a data channel to get the protocol and port.
43+
func ParseProxyDataChannel(channel *webrtc.DataChannel) (string, uint16, error) {
44+
if channel.Protocol() == "" {
45+
return "", 0, errors.New("data channel is not a proxy")
46+
}
47+
host, port, err := net.SplitHostPort(channel.Protocol())
48+
if err != nil {
49+
return "", 0, fmt.Errorf("split protocol: %w", err)
50+
}
51+
p, err := strconv.ParseInt(port, 10, 16)
52+
if err != nil {
53+
return "", 0, fmt.Errorf("parse port: %w", err)
54+
}
55+
return host, uint16(p), nil
56+
}

0 commit comments

Comments
 (0)