-
Notifications
You must be signed in to change notification settings - Fork 18
Add coder agent start #311
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
package cmd | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net" | ||
"net/url" | ||
"os" | ||
"strings" | ||
"time" | ||
|
||
"cdr.dev/slog" | ||
"cdr.dev/slog/sloggers/sloghuman" | ||
"github.com/hashicorp/yamux" | ||
"github.com/pion/webrtc/v3" | ||
"github.com/spf13/cobra" | ||
"golang.org/x/xerrors" | ||
"nhooyr.io/websocket" | ||
|
||
"cdr.dev/coder-cli/internal/x/xcobra" | ||
"cdr.dev/coder-cli/internal/x/xwebrtc" | ||
"cdr.dev/coder-cli/pkg/proto" | ||
) | ||
|
||
func agentCmd() *cobra.Command { | ||
cmd := &cobra.Command{ | ||
Use: "agent", | ||
Short: "Run the workspace agent", | ||
Long: "Connect to Coder and start running a p2p agent", | ||
Hidden: true, | ||
} | ||
|
||
cmd.AddCommand( | ||
startCmd(), | ||
) | ||
return cmd | ||
} | ||
|
||
func startCmd() *cobra.Command { | ||
var ( | ||
token string | ||
) | ||
cmd := &cobra.Command{ | ||
Use: "start [coderURL] --token=[token]", | ||
Args: xcobra.ExactArgs(1), | ||
Short: "starts the coder agent", | ||
Long: "starts the coder agent", | ||
Example: `# start the agent and connect with a Coder agent token | ||
|
||
coder agent start https://my-coder.com --token xxxx-xxxx | ||
|
||
# start the agent and use CODER_AGENT_TOKEN env var for auth token | ||
|
||
coder agent start https://my-coder.com | ||
`, | ||
RunE: func(cmd *cobra.Command, args []string) error { | ||
ctx := cmd.Context() | ||
log := slog.Make(sloghuman.Sink(cmd.OutOrStdout())) | ||
|
||
// Pull the URL from the args and do some sanity check. | ||
rawURL := args[0] | ||
if rawURL == "" || !strings.HasPrefix(rawURL, "http") { | ||
return xerrors.Errorf("invalid URL") | ||
} | ||
u, err := url.Parse(rawURL) | ||
if err != nil { | ||
return xerrors.Errorf("parse url: %w", err) | ||
} | ||
// Remove the trailing '/' if any. | ||
u.Path = "/api/private/envagent/listen" | ||
|
||
if token == "" { | ||
var ok bool | ||
token, ok = os.LookupEnv("CODER_AGENT_TOKEN") | ||
if !ok { | ||
return xerrors.New("must pass --token or set the CODER_AGENT_TOKEN env variable") | ||
} | ||
} | ||
|
||
q := u.Query() | ||
q.Set("service_token", token) | ||
u.RawQuery = q.Encode() | ||
|
||
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) | ||
defer cancelFunc() | ||
log.Info(ctx, "connecting to broker", slog.F("url", u.String())) | ||
conn, res, err := websocket.Dial(ctx, u.String(), nil) | ||
if err != nil { | ||
return fmt.Errorf("dial: %w", err) | ||
} | ||
_ = res.Body.Close() | ||
nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is an MVP but there's a lot going on in this command so all of this should move into it's own package (or potentially even it's own repo if we want to use it within our monorepo as well). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I pinky promise I'll do that if this ends up being a thing we stick with ;) |
||
session, err := yamux.Server(nc, nil) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpicky since it doesn't really matter since yamux is bidirectional, but this should be a client because it's on the client end of a websocket connection There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kylecarbs Any objections? (since this part is your code) |
||
if err != nil { | ||
return fmt.Errorf("open: %w", err) | ||
} | ||
log.Info(ctx, "connected to broker. awaiting connection requests") | ||
for { | ||
st, err := session.AcceptStream() | ||
if err != nil { | ||
return fmt.Errorf("accept stream: %w", err) | ||
} | ||
stream := &stream{ | ||
logger: log.Named(fmt.Sprintf("stream %d", st.StreamID())), | ||
stream: st, | ||
} | ||
go stream.listen() | ||
} | ||
}, | ||
} | ||
|
||
cmd.Flags().StringVar(&token, "token", "", "coder agent token") | ||
return cmd | ||
} | ||
|
||
type stream struct { | ||
stream *yamux.Stream | ||
logger slog.Logger | ||
|
||
rtc *webrtc.PeerConnection | ||
} | ||
|
||
// writes an error and closes. | ||
func (s *stream) fatal(err error) { | ||
_ = s.write(proto.Message{ | ||
Error: err.Error(), | ||
}) | ||
s.logger.Error(context.Background(), err.Error(), slog.Error(err)) | ||
_ = s.stream.Close() | ||
} | ||
|
||
func (s *stream) listen() { | ||
decoder := json.NewDecoder(s.stream) | ||
for { | ||
var msg proto.Message | ||
err := decoder.Decode(&msg) | ||
if err == io.EOF { | ||
break | ||
} | ||
if err != nil { | ||
s.fatal(err) | ||
return | ||
} | ||
s.processMessage(msg) | ||
} | ||
} | ||
|
||
func (s *stream) write(msg proto.Message) error { | ||
d, err := json.Marshal(&msg) | ||
if err != nil { | ||
return err | ||
} | ||
_, err = s.stream.Write(d) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
func (s *stream) processMessage(msg proto.Message) { | ||
s.logger.Debug(context.Background(), "processing message", slog.F("msg", msg)) | ||
|
||
if msg.Error != "" { | ||
s.fatal(xerrors.New(msg.Error)) | ||
return | ||
} | ||
|
||
if msg.Candidate != "" { | ||
if s.rtc == nil { | ||
s.fatal(xerrors.New("rtc connection must be started before candidates are sent")) | ||
return | ||
} | ||
|
||
s.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate)) | ||
err := proto.AcceptICECandidate(s.rtc, &msg) | ||
if err != nil { | ||
s.fatal(err) | ||
return | ||
} | ||
} | ||
|
||
if msg.Offer != nil { | ||
rtc, err := xwebrtc.NewPeerConnection() | ||
if err != nil { | ||
s.fatal(fmt.Errorf("create connection: %w", err)) | ||
return | ||
} | ||
flushCandidates := proto.ProxyICECandidates(rtc, s.stream) | ||
|
||
err = rtc.SetRemoteDescription(*msg.Offer) | ||
if err != nil { | ||
s.fatal(fmt.Errorf("set remote desc: %w", err)) | ||
return | ||
} | ||
answer, err := rtc.CreateAnswer(nil) | ||
if err != nil { | ||
s.fatal(fmt.Errorf("create answer: %w", err)) | ||
return | ||
} | ||
err = rtc.SetLocalDescription(answer) | ||
if err != nil { | ||
s.fatal(fmt.Errorf("set local desc: %w", err)) | ||
return | ||
} | ||
flushCandidates() | ||
|
||
err = s.write(proto.Message{ | ||
Answer: rtc.LocalDescription(), | ||
}) | ||
if err != nil { | ||
s.fatal(fmt.Errorf("send local desc: %w", err)) | ||
return | ||
} | ||
|
||
rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { | ||
s.logger.Info(context.Background(), "state changed", slog.F("new", pcs)) | ||
}) | ||
rtc.OnDataChannel(s.processDataChannel) | ||
s.rtc = rtc | ||
} | ||
} | ||
|
||
func (s *stream) processDataChannel(channel *webrtc.DataChannel) { | ||
if channel.Protocol() == "ping" { | ||
channel.OnOpen(func() { | ||
rw, err := channel.Detach() | ||
if err != nil { | ||
return | ||
} | ||
d := make([]byte, 64) | ||
_, _ = rw.Read(d) | ||
_, _ = rw.Write(d) | ||
}) | ||
return | ||
} | ||
|
||
prto, port, err := xwebrtc.ParseProxyDataChannel(channel) | ||
if err != nil { | ||
s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err)) | ||
return | ||
} | ||
if prto != "tcp" { | ||
s.fatal(fmt.Errorf("client provided unsupported protocol: %s", prto)) | ||
return | ||
} | ||
|
||
conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port)) | ||
if err != nil { | ||
s.fatal(fmt.Errorf("failed to dial client port: %d", port)) | ||
return | ||
} | ||
|
||
channel.OnOpen(func() { | ||
s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port)) | ||
rw, err := channel.Detach() | ||
if err != nil { | ||
_ = channel.Close() | ||
s.logger.Error(context.Background(), "detach client data channel", slog.Error(err)) | ||
return | ||
} | ||
go func() { | ||
_, _ = io.Copy(rw, conn) | ||
}() | ||
go func() { | ||
_, _ = io.Copy(conn, rw) | ||
}() | ||
}) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package xwebrtc | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net" | ||
"strconv" | ||
"time" | ||
|
||
"github.com/pion/webrtc/v3" | ||
) | ||
|
||
// WaitForDataChannelOpen waits for the data channel to have the open state. | ||
// By default, it waits 15 seconds. | ||
func WaitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) error { | ||
if channel.ReadyState() == webrtc.DataChannelStateOpen { | ||
return nil | ||
} | ||
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) | ||
defer cancelFunc() | ||
channel.OnOpen(func() { | ||
cancelFunc() | ||
}) | ||
<-ctx.Done() | ||
if ctx.Err() == context.DeadlineExceeded { | ||
return ctx.Err() | ||
} | ||
return nil | ||
} | ||
|
||
// NewProxyDataChannel creates a new data channel for proxying. | ||
func NewProxyDataChannel(conn *webrtc.PeerConnection, name, protocol string, port uint16) (*webrtc.DataChannel, error) { | ||
proto := fmt.Sprintf("%s:%d", protocol, port) | ||
ordered := true | ||
return conn.CreateDataChannel(name, &webrtc.DataChannelInit{ | ||
Protocol: &proto, | ||
Ordered: &ordered, | ||
}) | ||
} | ||
|
||
// ParseProxyDataChannel parses a data channel to get the protocol and port. | ||
func ParseProxyDataChannel(channel *webrtc.DataChannel) (string, uint16, error) { | ||
if channel.Protocol() == "" { | ||
return "", 0, errors.New("data channel is not a proxy") | ||
} | ||
host, port, err := net.SplitHostPort(channel.Protocol()) | ||
if err != nil { | ||
return "", 0, fmt.Errorf("split protocol: %w", err) | ||
} | ||
p, err := strconv.ParseInt(port, 10, 16) | ||
if err != nil { | ||
return "", 0, fmt.Errorf("parse port: %w", err) | ||
} | ||
return host, uint16(p), nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think
coderURL
should be a--url
flag? I'm imagining a case where this is defaulted tocoder.com
in the future for a SaaS offering. However, I know this isn't ideal for now since thecoder agent start
cmd would fail without the flags specified, so more so opening this up for discussion. Could potentially defaultcoderURL
tocoder.com
in the future as well if it's not specified as another option.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is suuuper mvp just to get a POC working in coder, a real implementation will follow once we get an idea of how scalable and stable this is.