Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Add coder agent start #311

Merged
merged 2 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ require (
github.com/fatih/color v1.10.0
github.com/google/go-cmp v0.5.5
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/yamux v0.0.0-20210316155119-a95892c5f864
github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f
github.com/klauspost/compress v1.10.8 // indirect
github.com/manifoldco/promptui v0.8.0
github.com/pion/webrtc/v3 v3.0.20
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
github.com/rjeczalik/notify v0.9.2
github.com/spf13/cobra v1.1.3
github.com/stretchr/testify v1.6.1 // indirect
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208
golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
nhooyr.io/websocket v1.8.6
Expand Down
111 changes: 102 additions & 9 deletions go.sum

Large diffs are not rendered by default.

270 changes: 270 additions & 0 deletions internal/cmd/agent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
package cmd

import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/url"
"os"
"strings"
"time"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/hashicorp/yamux"
"github.com/pion/webrtc/v3"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"nhooyr.io/websocket"

"cdr.dev/coder-cli/internal/x/xcobra"
"cdr.dev/coder-cli/internal/x/xwebrtc"
"cdr.dev/coder-cli/pkg/proto"
)

func agentCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "agent",
Short: "Run the workspace agent",
Long: "Connect to Coder and start running a p2p agent",
Hidden: true,
}

cmd.AddCommand(
startCmd(),
)
return cmd
}

func startCmd() *cobra.Command {
var (
token string
)
cmd := &cobra.Command{
Use: "start [coderURL] --token=[token]",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think coderURL should be a --url flag? I'm imagining a case where this is defaulted to coder.com in the future for a SaaS offering. However, I know this isn't ideal for now since the coder agent start cmd would fail without the flags specified, so more so opening this up for discussion. Could potentially default coderURL to coder.com in the future as well if it's not specified as another option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is suuuper mvp just to get a POC working in coder, a real implementation will follow once we get an idea of how scalable and stable this is.

Args: xcobra.ExactArgs(1),
Short: "starts the coder agent",
Long: "starts the coder agent",
Example: `# start the agent and connect with a Coder agent token

coder agent start https://my-coder.com --token xxxx-xxxx

# start the agent and use CODER_AGENT_TOKEN env var for auth token

coder agent start https://my-coder.com
`,
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
log := slog.Make(sloghuman.Sink(cmd.OutOrStdout()))

// Pull the URL from the args and do some sanity check.
rawURL := args[0]
if rawURL == "" || !strings.HasPrefix(rawURL, "http") {
return xerrors.Errorf("invalid URL")
}
u, err := url.Parse(rawURL)
if err != nil {
return xerrors.Errorf("parse url: %w", err)
}
// Remove the trailing '/' if any.
u.Path = "/api/private/envagent/listen"

if token == "" {
var ok bool
token, ok = os.LookupEnv("CODER_AGENT_TOKEN")
if !ok {
return xerrors.New("must pass --token or set the CODER_AGENT_TOKEN env variable")
}
}

q := u.Query()
q.Set("service_token", token)
u.RawQuery = q.Encode()

ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
defer cancelFunc()
log.Info(ctx, "connecting to broker", slog.F("url", u.String()))
conn, res, err := websocket.Dial(ctx, u.String(), nil)
if err != nil {
return fmt.Errorf("dial: %w", err)
}
_ = res.Body.Close()
nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is an MVP but there's a lot going on in this command so all of this should move into it's own package (or potentially even it's own repo if we want to use it within our monorepo as well).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pinky promise I'll do that if this ends up being a thing we stick with ;)

session, err := yamux.Server(nc, nil)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicky since it doesn't really matter since yamux is bidirectional, but this should be a client because it's on the client end of a websocket connection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kylecarbs Any objections? (since this part is your code)

if err != nil {
return fmt.Errorf("open: %w", err)
}
log.Info(ctx, "connected to broker. awaiting connection requests")
for {
st, err := session.AcceptStream()
if err != nil {
return fmt.Errorf("accept stream: %w", err)
}
stream := &stream{
logger: log.Named(fmt.Sprintf("stream %d", st.StreamID())),
stream: st,
}
go stream.listen()
}
},
}

cmd.Flags().StringVar(&token, "token", "", "coder agent token")
return cmd
}

type stream struct {
stream *yamux.Stream
logger slog.Logger

rtc *webrtc.PeerConnection
}

// writes an error and closes.
func (s *stream) fatal(err error) {
_ = s.write(proto.Message{
Error: err.Error(),
})
s.logger.Error(context.Background(), err.Error(), slog.Error(err))
_ = s.stream.Close()
}

func (s *stream) listen() {
decoder := json.NewDecoder(s.stream)
for {
var msg proto.Message
err := decoder.Decode(&msg)
if err == io.EOF {
break
}
if err != nil {
s.fatal(err)
return
}
s.processMessage(msg)
}
}

func (s *stream) write(msg proto.Message) error {
d, err := json.Marshal(&msg)
if err != nil {
return err
}
_, err = s.stream.Write(d)
if err != nil {
return err
}
return nil
}

