Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

feat: Add DialWebsocket func to enable Dial through net.Conn #335

Merged
merged 1 commit into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/cmd/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (c *tunnneler) start(ctx context.Context) error {
}

c.log.Info(ctx, "Connecting to workspace...")
wd, err := wsnet.Dial(ctx, wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), []webrtc.ICEServer{server})
wd, err := wsnet.DialWebsocket(ctx, wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), []webrtc.ICEServer{server})
if err != nil {
return xerrors.Errorf("creating workspace dialer: %w", err)
}
Expand Down
34 changes: 19 additions & 15 deletions wsnet/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@ import (
"cdr.dev/coder-cli/coder-sdk"
)

// Dial connects to the broker and negotiates a connection to a listener.
func Dial(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*Dialer, error) {
if iceServers == nil {
iceServers = []webrtc.ICEServer{}
}

// DialWebsocket dials the broker with a WebSocket and negotiates a connection.
func DialWebsocket(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*Dialer, error) {
conn, resp, err := websocket.Dial(ctx, broker, nil)
if err != nil {
if resp != nil {
Expand All @@ -40,13 +36,21 @@ func Dial(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*D
// We should close the socket intentionally.
_ = conn.Close(websocket.StatusInternalError, "an error occurred")
}()
return Dial(ctx, nconn, iceServers)
}

// Dial negotiates a connection to a listener.
func Dial(ctx context.Context, conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) {
if iceServers == nil {
iceServers = []webrtc.ICEServer{}
}

rtc, err := newPeerConnection(iceServers)
if err != nil {
return nil, fmt.Errorf("create peer connection: %w", err)
}

flushCandidates := proxyICECandidates(rtc, nconn)
flushCandidates := proxyICECandidates(rtc, conn)

ctrl, err := rtc.CreateDataChannel(controlChannel, &webrtc.DataChannelInit{
Protocol: stringPtr(controlChannel),
Expand All @@ -72,34 +76,34 @@ func Dial(ctx context.Context, broker string, iceServers []webrtc.ICEServer) (*D
if err != nil {
return nil, fmt.Errorf("marshal offer message: %w", err)
}
_, err = nconn.Write(offerMessage)
_, err = conn.Write(offerMessage)
if err != nil {
return nil, fmt.Errorf("write offer: %w", err)
}
flushCandidates()

dialer := &Dialer{
ws: conn,
conn: conn,
ctrl: ctrl,
rtc: rtc,
}

return dialer, dialer.negotiate(nconn)
return dialer, dialer.negotiate()
}

// Dialer enables arbitrary dialing to any network and address
// inside a workspace. The opposing end of the WebSocket messages
// should be proxied with a Listener.
type Dialer struct {
ws *websocket.Conn
conn net.Conn
ctrl *webrtc.DataChannel
ctrlrw datachannel.ReadWriteCloser
rtc *webrtc.PeerConnection
}

func (d *Dialer) negotiate(nconn net.Conn) (err error) {
func (d *Dialer) negotiate() (err error) {
var (
decoder = json.NewDecoder(nconn)
decoder = json.NewDecoder(d.conn)
errCh = make(chan error)
// If candidates are sent before an offer, we place them here.
// We currently have no assurances to ensure this can't happen,
Expand All @@ -111,15 +115,15 @@ func (d *Dialer) negotiate(nconn net.Conn) (err error) {
defer close(errCh)
err := waitForDataChannelOpen(context.Background(), d.ctrl)
if err != nil {
_ = d.ws.Close(websocket.StatusAbnormalClosure, "timeout")
_ = d.conn.Close()
errCh <- err
return
}
d.ctrlrw, err = d.ctrl.Detach()
if err != nil {
errCh <- err
}
_ = d.ws.Close(websocket.StatusNormalClosure, "connected")
_ = d.conn.Close()
}()

for {
Expand Down
10 changes: 5 additions & 5 deletions wsnet/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func ExampleDial_basic() {
}
}

dialer, err := Dial(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", servers)
dialer, err := DialWebsocket(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", servers)
if err != nil {
// Do something...
}
Expand All @@ -49,7 +49,7 @@ func TestDial(t *testing.T) {
if err != nil {
t.Error(err)
}
dialer, err := Dial(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
if err != nil {
t.Error(err)
}
Expand All @@ -65,7 +65,7 @@ func TestDial(t *testing.T) {
if err != nil {
t.Error(err)
}
dialer, err := Dial(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -101,7 +101,7 @@ func TestDial(t *testing.T) {
if err != nil {
t.Error(err)
}
dialer, err := Dial(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -133,7 +133,7 @@ func TestDial(t *testing.T) {
if err != nil {
t.Error(err)
}
dialer, err := Dial(context.Background(), connectAddr, nil)
dialer, err := DialWebsocket(context.Background(), connectAddr, nil)
if err != nil {
t.Error(err)
}
Expand Down