diff --git a/wsnet/dial.go b/wsnet/dial.go index 6900bacf..23581eaf 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -67,7 +67,7 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) { return nil, fmt.Errorf("set local offer: %w", err) } - offerMessage, err := json.Marshal(&protoMessage{ + offerMessage, err := json.Marshal(&BrokerMessage{ Offer: &offer, Servers: iceServers, }) @@ -124,7 +124,7 @@ func (d *Dialer) negotiate() (err error) { }() for { - var msg protoMessage + var msg BrokerMessage err = decoder.Decode(&msg) if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { break @@ -218,24 +218,23 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. errCh := make(chan error) go func() { - var init dialChannelMessage - err = json.NewDecoder(rw).Decode(&init) + var res DialChannelResponse + err = json.NewDecoder(rw).Decode(&res) if err != nil { - errCh <- fmt.Errorf("read init: %w", err) + errCh <- fmt.Errorf("read dial response: %w", err) return } - if init.Err == "" { + if res.Err == "" { close(errCh) return } - err := errors.New(init.Err) - if init.Net != "" { - errCh <- &net.OpError{ - Op: init.Op, - Net: init.Net, + err := errors.New(res.Err) + if res.Code == CodeDialErr { + err = &net.OpError{ + Op: res.Op, + Net: res.Net, Err: err, } - return } errCh <- err }() diff --git a/wsnet/listen.go b/wsnet/listen.go index 92f7671f..1496e19c 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net" - "strings" "sync" "time" @@ -18,8 +17,26 @@ import ( "cdr.dev/coder-cli/coder-sdk" ) +// Codes for DialChannelResponse. +const ( + CodeDialErr = "dial_error" + CodePermissionErr = "permission_error" + CodeBadAddressErr = "bad_address_error" +) + var connectionRetryInterval = time.Second +// DialChannelResponse is used to notify a dial channel of a +// listening state. Modeled after net.OpError, and marshalled +// to that if Net is not "". +type DialChannelResponse struct { + Code string + Err string + // Fields are set if the code is CodeDialErr. + Net string + Op string +} + // Listen connects to the broker proxies connections to the local net. // Close will end all RTC connections. func Listen(ctx context.Context, broker string) (io.Closer, error) { @@ -124,7 +141,7 @@ func (l *listener) negotiate(conn net.Conn) { // Sends the error provided then closes the connection. // If RTC isn't connected, we'll close it. closeError = func(err error) { - d, _ := json.Marshal(&protoMessage{ + d, _ := json.Marshal(&BrokerMessage{ Error: err.Error(), }) _, _ = conn.Write(d) @@ -139,7 +156,7 @@ func (l *listener) negotiate(conn net.Conn) { ) for { - var msg protoMessage + var msg BrokerMessage err = decoder.Decode(&msg) if err != nil { closeError(err) @@ -190,7 +207,7 @@ func (l *listener) negotiate(conn net.Conn) { l.connClosersMut.Lock() l.connClosers = append(l.connClosers, rtc) l.connClosersMut.Unlock() - rtc.OnDataChannel(l.handle) + rtc.OnDataChannel(l.handle(msg)) err = rtc.SetRemoteDescription(*msg.Offer) if err != nil { closeError(fmt.Errorf("apply offer: %w", err)) @@ -208,7 +225,7 @@ func (l *listener) negotiate(conn net.Conn) { } flushCandidates() - data, err := json.Marshal(&protoMessage{ + data, err := json.Marshal(&BrokerMessage{ Answer: rtc.LocalDescription(), }) if err != nil { @@ -233,70 +250,89 @@ func (l *listener) negotiate(conn net.Conn) { } } -func (l *listener) handle(dc *webrtc.DataChannel) { - if dc.Protocol() == controlChannel { - // The control channel handles pings. +func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) { + return func(dc *webrtc.DataChannel) { + if dc.Protocol() == controlChannel { + // The control channel handles pings. + dc.OnOpen(func() { + rw, err := dc.Detach() + if err != nil { + return + } + // We'll read and write back a single byte for ping/pongin'. + d := make([]byte, 1) + for { + _, err = rw.Read(d) + if errors.Is(err, io.EOF) { + return + } + if err != nil { + continue + } + _, _ = rw.Write(d) + } + }) + return + } + dc.OnOpen(func() { rw, err := dc.Detach() if err != nil { return } - // We'll read and write back a single byte for ping/pongin'. - d := make([]byte, 1) - for { - _, err = rw.Read(d) - if errors.Is(err, io.EOF) { + + var init DialChannelResponse + sendInitMessage := func() { + initData, err := json.Marshal(&init) + if err != nil { + rw.Close() return } + _, err = rw.Write(initData) if err != nil { - continue + return + } + if init.Err != "" { + // If an error occurred, we're safe to close the connection. + dc.Close() + return } - _, _ = rw.Write(d) } - }) - return - } - dc.OnOpen(func() { - rw, err := dc.Detach() - if err != nil { - return - } - parts := strings.SplitN(dc.Protocol(), ":", 2) - network := parts[0] - addr := parts[1] + network, addr, err := msg.getAddress(dc.Protocol()) + if err != nil { + init.Code = CodeBadAddressErr + init.Err = err.Error() + var policyErr notPermittedByPolicyErr + if errors.As(err, &policyErr) { + init.Code = CodePermissionErr + } + sendInitMessage() + return + } - var init dialChannelMessage - conn, err := net.Dial(network, addr) - if err != nil { - init.Err = err.Error() - if op, ok := err.(*net.OpError); ok { - init.Net = op.Net - init.Op = op.Op + conn, err := net.Dial(network, addr) + if err != nil { + init.Code = CodeDialErr + init.Err = err.Error() + if op, ok := err.(*net.OpError); ok { + init.Net = op.Net + init.Op = op.Op + } } - } - initData, err := json.Marshal(&init) - if err != nil { - rw.Close() - return - } - _, err = rw.Write(initData) - if err != nil { - return - } - if init.Err != "" { - // If an error occurred, we're safe to close the connection. - dc.Close() - return - } - defer conn.Close() - defer dc.Close() + sendInitMessage() + if init.Err != "" { + return + } + defer conn.Close() + defer dc.Close() - go func() { - _, _ = io.Copy(rw, conn) - }() - _, _ = io.Copy(conn, rw) - }) + go func() { + _, _ = io.Copy(rw, conn) + }() + _, _ = io.Copy(conn, rw) + }) + } } // Close closes the broker socket and all created RTC connections. diff --git a/wsnet/proto.go b/wsnet/proto.go index cbe3ac82..754fffac 100644 --- a/wsnet/proto.go +++ b/wsnet/proto.go @@ -1,10 +1,45 @@ package wsnet import ( + "fmt" + "math/bits" + "net" + "strconv" + "strings" + "github.com/pion/webrtc/v3" ) -// protoMessage is used for brokering a dialer and listener. +// DialPolicy a single network + address + port combinations that a connection +// is permitted to use. +type DialPolicy struct { + // If network is empty, it applies to all networks. + Network string `json:"network"` + // Host is the IP or hostname of the address. It should not contain the + // port.If empty, it applies to all hosts. "localhost", [::1], and any IPv4 + // address under "127.0.0.0/8" can be used interchangeably. + Host string `json:"address"` + // If port is 0, it applies to all ports. + Port uint16 `json:"port"` +} + +// permits checks if a DialPolicy permits a specific network + host + port +// combination. The host must be put through normalizeHost first. +func (p DialPolicy) permits(network, host string, port uint16) bool { + if p.Network != "" && p.Network != network { + return false + } + if p.Host != "" && canonicalizeHost(p.Host) != host { + return false + } + if p.Port != 0 && p.Port != port { + return false + } + + return true +} + +// BrokerMessage is used for brokering a dialer and listener. // // Dialers initiate an exchange by providing an Offer, // along with a list of ICE servers for the listener to @@ -12,10 +47,13 @@ import ( // // The listener should respond with an offer, then both // sides can begin exchanging candidates. -type protoMessage struct { +type BrokerMessage struct { // Dialer -> Listener Offer *webrtc.SessionDescription `json:"offer"` Servers []webrtc.ICEServer `json:"servers"` + // Policies denote which addresses the client can dial. If empty or nil, all + // addresses are permitted. + Policies []DialPolicy `json:"ports"` // Listener -> Dialer Error string `json:"error"` @@ -25,11 +63,73 @@ type protoMessage struct { Candidate string `json:"candidate"` } -// dialChannelMessage is used to notify a dial channel of a -// listening state. Modeled after net.OpError, and marshalled -// to that if Net is not "". -type dialChannelMessage struct { - Err string - Net string - Op string +// getAddress parses the data channel's protocol into an address suitable for +// net.Dial. It also verifies that the BrokerMessage permits connecting to said +// address. +func (msg BrokerMessage) getAddress(protocol string) (netwk, addr string, err error) { + parts := strings.SplitN(protocol, ":", 2) + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid dial address: %v", protocol) + } + host, port, err := net.SplitHostPort(parts[1]) + if err != nil { + return "", "", fmt.Errorf("invalid dial address: %v", protocol) + } + + var ( + network = parts[0] + normalHost = canonicalizeHost(host) + // Still return the original host value, not the canonical value. + fullAddr = net.JoinHostPort(host, port) + ) + if network == "" { + return "", "", fmt.Errorf("invalid dial address %q network: %v", protocol, network) + } + if host == "" { + return "", "", fmt.Errorf("invalid dial address %q host: %v", protocol, host) + } + + portParsed, err := strconv.Atoi(port) + if err != nil || portParsed < 0 || bits.Len(uint(portParsed)) > 16 { + return "", "", fmt.Errorf("invalid dial address %q port: %v", protocol, port) + } + if len(msg.Policies) == 0 { + return network, fullAddr, nil + } + + portParsedU16 := uint16(portParsed) + for _, p := range msg.Policies { + if p.permits(network, normalHost, portParsedU16) { + return network, fullAddr, nil + } + } + + return "", "", fmt.Errorf("connections are not permitted to %q by policy", protocol) +} + +// canonicalizeHost converts all representations of "localhost" to "localhost". +func canonicalizeHost(addr string) string { + addr = strings.TrimPrefix(addr, "[") + addr = strings.TrimSuffix(addr, "]") + + ip := net.ParseIP(addr) + if ip == nil { + return addr + } + + if ip.IsLoopback() { + return "localhost" + } + return addr +} + +type notPermittedByPolicyErr struct { + protocol string +} + +var _ error = notPermittedByPolicyErr{} + +// Error implements error. +func (e notPermittedByPolicyErr) Error() string { + return fmt.Sprintf("connections are not permitted to %q by policy", e.protocol) } diff --git a/wsnet/proto_test.go b/wsnet/proto_test.go new file mode 100644 index 00000000..89999f6b --- /dev/null +++ b/wsnet/proto_test.go @@ -0,0 +1,235 @@ +package wsnet + +import ( + "fmt" + "testing" + + "cdr.dev/slog/sloggers/slogtest/assert" +) + +func Test_BrokerMessage(t *testing.T) { + t.Run("getAddress", func(t *testing.T) { + t.Run("OK", func(t *testing.T) { + var ( + msg = BrokerMessage{ + Policies: nil, + } + network = "tcp" + addr = "localhost:1234" + ) + + protocol := formatAddress(network, addr) + gotNetwork, gotAddr, err := msg.getAddress(protocol) + assert.Success(t, "got address", err) + assert.Equal(t, "networks equal", network, gotNetwork) + assert.Equal(t, "addresses equal", addr, gotAddr) + + msg.Policies = []DialPolicy{} + gotNetwork, gotAddr, err = msg.getAddress(protocol) + assert.Success(t, "got address", err) + assert.Equal(t, "networks equal", network, gotNetwork) + assert.Equal(t, "addresses equal", addr, gotAddr) + }) + + t.Run("InvalidProtocol", func(t *testing.T) { + cases := []struct { + protocol string + errContains string + }{ + { + protocol: "", + errContains: "invalid", + }, + { + protocol: "a:b", + errContains: "invalid", + }, + { + protocol: "a:b:c:d", + errContains: "invalid", + }, + { + protocol: ":localhost:1234", + errContains: "network", + }, + { + protocol: "tcp::1234", + errContains: "host", + }, + { + protocol: "tcp:localhost:", + errContains: "port", + }, + { + protocol: "tcp:localhost:asdf", + errContains: "port", + }, + { + protocol: "tcp:localhost:-1", + errContains: "port", + }, + { + // Overflow uint16. + protocol: fmt.Sprintf("tcp:localhost:%v", uint(1)<<16), + errContains: "port", + }, + } + + var msg BrokerMessage + for i, c := range cases { + amsg := fmt.Sprintf("case %v %q: ", i, c) + gotNetwork, gotAddr, err := msg.getAddress(c.protocol) + assert.Error(t, amsg+"successfully got invalid address", err) + assert.ErrorContains(t, fmt.Sprintf("%verr contains %q", amsg, c.errContains), err, c.errContains) + assert.Equal(t, amsg+"empty network", "", gotNetwork) + assert.Equal(t, amsg+"empty address", "", gotAddr) + } + }) + + t.Run("ChecksPolicies", func(t *testing.T) { + // ok == true tests automatically have a bunch of non-matching dial + // policies injected in front of them. + cases := []struct { + network string + host string + port uint16 + policy DialPolicy + ok bool + }{ + { + network: "tcp", + host: "localhost", + port: 1234, + policy: dialPolicy("tcp", "localhost", 1234), + ok: true, + }, + { + network: "tcp", + host: "localhost", + port: 1234, + policy: dialPolicy("udp", "example.com", 51), + ok: false, + }, + // Network checks. + { + network: "tcp", + host: "localhost", + port: 1234, + policy: dialPolicy("", "localhost", 1234), + ok: true, + }, + { + network: "tcp", + host: "localhost", + port: 1234, + policy: dialPolicy("udp", "localhost", 1234), + ok: false, + }, + // Host checks. + { + network: "tcp", + host: "localhost", + port: 1234, + policy: dialPolicy("tcp", "", 1234), + ok: true, + }, + { + network: "tcp", + host: "localhost", + port: 1234, + policy: dialPolicy("tcp", "127.0.0.1", 1234), + ok: true, + }, + { + network: "tcp", + host: "127.0.0.1", + port: 1234, + policy: dialPolicy("tcp", "127.1.2.3", 1234), + ok: true, + }, + { + network: "tcp", + host: "[::1]", + port: 1234, + policy: dialPolicy("tcp", "127.1.2.3", 1234), + ok: true, + }, + { + network: "tcp", + host: "localhost", + port: 1234, + policy: dialPolicy("tcp", "example.com", 1234), + ok: false, + }, + { + network: "tcp", + host: "example.com", + port: 1234, + policy: dialPolicy("tcp", "localhost", 1234), + ok: false, + }, + // Port checks. + { + network: "tcp", + host: "localhost", + port: 1234, + policy: dialPolicy("tcp", "localhost", 5678), + ok: false, + }, + { + network: "tcp", + host: "localhost", + port: 1234, + policy: dialPolicy("tcp", "localhost", 0), + ok: true, + }, + } + + for i, c := range cases { + var ( + amsg = fmt.Sprintf("case %v '%+v': ", i, c) + msg = BrokerMessage{ + Policies: []DialPolicy{c.policy}, + } + ) + + // Add nonsense policies before the matching policy. + if c.ok { + msg.Policies = []DialPolicy{ + dialPolicy("asdf", "localhost", 1234), + dialPolicy("tcp", "asdf", 1234), + dialPolicy("tcp", "localhost", 17208), + c.policy, + } + } + + // Test DialPolicy. + assert.Equal(t, amsg+"policy matches", c.ok, c.policy.permits(c.network, canonicalizeHost(c.host), c.port)) + + // Test BrokerMessage. + protocol := formatAddress(c.network, fmt.Sprintf("%v:%v", c.host, c.port)) + gotNetwork, gotAddr, err := msg.getAddress(protocol) + if c.ok { + assert.Success(t, amsg, err) + } else { + assert.Error(t, amsg+"successfully got invalid address", err) + assert.ErrorContains(t, amsg+"err contains 'not permitted'", err, "not permitted") + assert.Equal(t, amsg+"empty network", "", gotNetwork) + assert.Equal(t, amsg+"empty address", "", gotAddr) + } + } + }) + }) +} + +func formatAddress(network, addr string) string { + return fmt.Sprintf("%v:%v", network, addr) +} + +func dialPolicy(network, host string, port uint16) DialPolicy { + return DialPolicy{ + Network: network, + Host: host, + Port: port, + } +} diff --git a/wsnet/rtc.go b/wsnet/rtc.go index 9c07663d..f5c7c5f3 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -186,7 +186,7 @@ func proxyICECandidates(conn *webrtc.PeerConnection, w io.Writer) func() { queue = []*webrtc.ICECandidate{} flushed = false write = func(i *webrtc.ICECandidate) { - b, _ := json.Marshal(&protoMessage{ + b, _ := json.Marshal(&BrokerMessage{ Candidate: i.ToJSON().Candidate, }) _, _ = w.Write(b)