Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 544f276

Browse files
authored
wsnet dial policy (#364)
* wsnet dial policy * Add wsnet address/policy tests * Add op to not permitted errors * Add Code to dial response struct * fixup! Add Code to dial response struct * fixup! Add Code to dial response struct
1 parent 765c0dd commit 544f276

File tree

5 files changed

+447
-77
lines changed

5 files changed

+447
-77
lines changed

wsnet/dial.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) {
6767
return nil, fmt.Errorf("set local offer: %w", err)
6868
}
6969

70-
offerMessage, err := json.Marshal(&protoMessage{
70+
offerMessage, err := json.Marshal(&BrokerMessage{
7171
Offer: &offer,
7272
Servers: iceServers,
7373
})
@@ -124,7 +124,7 @@ func (d *Dialer) negotiate() (err error) {
124124
}()
125125

126126
for {
127-
var msg protoMessage
127+
var msg BrokerMessage
128128
err = decoder.Decode(&msg)
129129
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
130130
break
@@ -218,24 +218,23 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
218218

219219
errCh := make(chan error)
220220
go func() {
221-
var init dialChannelMessage
222-
err = json.NewDecoder(rw).Decode(&init)
221+
var res DialChannelResponse
222+
err = json.NewDecoder(rw).Decode(&res)
223223
if err != nil {
224-
errCh <- fmt.Errorf("read init: %w", err)
224+
errCh <- fmt.Errorf("read dial response: %w", err)
225225
return
226226
}
227-
if init.Err == "" {
227+
if res.Err == "" {
228228
close(errCh)
229229
return
230230
}
231-
err := errors.New(init.Err)
232-
if init.Net != "" {
233-
errCh <- &net.OpError{
234-
Op: init.Op,
235-
Net: init.Net,
231+
err := errors.New(res.Err)
232+
if res.Code == CodeDialErr {
233+
err = &net.OpError{
234+
Op: res.Op,
235+
Net: res.Net,
236236
Err: err,
237237
}
238-
return
239238
}
240239
errCh <- err
241240
}()

wsnet/listen.go

Lines changed: 91 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"fmt"
88
"io"
99
"net"
10-
"strings"
1110
"sync"
1211
"time"
1312

@@ -18,8 +17,26 @@ import (
1817
"cdr.dev/coder-cli/coder-sdk"
1918
)
2019

20+
// Codes for DialChannelResponse.
21+
const (
22+
CodeDialErr = "dial_error"
23+
CodePermissionErr = "permission_error"
24+
CodeBadAddressErr = "bad_address_error"
25+
)
26+
2127
var connectionRetryInterval = time.Second
2228

29+
// DialChannelResponse is used to notify a dial channel of a
30+
// listening state. Modeled after net.OpError, and marshalled
31+
// to that if Net is not "".
32+
type DialChannelResponse struct {
33+
Code string
34+
Err string
35+
// Fields are set if the code is CodeDialErr.
36+
Net string
37+
Op string
38+
}
39+
2340
// Listen connects to the broker proxies connections to the local net.
2441
// Close will end all RTC connections.
2542
func Listen(ctx context.Context, broker string) (io.Closer, error) {
@@ -124,7 +141,7 @@ func (l *listener) negotiate(conn net.Conn) {
124141
// Sends the error provided then closes the connection.
125142
// If RTC isn't connected, we'll close it.
126143
closeError = func(err error) {
127-
d, _ := json.Marshal(&protoMessage{
144+
d, _ := json.Marshal(&BrokerMessage{
128145
Error: err.Error(),
129146
})
130147
_, _ = conn.Write(d)
@@ -139,7 +156,7 @@ func (l *listener) negotiate(conn net.Conn) {
139156
)
140157

141158
for {
142-
var msg protoMessage
159+
var msg BrokerMessage
143160
err = decoder.Decode(&msg)
144161
if err != nil {
145162
closeError(err)
@@ -190,7 +207,7 @@ func (l *listener) negotiate(conn net.Conn) {
190207
l.connClosersMut.Lock()
191208
l.connClosers = append(l.connClosers, rtc)
192209
l.connClosersMut.Unlock()
193-
rtc.OnDataChannel(l.handle)
210+
rtc.OnDataChannel(l.handle(msg))
194211
err = rtc.SetRemoteDescription(*msg.Offer)
195212
if err != nil {
196213
closeError(fmt.Errorf("apply offer: %w", err))
@@ -208,7 +225,7 @@ func (l *listener) negotiate(conn net.Conn) {
208225
}
209226
flushCandidates()
210227

211-
data, err := json.Marshal(&protoMessage{
228+
data, err := json.Marshal(&BrokerMessage{
212229
Answer: rtc.LocalDescription(),
213230
})
214231
if err != nil {
@@ -233,70 +250,89 @@ func (l *listener) negotiate(conn net.Conn) {
233250
}
234251
}
235252

236-
func (l *listener) handle(dc *webrtc.DataChannel) {
237-
if dc.Protocol() == controlChannel {
238-
// The control channel handles pings.
253+
func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) {
254+
return func(dc *webrtc.DataChannel) {
255+
if dc.Protocol() == controlChannel {
256+
// The control channel handles pings.
257+
dc.OnOpen(func() {
258+
rw, err := dc.Detach()
259+
if err != nil {
260+
return
261+
}
262+
// We'll read and write back a single byte for ping/pongin'.
263+
d := make([]byte, 1)
264+
for {
265+
_, err = rw.Read(d)
266+
if errors.Is(err, io.EOF) {
267+
return
268+
}
269+
if err != nil {
270+
continue
271+
}
272+
_, _ = rw.Write(d)
273+
}
274+
})
275+
return
276+
}
277+
239278
dc.OnOpen(func() {
240279
rw, err := dc.Detach()
241280
if err != nil {
242281
return
243282
}
244-
// We'll read and write back a single byte for ping/pongin'.
245-
d := make([]byte, 1)
246-
for {
247-
_, err = rw.Read(d)
248-
if errors.Is(err, io.EOF) {
283+
284+
var init DialChannelResponse
285+
sendInitMessage := func() {
286+
initData, err := json.Marshal(&init)
287+
if err != nil {
288+
rw.Close()
249289
return
250290
}
291+
_, err = rw.Write(initData)
251292
if err != nil {
252-
continue
293+
return
294+
}
295+
if init.Err != "" {
296+
// If an error occurred, we're safe to close the connection.
297+
dc.Close()
298+
return
253299
}
254-
_, _ = rw.Write(d)
255300
}
256-
})
257-
return
258-
}
259301

260-
dc.OnOpen(func() {
261-
rw, err := dc.Detach()
262-
if err != nil {
263-
return
264-
}
265-
parts := strings.SplitN(dc.Protocol(), ":", 2)
266-
network := parts[0]
267-
addr := parts[1]
302+
network, addr, err := msg.getAddress(dc.Protocol())
303+
if err != nil {
304+
init.Code = CodeBadAddressErr
305+
init.Err = err.Error()
306+
var policyErr notPermittedByPolicyErr
307+
if errors.As(err, &policyErr) {
308+
init.Code = CodePermissionErr
309+
}
310+
sendInitMessage()
311+
return
312+
}
268313

269-
var init dialChannelMessage
270-
conn, err := net.Dial(network, addr)
271-
if err != nil {
272-
init.Err = err.Error()
273-
if op, ok := err.(*net.OpError); ok {
274-
init.Net = op.Net
275-
init.Op = op.Op
314+
conn, err := net.Dial(network, addr)
315+
if err != nil {
316+
init.Code = CodeDialErr
317+
init.Err = err.Error()
318+
if op, ok := err.(*net.OpError); ok {
319+
init.Net = op.Net
320+
init.Op = op.Op
321+
}
276322
}
277-
}
278-
initData, err := json.Marshal(&init)
279-
if err != nil {
280-
rw.Close()
281-
return
282-
}
283-
_, err = rw.Write(initData)
284-
if err != nil {
285-
return
286-
}
287-
if init.Err != "" {
288-
// If an error occurred, we're safe to close the connection.
289-
dc.Close()
290-
return
291-
}
292-
defer conn.Close()
293-
defer dc.Close()
323+
sendInitMessage()
324+
if init.Err != "" {
325+
return
326+
}
327+
defer conn.Close()
328+
defer dc.Close()
294329

295-
go func() {
296-
_, _ = io.Copy(rw, conn)
297-
}()
298-
_, _ = io.Copy(conn, rw)
299-
})
330+
go func() {
331+
_, _ = io.Copy(rw, conn)
332+
}()
333+
_, _ = io.Copy(conn, rw)
334+
})
335+
}
300336
}
301337

302338
// Close closes the broker socket and all created RTC connections.

0 commit comments

Comments
 (0)