Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 63fbaa3

Browse files
committed
Return proper errors from dial ice
1 parent 1c95749 commit 63fbaa3

File tree

7 files changed

+367
-249
lines changed

7 files changed

+367
-249
lines changed

wsnet/broker_test.go

Lines changed: 0 additions & 55 deletions
This file was deleted.

wsnet/dial.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ import (
1313
"nhooyr.io/websocket"
1414
)
1515

16+
// DialConfig provides options to configure the Dial for a connection.
1617
type DialConfig struct {
17-
ICEServers []ICEServer
18+
ICEServers []webrtc.ICEServer
1819
}
1920

2021
// Dial connects to the broker and negotiates a connection to a listener.
22+
//
2123
func Dial(ctx context.Context, broker string, config *DialConfig) (*Dialer, error) {
2224
if config == nil {
2325
config = &DialConfig{}
@@ -46,17 +48,11 @@ func Dial(ctx context.Context, broker string, config *DialConfig) (*Dialer, erro
4648
if err != nil {
4749
return nil, fmt.Errorf("create peer connection: %w", err)
4850
}
49-
rtc.OnICEConnectionStateChange(func(is webrtc.ICEConnectionState) {
50-
fmt.Printf("WE CONNECTED!: %s\n", is)
51-
})
52-
rtc.OnICECandidate(func(i *webrtc.ICECandidate) {
53-
fmt.Printf("WE GOT ICE: %+v\n", i)
54-
})
5551

5652
flushCandidates := proxyICECandidates(rtc, nconn)
5753

58-
ctrl, err := rtc.CreateDataChannel("control", &webrtc.DataChannelInit{
59-
Protocol: stringPtr("control"),
54+
ctrl, err := rtc.CreateDataChannel(controlChannel, &webrtc.DataChannelInit{
55+
Protocol: stringPtr(controlChannel),
6056
Ordered: boolPtr(true),
6157
})
6258
if err != nil {

wsnet/dial_test.go

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,40 @@ package wsnet
22

33
import (
44
"context"
5-
"fmt"
6-
"testing"
5+
"errors"
76

8-
"github.com/pion/ice/v2"
9-
"github.com/pion/turn/v2"
107
"github.com/pion/webrtc/v3"
118
)
129

13-
func TestDial(t *testing.T) {
14-
t.Run("Example", func(t *testing.T) {
15-
connectAddr, _ := listenBroker(t)
16-
turnAddr := listenTURN(t, ice.ProtoTypeTCP, "wowie", true)
10+
func ExampleDial_basic() {
11+
servers := []webrtc.ICEServer{{
12+
URLs: []string{"turns:master.cdr.dev"},
13+
Username: "kyle",
14+
Credential: "pass",
15+
CredentialType: webrtc.ICECredentialTypePassword,
16+
}}
1717

18-
dialer, err := Dial(context.Background(), connectAddr, &DialConfig{
19-
[]ICEServer{{
20-
URLs: []string{turnAddr},
21-
Username: "insecure",
22-
Credential: "pass",
23-
CredentialType: webrtc.ICECredentialTypePassword,
24-
}},
25-
})
26-
if err != nil {
27-
t.Error(err)
18+
for _, server := range servers {
19+
err := DialICE(server, DefaultICETimeout)
20+
if errors.Is(err, ErrInvalidCredentials) {
21+
// You could do something...
2822
}
29-
fmt.Printf("Dialer: %+v\n", dialer)
30-
})
31-
}
32-
33-
func testTURN() {
23+
if errors.Is(err, ErrMismatchedProtocol) {
24+
// Likely they used TURN when they should have used TURN.
25+
// Or they could have used TURN instead of TURNS.
26+
}
27+
}
3428

35-
turn.NewServer(turn.ServerConfig{})
29+
dialer, err := Dial(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", &DialConfig{
30+
ICEServers: servers,
31+
})
32+
if err != nil {
33+
// Do something...
34+
}
35+
conn, err := dialer.DialContext(context.Background(), "tcp", "localhost:13337")
36+
if err != nil {
37+
// Something...
38+
}
39+
defer conn.Close()
40+
// You now have access to the proxied remote port in `conn`.
3641
}

wsnet/listen.go

Lines changed: 107 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"io"
87
"net"
98

109
"cdr.dev/coder-cli/coder-sdk"
1110
"github.com/hashicorp/yamux"
11+
"github.com/pion/datachannel"
1212
"github.com/pion/webrtc/v3"
1313
"nhooyr.io/websocket"
1414
)
1515

16+
// Listen connects to the broker and returns a Listener that's triggered
17+
// when a new connection is requested from a Dialer.
18+
//
19+
// LocalAddr on connections indicates the target specified by the dialer.
1620
func Listen(ctx context.Context, broker string) (net.Listener, error) {
1721
conn, resp, err := websocket.Dial(ctx, broker, nil)
1822
if err != nil {
@@ -26,100 +30,174 @@ func Listen(ctx context.Context, broker string) (net.Listener, error) {
2630
nconn := websocket.NetConn(ctx, conn, websocket.MessageBinary)
2731
session, err := yamux.Server(nconn, nil)
2832
if err != nil {
29-
return nil, fmt.Errorf("")
33+
return nil, fmt.Errorf("create multiplex: %w", err)
3034
}
31-
return nil, nil
35+
l := &listener{
36+
ws: conn,
37+
conns: make(chan net.Conn),
38+
}
39+
go func() {
40+
for {
41+
conn, err := session.Accept()
42+
if err != nil {
43+
l.acceptError = err
44+
l.Close()
45+
return
46+
}
47+
l.negotiate(conn)
48+
}
49+
}()
50+
return l, nil
3251
}
3352

3453
type listener struct {
35-
session *yamux.Session
36-
}
54+
acceptError error
55+
ws *websocket.Conn
3756

38-
func (l *listener) Accept() (net.Conn, error) {
39-
conn, err := l.session.Accept()
40-
if err != nil {
41-
return nil, err
42-
}
57+
conns chan net.Conn
58+
}
4359

60+
// Negotiates the handshake protocol over the connection provided.
61+
func (l *listener) negotiate(conn net.Conn) {
4462
var (
45-
decoder = json.NewDecoder(conn)
46-
closeError = func(err error) error {
63+
err error
64+
decoder = json.NewDecoder(conn)
65+
rtc *webrtc.PeerConnection
66+
// Sends the error provided then closes the connection.
67+
// If RTC isn't connected, we'll close it.
68+
closeError = func(err error) {
4769
d, _ := json.Marshal(&protoMessage{
4870
Error: err.Error(),
4971
})
5072
_, _ = conn.Write(d)
5173
_ = conn.Close()
52-
return err
74+
if rtc != nil {
75+
if rtc.ConnectionState() != webrtc.PeerConnectionStateConnected {
76+
rtc.Close()
77+
rtc = nil
78+
}
79+
}
5380
}
54-
rtc *webrtc.PeerConnection
5581
)
5682

5783
for {
5884
var msg protoMessage
5985
err = decoder.Decode(&msg)
60-
if err == io.EOF {
61-
break
62-
}
6386
if err != nil {
64-
return nil, err
87+
closeError(err)
88+
return
6589
}
6690

6791
if msg.Candidate != "" {
6892
if rtc == nil {
69-
return nil, closeError(fmt.Errorf("Offer must be sent before candidates"))
93+
closeError(fmt.Errorf("offer must be sent before candidates"))
94+
return
7095
}
7196

7297
err = rtc.AddICECandidate(webrtc.ICECandidateInit{
7398
Candidate: msg.Candidate,
7499
})
75100
if err != nil {
76-
return nil, closeError(fmt.Errorf("accept ice candidate: %w", err))
101+
closeError(fmt.Errorf("accept ice candidate: %w", err))
102+
return
77103
}
78104
}
79105

80106
if msg.Offer != nil {
81107
if msg.Servers == nil {
82-
return nil, closeError(fmt.Errorf("ICEServers must be provided"))
108+
closeError(fmt.Errorf("ICEServers must be provided"))
109+
return
83110
}
84111
rtc, err = newPeerConnection(msg.Servers)
85112
if err != nil {
86-
return nil, closeError(err)
113+
closeError(err)
114+
return
87115
}
116+
rtc.OnDataChannel(l.handle)
88117
flushCandidates := proxyICECandidates(rtc, conn)
89118
err = rtc.SetRemoteDescription(*msg.Offer)
90119
if err != nil {
91-
return nil, closeError(fmt.Errorf("apply offer: %w", err))
120+
closeError(fmt.Errorf("apply offer: %w", err))
121+
return
92122
}
93123
answer, err := rtc.CreateAnswer(nil)
94124
if err != nil {
95-
return nil, closeError(fmt.Errorf("create answer: %w", err))
125+
closeError(fmt.Errorf("create answer: %w", err))
126+
return
96127
}
97128
err = rtc.SetLocalDescription(answer)
98129
if err != nil {
99-
return nil, closeError(fmt.Errorf("set local answer: %w", err))
130+
closeError(fmt.Errorf("set local answer: %w", err))
131+
return
100132
}
101133
flushCandidates()
102134

103135
data, err := json.Marshal(&protoMessage{
104136
Answer: rtc.LocalDescription(),
105137
})
106138
if err != nil {
107-
return nil, closeError(fmt.Errorf("marshal: %w", err))
139+
closeError(fmt.Errorf("marshal: %w", err))
140+
return
108141
}
109142
_, err = conn.Write(data)
110143
if err != nil {
111-
return nil, closeError(fmt.Errorf("write: %w", err))
144+
closeError(fmt.Errorf("write: %w", err))
145+
return
112146
}
113147
}
114148
}
149+
}
150+
151+
func (l *listener) handle(dc *webrtc.DataChannel) {
152+
// if dc.Protocol() == controlChannel {
153+
// return
154+
// }
155+
156+
fmt.Printf("GOT CHANNEL %s\n", dc.Protocol())
157+
158+
// dc.OnOpen(func() {
159+
// rw, err := dc.Detach()
160+
// })
161+
}
115162

116-
return nil, nil
163+
// Accept accepts a new connection.
164+
func (l *listener) Accept() (net.Conn, error) {
165+
return <-l.conns, l.acceptError
117166
}
118167

168+
// Close closes the broker socket.
119169
func (l *listener) Close() error {
120-
return nil
170+
close(l.conns)
171+
return l.ws.Close(websocket.StatusNormalClosure, "")
121172
}
122173

174+
// Since this listener is bound to the WebSocket, we could
175+
// return that resolved Addr, but until we need it we won't.
123176
func (l *listener) Addr() net.Addr {
124177
return nil
125178
}
179+
180+
type dataChannelConn struct {
181+
rw datachannel.ReadWriteCloser
182+
localAddr net.Addr
183+
}
184+
185+
func (d *dataChannelConn) Read(b []byte) (n int, err error) {
186+
return d.rw.Read(b)
187+
}
188+
189+
func (d *dataChannelConn) Write(b []byte) (n int, err error) {
190+
return d.rw.Write(b)
191+
}
192+
193+
func (d *dataChannelConn) Close() error {
194+
return d.Close()
195+
}
196+
197+
func (d *dataChannelConn) LocalAddr() net.Addr {
198+
return d.localAddr
199+
}
200+
201+
func (d *dataChannelConn) RemoteAddr() net.Addr {
202+
return nil
203+
}

0 commit comments

Comments
 (0)