func (s *stream) processMessage(msg proto.Message) {
s.logger.Debug(context.Background(), "processing message", slog.F("msg", msg))

if msg.Error != "" {
s.fatal(xerrors.New(msg.Error))
return
}

if msg.Candidate != "" {
if s.rtc == nil {
s.fatal(xerrors.New("rtc connection must be started before candidates are sent"))
return
}

s.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate))
err := proto.AcceptICECandidate(s.rtc, &msg)
if err != nil {
s.fatal(err)
return
}
}

if msg.Offer != nil {
rtc, err := xwebrtc.NewPeerConnection()
if err != nil {
s.fatal(fmt.Errorf("create connection: %w", err))
return
}
flushCandidates := proto.ProxyICECandidates(rtc, s.stream)

err = rtc.SetRemoteDescription(*msg.Offer)
if err != nil {
s.fatal(fmt.Errorf("set remote desc: %w", err))
return
}
answer, err := rtc.CreateAnswer(nil)
if err != nil {
s.fatal(fmt.Errorf("create answer: %w", err))
return
}
err = rtc.SetLocalDescription(answer)
if err != nil {
s.fatal(fmt.Errorf("set local desc: %w", err))
return
}
flushCandidates()

err = s.write(proto.Message{
Answer: rtc.LocalDescription(),
})
if err != nil {
s.fatal(fmt.Errorf("send local desc: %w", err))
return
}

rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
s.logger.Info(context.Background(), "state changed", slog.F("new", pcs))
})
rtc.OnDataChannel(s.processDataChannel)
s.rtc = rtc
}
}

func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
if channel.Protocol() == "ping" {
channel.OnOpen(func() {
rw, err := channel.Detach()
if err != nil {
return
}
d := make([]byte, 64)
_, _ = rw.Read(d)
_, _ = rw.Write(d)
})
return
}

prto, port, err := xwebrtc.ParseProxyDataChannel(channel)
if err != nil {
s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err))
return
}
if prto != "tcp" {
s.fatal(fmt.Errorf("client provided unsupported protocol: %s", prto))
return
}

conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port))
if err != nil {
s.fatal(fmt.Errorf("failed to dial client port: %d", port))
return
}

channel.OnOpen(func() {
s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port))
rw, err := channel.Detach()
if err != nil {
_ = channel.Close()
s.logger.Error(context.Background(), "detach client data channel", slog.Error(err))
return
}
go func() {
_, _ = io.Copy(rw, conn)
}()
go func() {
_, _ = io.Copy(conn, rw)
}()
})
}
1 change: 1 addition & 0 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func Make() *cobra.Command {
imgsCmd(),
providersCmd(),
genDocsCmd(app),
agentCmd(),
)
app.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "show verbose output")
return app
Expand Down
4 changes: 2 additions & 2 deletions internal/cmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (
"path/filepath"

"github.com/spf13/cobra"
"golang.org/x/crypto/ssh/terminal"
"golang.org/x/term"
"golang.org/x/xerrors"

"cdr.dev/coder-cli/coder-sdk"
"cdr.dev/coder-cli/pkg/clog"
)

var (
showInteractiveOutput = terminal.IsTerminal(int(os.Stdout.Fd()))
showInteractiveOutput = term.IsTerminal(int(os.Stdout.Fd()))
)

func sshCmd() *cobra.Command {
Expand Down
6 changes: 3 additions & 3 deletions internal/x/xterminal/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
package xterminal

import (
"golang.org/x/crypto/ssh/terminal"
"golang.org/x/term"
)

// State differs per-platform.
type State struct {
s *terminal.State
s *term.State
}

// MakeOutputRaw does nothing on non-Windows platforms.
Expand All @@ -20,5 +20,5 @@ func Restore(fd uintptr, state *State) error {
return nil
}

return terminal.Restore(int(fd), state.s)
return term.Restore(int(fd), state.s)
}
56 changes: 56 additions & 0 deletions internal/x/xwebrtc/channel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package xwebrtc

import (
"context"
"errors"
"fmt"
"net"
"strconv"
"time"

"github.com/pion/webrtc/v3"
)

// WaitForDataChannelOpen waits for the data channel to have the open state.
// By default, it waits 15 seconds.
func WaitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) error {
if channel.ReadyState() == webrtc.DataChannelStateOpen {
return nil
}
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
defer cancelFunc()
channel.OnOpen(func() {
cancelFunc()
})
<-ctx.Done()
if ctx.Err() == context.DeadlineExceeded {
return ctx.Err()
}
return nil
}

// NewProxyDataChannel creates a new data channel for proxying.
func NewProxyDataChannel(conn *webrtc.PeerConnection, name, protocol string, port uint16) (*webrtc.DataChannel, error) {
proto := fmt.Sprintf("%s:%d", protocol, port)
ordered := true
return conn.CreateDataChannel(name, &webrtc.DataChannelInit{
Protocol: &proto,
Ordered: &ordered,
})
}

// ParseProxyDataChannel parses a data channel to get the protocol and port.
func ParseProxyDataChannel(channel *webrtc.DataChannel) (string, uint16, error) {
if channel.Protocol() == "" {
return "", 0, errors.New("data channel is not a proxy")
}
host, port, err := net.SplitHostPort(channel.Protocol())
if err != nil {
return "", 0, fmt.Errorf("split protocol: %w", err)
}
p, err := strconv.ParseInt(port, 10, 16)
if err != nil {
return "", 0, fmt.Errorf("parse port: %w", err)
}
return host, uint16(p), nil
}
Loading