From 639837786bc2638e10ed545028f68af5968d52a2 Mon Sep 17 00:00:00 2001 From: Garrett Date: Thu, 15 Apr 2021 20:55:12 +0000 Subject: [PATCH 1/2] Add retries to coder agent start cmd --- agent/doc.go | 2 + agent/server.go | 89 +++++++++++++++++ agent/stream.go | 171 +++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + internal/cmd/agent.go | 219 +++--------------------------------------- 6 files changed, 279 insertions(+), 205 deletions(-) create mode 100644 agent/doc.go create mode 100644 agent/server.go create mode 100644 agent/stream.go diff --git a/agent/doc.go b/agent/doc.go new file mode 100644 index 00000000..46ebd899 --- /dev/null +++ b/agent/doc.go @@ -0,0 +1,2 @@ +// Package agent is for interacting with p2p server and clients +package agent diff --git a/agent/server.go b/agent/server.go new file mode 100644 index 00000000..17bb51f7 --- /dev/null +++ b/agent/server.go @@ -0,0 +1,89 @@ +package agent + +import ( + "context" + "fmt" + "net/url" + "time" + + "cdr.dev/slog" + "github.com/hashicorp/yamux" + "go.coder.com/retry" + "golang.org/x/xerrors" + "nhooyr.io/websocket" +) + +const ( + listenRoute = "/api/private/envagent/listen" +) + +// Server connects to a Coder deployment and listens for p2p connections. +type Server struct { + log slog.Logger + listenURL *url.URL +} + +// ServerArgs are the required arguments to create an agent server. +type ServerArgs struct { + Log slog.Logger + CoderURL *url.URL + Token string +} + +// NewServer creates a new agent server. +func NewServer(args ServerArgs) (*Server, error) { + lURL, err := formatListenURL(args.CoderURL, args.Token) + if err != nil { + return nil, xerrors.Errorf("formatting listen url: %w", err) + } + + return &Server{ + log: args.Log, + listenURL: lURL, + }, nil +} + +// Run will listen and proxy new peer connections on a retry loop. +func (s *Server) Run(ctx context.Context) error { + err := retry.New(time.Second).Context(ctx).Backoff(15 * time.Second).Run(func() error { + ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) + defer cancelFunc() + s.log.Info(ctx, "connecting to coder", slog.F("url", s.listenURL.String())) + conn, _, err := websocket.Dial(ctx, s.listenURL.String(), nil) + if err != nil { + return fmt.Errorf("dial: %w", err) + } + nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary) + session, err := yamux.Server(nc, nil) + if err != nil { + return fmt.Errorf("open: %w", err) + } + s.log.Info(ctx, "connected to coder. awaiting connection requests") + for { + st, err := session.AcceptStream() + if err != nil { + return fmt.Errorf("accept stream: %w", err) + } + stream := &stream{ + logger: s.log.Named(fmt.Sprintf("stream %d", st.StreamID())), + stream: st, + } + go stream.listen() + } + }) + + return err +} + +func formatListenURL(coderURL *url.URL, token string) (*url.URL, error) { + if coderURL.Scheme != "http" && coderURL.Scheme != "https" { + return nil, xerrors.Errorf("invalid URL scheme") + } + + coderURL.Path = listenRoute + q := coderURL.Query() + q.Set("service_token", token) + coderURL.RawQuery = q.Encode() + + return coderURL, nil +} diff --git a/agent/stream.go b/agent/stream.go new file mode 100644 index 00000000..1c033141 --- /dev/null +++ b/agent/stream.go @@ -0,0 +1,171 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + + "cdr.dev/slog" + "github.com/hashicorp/yamux" + "github.com/pion/webrtc/v3" + "golang.org/x/xerrors" + + "cdr.dev/coder-cli/internal/x/xwebrtc" + "cdr.dev/coder-cli/pkg/proto" +) + +type stream struct { + stream *yamux.Stream + logger slog.Logger + + rtc *webrtc.PeerConnection +} + +// writes an error and closes. +func (s *stream) fatal(err error) { + _ = s.write(proto.Message{ + Error: err.Error(), + }) + s.logger.Error(context.Background(), err.Error(), slog.Error(err)) + _ = s.stream.Close() +} + +func (s *stream) listen() { + decoder := json.NewDecoder(s.stream) + for { + var msg proto.Message + err := decoder.Decode(&msg) + if err == io.EOF { + break + } + if err != nil { + s.fatal(err) + return + } + s.processMessage(msg) + } +} + +func (s *stream) write(msg proto.Message) error { + d, err := json.Marshal(&msg) + if err != nil { + return err + } + _, err = s.stream.Write(d) + if err != nil { + return err + } + return nil +} + +func (s *stream) processMessage(msg proto.Message) { + s.logger.Debug(context.Background(), "processing message", slog.F("msg", msg)) + + if msg.Error != "" { + s.fatal(xerrors.New(msg.Error)) + return + } + + if msg.Candidate != "" { + if s.rtc == nil { + s.fatal(xerrors.New("rtc connection must be started before candidates are sent")) + return + } + + s.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate)) + err := proto.AcceptICECandidate(s.rtc, &msg) + if err != nil { + s.fatal(err) + return + } + } + + if msg.Offer != nil { + rtc, err := xwebrtc.NewPeerConnection() + if err != nil { + s.fatal(fmt.Errorf("create connection: %w", err)) + return + } + flushCandidates := proto.ProxyICECandidates(rtc, s.stream) + + err = rtc.SetRemoteDescription(*msg.Offer) + if err != nil { + s.fatal(fmt.Errorf("set remote desc: %w", err)) + return + } + answer, err := rtc.CreateAnswer(nil) + if err != nil { + s.fatal(fmt.Errorf("create answer: %w", err)) + return + } + err = rtc.SetLocalDescription(answer) + if err != nil { + s.fatal(fmt.Errorf("set local desc: %w", err)) + return + } + flushCandidates() + + err = s.write(proto.Message{ + Answer: rtc.LocalDescription(), + }) + if err != nil { + s.fatal(fmt.Errorf("send local desc: %w", err)) + return + } + + rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { + s.logger.Info(context.Background(), "state changed", slog.F("new", pcs)) + }) + rtc.OnDataChannel(s.processDataChannel) + s.rtc = rtc + } +} + +func (s *stream) processDataChannel(channel *webrtc.DataChannel) { + if channel.Protocol() == "ping" { + channel.OnOpen(func() { + rw, err := channel.Detach() + if err != nil { + return + } + d := make([]byte, 64) + _, _ = rw.Read(d) + _, _ = rw.Write(d) + }) + return + } + + prto, port, err := xwebrtc.ParseProxyDataChannel(channel) + if err != nil { + s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err)) + return + } + if prto != "tcp" { + s.fatal(fmt.Errorf("client provided unsupported protocol: %s", prto)) + return + } + + conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port)) + if err != nil { + s.fatal(fmt.Errorf("failed to dial client port: %d", port)) + return + } + + channel.OnOpen(func() { + s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port)) + rw, err := channel.Detach() + if err != nil { + _ = channel.Close() + s.logger.Error(context.Background(), "detach client data channel", slog.Error(err)) + return + } + go func() { + _, _ = io.Copy(rw, conn) + }() + go func() { + _, _ = io.Copy(conn, rw) + }() + }) +} diff --git a/go.mod b/go.mod index 8014ce0e..1a88ad94 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 github.com/rjeczalik/notify v0.9.2 github.com/spf13/cobra v1.1.3 + go.coder.com/retry v1.2.0 golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 diff --git a/go.sum b/go.sum index a9c11e0d..0ec0d443 100644 --- a/go.sum +++ b/go.sum @@ -358,6 +358,8 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q go.coder.com/cli v0.4.0/go.mod h1:hRTOURCR3LJF1FRW9arecgrzX+AHG7mfYMwThPIgq+w= go.coder.com/flog v0.0.0-20190906214207-47dd47ea0512 h1:DjCS6dRQh+1PlfiBmnabxfdrzenb0tAwJqFxDEH/s9g= go.coder.com/flog v0.0.0-20190906214207-47dd47ea0512/go.mod h1:83JsYgXYv0EOaXjIMnaZ1Fl6ddNB3fJnDZ/8845mUJ8= +go.coder.com/retry v1.2.0 h1:ODdUPu9cb9pcbeAM5j2YqJHUgfFbN60vmhtlWIKZGLo= +go.coder.com/retry v1.2.0/go.mod h1:ihkJszQk8F+yaFL2pcIku9MzbYo+U8vka4IsvQSXVfE= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/internal/cmd/agent.go b/internal/cmd/agent.go index dbe7d80a..84995bbb 100644 --- a/internal/cmd/agent.go +++ b/internal/cmd/agent.go @@ -2,25 +2,15 @@ package cmd import ( "context" - "encoding/json" - "fmt" - "io" - "net" "net/url" "os" - "strings" - "time" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" - "github.com/hashicorp/yamux" - "github.com/pion/webrtc/v3" "github.com/spf13/cobra" "golang.org/x/xerrors" - "nhooyr.io/websocket" - "cdr.dev/coder-cli/internal/x/xwebrtc" - "cdr.dev/coder-cli/pkg/proto" + "cdr.dev/coder-cli/agent" ) func agentCmd() *cobra.Command { @@ -60,7 +50,7 @@ coder agent start --coder-url https://my-coder.com --token xxxx-xxxx if coderURL == "" { var ok bool - token, ok = os.LookupEnv("CODER_URL") + coderURL, ok = os.LookupEnv("CODER_URL") if !ok { client, err := newClient(ctx) if err != nil { @@ -71,15 +61,10 @@ coder agent start --coder-url https://my-coder.com --token xxxx-xxxx } } - if !strings.HasPrefix(coderURL, "http") { - return xerrors.Errorf("invalid URL") - } u, err := url.Parse(coderURL) if err != nil { return xerrors.Errorf("parse url: %w", err) } - // Remove the trailing '/' if any. - u.Path = "/api/private/envagent/listen" if token == "" { var ok bool @@ -89,43 +74,21 @@ coder agent start --coder-url https://my-coder.com --token xxxx-xxxx } } - if token == "" { - var ok bool - token, ok = os.LookupEnv("CODER_AGENT_TOKEN") - if !ok { - return xerrors.New("must pass --token or set the CODER_AGENT_TOKEN env variable") - } - } - - q := u.Query() - q.Set("service_token", token) - u.RawQuery = q.Encode() - - ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) - defer cancelFunc() - log.Info(ctx, "connecting to broker", slog.F("url", u.String())) - // nolint: bodyclose - conn, _, err := websocket.Dial(ctx, u.String(), nil) + server, err := agent.NewServer(agent.ServerArgs{ + Log: log, + CoderURL: u, + Token: token, + }) if err != nil { - return fmt.Errorf("dial: %w", err) + return xerrors.Errorf("creating agent server: %w", err) } - nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary) - session, err := yamux.Server(nc, nil) - if err != nil { - return fmt.Errorf("open: %w", err) - } - log.Info(ctx, "connected to broker. awaiting connection requests") - for { - st, err := session.AcceptStream() - if err != nil { - return fmt.Errorf("accept stream: %w", err) - } - stream := &stream{ - logger: log.Named(fmt.Sprintf("stream %d", st.StreamID())), - stream: st, - } - go stream.listen() + + err = server.Run(ctx) + if err != nil && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) { + return xerrors.Errorf("running agent server: %w", err) } + + return nil }, } @@ -134,157 +97,3 @@ coder agent start --coder-url https://my-coder.com --token xxxx-xxxx return cmd } - -type stream struct { - stream *yamux.Stream - logger slog.Logger - - rtc *webrtc.PeerConnection -} - -// writes an error and closes. -func (s *stream) fatal(err error) { - _ = s.write(proto.Message{ - Error: err.Error(), - }) - s.logger.Error(context.Background(), err.Error(), slog.Error(err)) - _ = s.stream.Close() -} - -func (s *stream) listen() { - decoder := json.NewDecoder(s.stream) - for { - var msg proto.Message - err := decoder.Decode(&msg) - if err == io.EOF { - break - } - if err != nil { - s.fatal(err) - return - } - s.processMessage(msg) - } -} - -func (s *stream) write(msg proto.Message) error { - d, err := json.Marshal(&msg) - if err != nil { - return err - } - _, err = s.stream.Write(d) - if err != nil { - return err - } - return nil -} - -func (s *stream) processMessage(msg proto.Message) { - s.logger.Debug(context.Background(), "processing message", slog.F("msg", msg)) - - if msg.Error != "" { - s.fatal(xerrors.New(msg.Error)) - return - } - - if msg.Candidate != "" { - if s.rtc == nil { - s.fatal(xerrors.New("rtc connection must be started before candidates are sent")) - return - } - - s.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate)) - err := proto.AcceptICECandidate(s.rtc, &msg) - if err != nil { - s.fatal(err) - return - } - } - - if msg.Offer != nil { - rtc, err := xwebrtc.NewPeerConnection() - if err != nil { - s.fatal(fmt.Errorf("create connection: %w", err)) - return - } - flushCandidates := proto.ProxyICECandidates(rtc, s.stream) - - err = rtc.SetRemoteDescription(*msg.Offer) - if err != nil { - s.fatal(fmt.Errorf("set remote desc: %w", err)) - return - } - answer, err := rtc.CreateAnswer(nil) - if err != nil { - s.fatal(fmt.Errorf("create answer: %w", err)) - return - } - err = rtc.SetLocalDescription(answer) - if err != nil { - s.fatal(fmt.Errorf("set local desc: %w", err)) - return - } - flushCandidates() - - err = s.write(proto.Message{ - Answer: rtc.LocalDescription(), - }) - if err != nil { - s.fatal(fmt.Errorf("send local desc: %w", err)) - return - } - - rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { - s.logger.Info(context.Background(), "state changed", slog.F("new", pcs)) - }) - rtc.OnDataChannel(s.processDataChannel) - s.rtc = rtc - } -} - -func (s *stream) processDataChannel(channel *webrtc.DataChannel) { - if channel.Protocol() == "ping" { - channel.OnOpen(func() { - rw, err := channel.Detach() - if err != nil { - return - } - d := make([]byte, 64) - _, _ = rw.Read(d) - _, _ = rw.Write(d) - }) - return - } - - prto, port, err := xwebrtc.ParseProxyDataChannel(channel) - if err != nil { - s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err)) - return - } - if prto != "tcp" { - s.fatal(fmt.Errorf("client provided unsupported protocol: %s", prto)) - return - } - - conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port)) - if err != nil { - s.fatal(fmt.Errorf("failed to dial client port: %d", port)) - return - } - - channel.OnOpen(func() { - s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port)) - rw, err := channel.Detach() - if err != nil { - _ = channel.Close() - s.logger.Error(context.Background(), "detach client data channel", slog.Error(err)) - return - } - go func() { - _, _ = io.Copy(rw, conn) - }() - go func() { - _, _ = io.Copy(conn, rw) - }() - }) -} From 6580c8042690ecf58a0ca67752c2baa3aa877ddd Mon Sep 17 00:00:00 2001 From: Garrett Date: Thu, 15 Apr 2021 21:28:28 +0000 Subject: [PATCH 2/2] Add more error logs --- agent/stream.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/agent/stream.go b/agent/stream.go index 1c033141..42a2233f 100644 --- a/agent/stream.go +++ b/agent/stream.go @@ -131,8 +131,16 @@ func (s *stream) processDataChannel(channel *webrtc.DataChannel) { return } d := make([]byte, 64) - _, _ = rw.Read(d) - _, _ = rw.Write(d) + _, err = rw.Read(d) + if err != nil { + s.logger.Error(context.Background(), "read ping", slog.Error(err)) + return + } + _, err = rw.Write(d) + if err != nil { + s.logger.Error(context.Background(), "write ping", slog.Error(err)) + return + } }) return } @@ -162,10 +170,16 @@ func (s *stream) processDataChannel(channel *webrtc.DataChannel) { return } go func() { - _, _ = io.Copy(rw, conn) + _, err = io.Copy(rw, conn) + if err != nil { + s.logger.Error(context.Background(), "copy to conn", slog.Error(err)) + } }() go func() { _, _ = io.Copy(conn, rw) + if err != nil { + s.logger.Error(context.Background(), "copy from conn", slog.Error(err)) + } }() }) }