Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Add retries to coder agent start cmd #316

Merged
merged 2 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions agent/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package agent is for interacting with p2p server and clients
package agent
89 changes: 89 additions & 0 deletions agent/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package agent

import (
"context"
"fmt"
"net/url"
"time"

"cdr.dev/slog"
"github.com/hashicorp/yamux"
"go.coder.com/retry"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
)

const (
listenRoute = "/api/private/envagent/listen"
)

// Server connects to a Coder deployment and listens for p2p connections.
type Server struct {
log slog.Logger
listenURL *url.URL
}

// ServerArgs are the required arguments to create an agent server.
type ServerArgs struct {
Log slog.Logger
CoderURL *url.URL
Token string
}

// NewServer creates a new agent server.
func NewServer(args ServerArgs) (*Server, error) {
lURL, err := formatListenURL(args.CoderURL, args.Token)
if err != nil {
return nil, xerrors.Errorf("formatting listen url: %w", err)
}

return &Server{
log: args.Log,
listenURL: lURL,
}, nil
}

// Run will listen and proxy new peer connections on a retry loop.
func (s *Server) Run(ctx context.Context) error {
err := retry.New(time.Second).Context(ctx).Backoff(15 * time.Second).Run(func() error {
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
defer cancelFunc()
s.log.Info(ctx, "connecting to coder", slog.F("url", s.listenURL.String()))
conn, _, err := websocket.Dial(ctx, s.listenURL.String(), nil)
if err != nil {
return fmt.Errorf("dial: %w", err)
}
nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)
session, err := yamux.Server(nc, nil)
if err != nil {
return fmt.Errorf("open: %w", err)
}
s.log.Info(ctx, "connected to coder. awaiting connection requests")
for {
st, err := session.AcceptStream()
if err != nil {
return fmt.Errorf("accept stream: %w", err)
}
stream := &stream{
logger: s.log.Named(fmt.Sprintf("stream %d", st.StreamID())),
stream: st,
}
go stream.listen()
}
})

return err
}

func formatListenURL(coderURL *url.URL, token string) (*url.URL, error) {
if coderURL.Scheme != "http" && coderURL.Scheme != "https" {
return nil, xerrors.Errorf("invalid URL scheme")
}

coderURL.Path = listenRoute
q := coderURL.Query()
q.Set("service_token", token)
coderURL.RawQuery = q.Encode()

return coderURL, nil
}
185 changes: 185 additions & 0 deletions agent/stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package agent

import (
"context"
"encoding/json"
"fmt"
"io"
"net"

"cdr.dev/slog"
"github.com/hashicorp/yamux"
"github.com/pion/webrtc/v3"
"golang.org/x/xerrors"

"cdr.dev/coder-cli/internal/x/xwebrtc"
"cdr.dev/coder-cli/pkg/proto"
)

type stream struct {
stream *yamux.Stream
logger slog.Logger

rtc *webrtc.PeerConnection
}

// writes an error and closes.
func (s *stream) fatal(err error) {
_ = s.write(proto.Message{
Error: err.Error(),
})
s.logger.Error(context.Background(), err.Error(), slog.Error(err))
_ = s.stream.Close()
}

func (s *stream) listen() {
decoder := json.NewDecoder(s.stream)
for {
var msg proto.Message
err := decoder.Decode(&msg)
if err == io.EOF {
break
}
if err != nil {
s.fatal(err)
return
}
s.processMessage(msg)
}
}

func (s *stream) write(msg proto.Message) error {
d, err := json.Marshal(&msg)
if err != nil {
return err
}
_, err = s.stream.Write(d)
if err != nil {
return err
}
return nil
}

func (s *stream) processMessage(msg proto.Message) {
s.logger.Debug(context.Background(), "processing message", slog.F("msg", msg))

if msg.Error != "" {
s.fatal(xerrors.New(msg.Error))
return
}

if msg.Candidate != "" {
if s.rtc == nil {
s.fatal(xerrors.New("rtc connection must be started before candidates are sent"))
return
}

s.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate))
err := proto.AcceptICECandidate(s.rtc, &msg)
if err != nil {
s.fatal(err)
return
}
}

if msg.Offer != nil {
rtc, err := xwebrtc.NewPeerConnection()
if err != nil {
s.fatal(fmt.Errorf("create connection: %w", err))
return
}
flushCandidates := proto.ProxyICECandidates(rtc, s.stream)

err = rtc.SetRemoteDescription(*msg.Offer)
if err != nil {
s.fatal(fmt.Errorf("set remote desc: %w", err))
return
}
answer, err := rtc.CreateAnswer(nil)
if err != nil {
s.fatal(fmt.Errorf("create answer: %w", err))
return
}
err = rtc.SetLocalDescription(answer)
if err != nil {
s.fatal(fmt.Errorf("set local desc: %w", err))
return
}
flushCandidates()

err = s.write(proto.Message{
Answer: rtc.LocalDescription(),
})
if err != nil {
s.fatal(fmt.Errorf("send local desc: %w", err))
return
}

rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
s.logger.Info(context.Background(), "state changed", slog.F("new", pcs))
})
rtc.OnDataChannel(s.processDataChannel)
s.rtc = rtc
}
}

func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
if channel.Protocol() == "ping" {
channel.OnOpen(func() {
rw, err := channel.Detach()
if err != nil {
return
}
d := make([]byte, 64)
_, err = rw.Read(d)
if err != nil {
s.logger.Error(context.Background(), "read ping", slog.Error(err))
return
}
_, err = rw.Write(d)
if err != nil {
s.logger.Error(context.Background(), "write ping", slog.Error(err))
return
}
})
return
}

prto, port, err := xwebrtc.ParseProxyDataChannel(channel)
if err != nil {
s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err))
return
}
if prto != "tcp" {
s.fatal(fmt.Errorf("client provided unsupported protocol: %s", prto))
return
}

conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port))
if err != nil {
s.fatal(fmt.Errorf("failed to dial client port: %d", port))
return
}

channel.OnOpen(func() {
s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port))
rw, err := channel.Detach()
if err != nil {
_ = channel.Close()
s.logger.Error(context.Background(), "detach client data channel", slog.Error(err))
return
}
go func() {
_, err = io.Copy(rw, conn)
if err != nil {
s.logger.Error(context.Background(), "copy to conn", slog.Error(err))
}
}()
go func() {
_, _ = io.Copy(conn, rw)
if err != nil {
s.logger.Error(context.Background(), "copy from conn", slog.Error(err))
}
}()
})
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
github.com/rjeczalik/notify v0.9.2
github.com/spf13/cobra v1.1.3
go.coder.com/retry v1.2.0
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q
go.coder.com/cli v0.4.0/go.mod h1:hRTOURCR3LJF1FRW9arecgrzX+AHG7mfYMwThPIgq+w=
go.coder.com/flog v0.0.0-20190906214207-47dd47ea0512 h1:DjCS6dRQh+1PlfiBmnabxfdrzenb0tAwJqFxDEH/s9g=
go.coder.com/flog v0.0.0-20190906214207-47dd47ea0512/go.mod h1:83JsYgXYv0EOaXjIMnaZ1Fl6ddNB3fJnDZ/8845mUJ8=
go.coder.com/retry v1.2.0 h1:ODdUPu9cb9pcbeAM5j2YqJHUgfFbN60vmhtlWIKZGLo=
go.coder.com/retry v1.2.0/go.mod h1:ihkJszQk8F+yaFL2pcIku9MzbYo+U8vka4IsvQSXVfE=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
Expand Down
Loading