1
+ package cmd
2
+
3
+ import (
4
+ "cdr.dev/coder-cli/internal/x/xcobra"
5
+ "cdr.dev/coder-cli/internal/x/xwebrtc"
6
+ "cdr.dev/coder-cli/pkg/proto"
7
+ "cdr.dev/slog"
8
+ "cdr.dev/slog/sloggers/sloghuman"
9
+ "context"
10
+ "encoding/json"
11
+ "fmt"
12
+ "github.com/hashicorp/yamux"
13
+ "github.com/pion/webrtc/v3"
14
+ "github.com/spf13/cobra"
15
+ "golang.org/x/xerrors"
16
+ "io"
17
+ "net"
18
+ "net/url"
19
+ "nhooyr.io/websocket"
20
+ "os"
21
+ "strings"
22
+ "time"
23
+ )
24
+
25
+ func agentCmd () * cobra.Command {
26
+ cmd := & cobra.Command {
27
+ Use : "agent" ,
28
+ Short : "Run the workspace agent" ,
29
+ Long : "Connect to Coder and start running a p2p agent" ,
30
+ Hidden : true ,
31
+ }
32
+
33
+ cmd .AddCommand (
34
+ startCmd (),
35
+ )
36
+ return cmd
37
+ }
38
+
39
+ func startCmd () * cobra.Command {
40
+ var (
41
+ token string
42
+ )
43
+ cmd := & cobra.Command {
44
+ Use : "start [coderURL] --token=[token]" ,
45
+ Args : xcobra .ExactArgs (1 ),
46
+ Short : "starts the coder agent" ,
47
+ Long : "starts the coder agent" ,
48
+ Example : `# start the agent and connect with a Coder agent token
49
+
50
+ coder agent start https://my-coder.com --token xxxx-xxxx
51
+
52
+ # start the agent and use CODER_AGENT_TOKEN env var for auth token
53
+
54
+ coder agent start https://my-coder.com
55
+ ` ,
56
+ RunE : func (cmd * cobra.Command , args []string ) error {
57
+ ctx := cmd .Context ()
58
+ log := slog .Make (sloghuman .Sink (cmd .OutOrStdout ()))
59
+
60
+ // Pull the URL from the args and do some sanity check.
61
+ rawURL := args [0 ]
62
+ if rawURL == "" || ! strings .HasPrefix (rawURL , "http" ) {
63
+ return xerrors .Errorf ("invalid URL" )
64
+ }
65
+ u , err := url .Parse (rawURL )
66
+ if err != nil {
67
+ return xerrors .Errorf ("parse url: %w" , err )
68
+ }
69
+ // Remove the trailing '/' if any.
70
+ u .Path = "/api/private/envagent/listen"
71
+
72
+ if token == "" {
73
+ var ok bool
74
+ token , ok = os .LookupEnv ("CODER_AGENT_TOKEN" )
75
+ if ! ok {
76
+ return xerrors .New ("must pass --token or set the CODER_AGENT_TOKEN env variable" )
77
+ }
78
+ }
79
+
80
+ q := u .Query ()
81
+ q .Set ("agent_token" , token )
82
+ u .RawQuery = q .Encode ()
83
+
84
+ ctx , cancelFunc := context .WithTimeout (context .Background (), time .Second * 15 )
85
+ defer cancelFunc ()
86
+ log .Info (ctx , "connecting to broker" , slog .F ("url" , u .String ()))
87
+ conn , _ , err := websocket .Dial (ctx , u .String (), nil )
88
+ if err != nil {
89
+ return fmt .Errorf ("dial: %w" , err )
90
+ }
91
+ nc := websocket .NetConn (context .Background (), conn , websocket .MessageBinary )
92
+ session , err := yamux .Server (nc , nil )
93
+ if err != nil {
94
+ return fmt .Errorf ("open: %w" , err )
95
+ }
96
+ log .Info (ctx , "connected to broker. awaiting connection requests" )
97
+ for {
98
+ st , err := session .AcceptStream ()
99
+ if err != nil {
100
+ return fmt .Errorf ("accept stream: %w" , err )
101
+ }
102
+ stream := & stream {
103
+ logger : log .Named (fmt .Sprintf ("stream %d" , st .StreamID ())),
104
+ stream : st ,
105
+ }
106
+ go stream .listen ()
107
+ }
108
+ },
109
+ }
110
+
111
+ cmd .Flags ().StringVar (& token , "token" , "" , "coder agent token" )
112
+ return cmd
113
+ }
114
+
115
+ type stream struct {
116
+ stream * yamux.Stream
117
+ logger slog.Logger
118
+
119
+ rtc * webrtc.PeerConnection
120
+ }
121
+
122
+ // writes an error and closes
123
+ func (s * stream ) fatal (err error ) {
124
+ s .write (proto.Message {
125
+ Error : err .Error (),
126
+ })
127
+ s .logger .Error (context .Background (), err .Error (), slog .Error (err ))
128
+ s .stream .Close ()
129
+ }
130
+
131
+ func (s * stream ) listen () {
132
+ decoder := json .NewDecoder (s .stream )
133
+ for {
134
+ var msg proto.Message
135
+ err := decoder .Decode (& msg )
136
+ if err == io .EOF {
137
+ break
138
+ }
139
+ if err != nil {
140
+ s .fatal (err )
141
+ return
142
+ }
143
+ s .processMessage (msg )
144
+ }
145
+ }
146
+
147
+ func (s * stream ) write (msg proto.Message ) error {
148
+ d , err := json .Marshal (& msg )
149
+ if err != nil {
150
+ return err
151
+ }
152
+ _ , err = s .stream .Write (d )
153
+ if err != nil {
154
+ return err
155
+ }
156
+ return nil
157
+ }
158
+
159
+ func (s * stream ) processMessage (msg proto.Message ) {
160
+ s .logger .Debug (context .Background (), "processing message" , slog .F ("msg" , msg ))
161
+
162
+ if msg .Error != "" {
163
+ s .fatal (xerrors .New (msg .Error ))
164
+ return
165
+ }
166
+
167
+ if msg .Candidate != "" {
168
+ if s .rtc == nil {
169
+ s .fatal (xerrors .New ("rtc connection must be started before candidates are sent" ))
170
+ return
171
+ }
172
+
173
+ s .logger .Debug (context .Background (), "accepted ice candidate" , slog .F ("candidate" , msg .Candidate ))
174
+ err := proto .AcceptICECandidate (s .rtc , & msg )
175
+ if err != nil {
176
+ s .fatal (err )
177
+ return
178
+ }
179
+ }
180
+
181
+ if msg .Offer != nil {
182
+ rtc , err := xwebrtc .NewPeerConnection ()
183
+ if err != nil {
184
+ s .fatal (fmt .Errorf ("create connection: %w" , err ))
185
+ return
186
+ }
187
+ flushCandidates := proto .ProxyICECandidates (rtc , s .stream )
188
+
189
+ err = rtc .SetRemoteDescription (* msg .Offer )
190
+ if err != nil {
191
+ s .fatal (fmt .Errorf ("set remote desc: %w" , err ))
192
+ return
193
+ }
194
+ answer , err := rtc .CreateAnswer (nil )
195
+ if err != nil {
196
+ s .fatal (fmt .Errorf ("create answer: %w" , err ))
197
+ return
198
+ }
199
+ err = rtc .SetLocalDescription (answer )
200
+ if err != nil {
201
+ s .fatal (fmt .Errorf ("set local desc: %w" , err ))
202
+ return
203
+ }
204
+ flushCandidates ()
205
+
206
+ err = s .write (proto.Message {
207
+ Answer : rtc .LocalDescription (),
208
+ })
209
+ if err != nil {
210
+ s .fatal (fmt .Errorf ("send local desc: %w" , err ))
211
+ return
212
+ }
213
+
214
+ rtc .OnConnectionStateChange (func (pcs webrtc.PeerConnectionState ) {
215
+ s .logger .Info (context .Background (), "state changed" , slog .F ("new" , pcs ))
216
+ })
217
+ rtc .OnDataChannel (s .processDataChannel )
218
+ s .rtc = rtc
219
+ }
220
+ }
221
+
222
+ func (s * stream ) processDataChannel (channel * webrtc.DataChannel ) {
223
+ if channel .Protocol () == "ping" {
224
+ channel .OnOpen (func () {
225
+ rw , err := channel .Detach ()
226
+ if err != nil {
227
+ return
228
+ }
229
+ d := make ([]byte , 64 )
230
+ _ , err = rw .Read (d )
231
+ rw .Write (d )
232
+ })
233
+ return
234
+ }
235
+
236
+ proto , port , err := xwebrtc .ParseProxyDataChannel (channel )
237
+ if proto != "tcp" {
238
+ s .fatal (fmt .Errorf ("client provided unsupported protocol: %s" , proto ))
239
+ return
240
+ }
241
+
242
+ conn , err := net .Dial (proto , fmt .Sprintf ("localhost:%d" , port ))
243
+ if err != nil {
244
+ s .fatal (fmt .Errorf ("failed to dial client port: %d" , port ))
245
+ return
246
+ }
247
+
248
+ channel .OnOpen (func () {
249
+ s .logger .Debug (context .Background (), "proxying data channel to local port" , slog .F ("port" , port ))
250
+ rw , err := channel .Detach ()
251
+ if err != nil {
252
+ channel .Close ()
253
+ s .logger .Error (context .Background (), "detach client data channel" , slog .Error (err ))
254
+ return
255
+ }
256
+ go io .Copy (rw , conn )
257
+ go io .Copy (conn , rw )
258
+ })
259
+ }
0 commit comments