Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

wsnet dial policy #364

Merged
merged 6 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add wsnet address/policy tests
  • Loading branch information
deansheather committed Jun 4, 2021
commit 27d5c988466201f12eeb24f77a8cb2bb50f96bb7
4 changes: 2 additions & 2 deletions wsnet/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) {
return nil, fmt.Errorf("set local offer: %w", err)
}

offerMessage, err := json.Marshal(&ProtoMessage{
offerMessage, err := json.Marshal(&BrokerMessage{
Offer: &offer,
Servers: iceServers,
})
Expand Down Expand Up @@ -124,7 +124,7 @@ func (d *Dialer) negotiate() (err error) {
}()

for {
var msg ProtoMessage
var msg BrokerMessage
err = decoder.Decode(&msg)
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
break
Expand Down
68 changes: 5 additions & 63 deletions wsnet/listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ import (
"errors"
"fmt"
"io"
"math/bits"
"net"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -131,7 +128,7 @@ func (l *listener) negotiate(conn net.Conn) {
// Sends the error provided then closes the connection.
// If RTC isn't connected, we'll close it.
closeError = func(err error) {
d, _ := json.Marshal(&ProtoMessage{
d, _ := json.Marshal(&BrokerMessage{
Error: err.Error(),
})
_, _ = conn.Write(d)
Expand All @@ -146,7 +143,7 @@ func (l *listener) negotiate(conn net.Conn) {
)

for {
var msg ProtoMessage
var msg BrokerMessage
err = decoder.Decode(&msg)
if err != nil {
closeError(err)
Expand Down Expand Up @@ -215,7 +212,7 @@ func (l *listener) negotiate(conn net.Conn) {
}
flushCandidates()

data, err := json.Marshal(&ProtoMessage{
data, err := json.Marshal(&BrokerMessage{
Answer: rtc.LocalDescription(),
})
if err != nil {
Expand All @@ -240,7 +237,7 @@ func (l *listener) negotiate(conn net.Conn) {
}
}

func (l *listener) handle(msg ProtoMessage) func(dc *webrtc.DataChannel) {
func (l *listener) handle(msg BrokerMessage) func(dc *webrtc.DataChannel) {
return func(dc *webrtc.DataChannel) {
if dc.Protocol() == controlChannel {
// The control channel handles pings.
Expand Down Expand Up @@ -289,7 +286,7 @@ func (l *listener) handle(msg ProtoMessage) func(dc *webrtc.DataChannel) {
}
}

network, addr, err := getAddress(msg, dc.Protocol())
network, addr, err := msg.getAddress(dc.Protocol())
if err != nil {
init.Err = err.Error()
sendInitMessage()
Expand Down Expand Up @@ -336,58 +333,3 @@ func (l *listener) Close() error {
func (l *listener) Addr() net.Addr {
return nil
}

// normalizeHost converts all representations of "localhost" to "localhost".
func normalizeHost(addr string) string {
ip := net.ParseIP(addr)
if ip == nil {
return addr
}

if localNet.Contains(ip) {
return "localhost"
}
return addr
}

// getAddress parses the data channel's protocol into an address suitable for
// net.Dial. It also verifies that the ProtoMessage permits connecting to said
// address.
func getAddress(msg ProtoMessage, protocol string) (netwk, addr string, err error) {
parts := strings.SplitN(protocol, ":", 3)
if len(parts) != 3 {
return "", "", fmt.Errorf("invalid dial address: %v", protocol)
}

var (
network = parts[0]
host = normalizeHost(parts[1])
port = parts[2]
fullAddr = net.JoinHostPort(host, port)
)
if len(msg.Policies) == 0 {
return network, fullAddr, nil
}

portParsed, err := strconv.Atoi(port)
if err != nil || portParsed < 0 || bits.Len(uint(portParsed)) > 16 {
return "", "", fmt.Errorf("invalid dial address %q port: %v", protocol, port)
}
portParsedU16 := uint16(portParsed)

for _, p := range msg.Policies {
if p.Network != "" && p.Network != network {
continue
}
if p.Host != "" && normalizeHost(p.Host) != host {
continue
}
if p.Port != 0 && p.Port != portParsedU16 {
continue
}

return network, fullAddr, nil
}

return "", "", fmt.Errorf("connections are not permitted to %q", err)
}
91 changes: 87 additions & 4 deletions wsnet/proto.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package wsnet

import (
"fmt"
"math/bits"
"net"
"strconv"
"strings"

"github.com/pion/webrtc/v3"
)

Expand All @@ -9,22 +15,39 @@ import (
type DialPolicy struct {
// If network is empty, it applies to all networks.
Network string `json:"network"`
// If host is empty, it applies to all hosts. "localhost" and any IP address
// under "127.0.0.0/8" can be used interchangeably.
// Host is the IP or hostname of the address. It should not contain the
// port.If empty, it applies to all hosts. "localhost", [::1], and any IPv4
// address under "127.0.0.0/8" can be used interchangeably.
Host string `json:"address"`
// If port is 0, it applies to all ports.
Port uint16 `json:"port"`
}

// ProtoMessage is used for brokering a dialer and listener.
// permits checks if a DialPolicy permits a specific network + host + port
// combination. The host must be put through normalizeHost first.
func (p DialPolicy) permits(network, host string, port uint16) bool {
if p.Network != "" && p.Network != network {
return false
}
if p.Host != "" && canonicalizeHost(p.Host) != host {
return false
}
if p.Port != 0 && p.Port != port {
return false
}

return true
}

// BrokerMessage is used for brokering a dialer and listener.
//
// Dialers initiate an exchange by providing an Offer,
// along with a list of ICE servers for the listener to
// peer with.
//
// The listener should respond with an offer, then both
// sides can begin exchanging candidates.
type ProtoMessage struct {
type BrokerMessage struct {
// Dialer -> Listener
Offer *webrtc.SessionDescription `json:"offer"`
Servers []webrtc.ICEServer `json:"servers"`
Expand All @@ -40,6 +63,50 @@ type ProtoMessage struct {
Candidate string `json:"candidate"`
}

// getAddress parses the data channel's protocol into an address suitable for
// net.Dial. It also verifies that the BrokerMessage permits connecting to said
// address.
func (msg BrokerMessage) getAddress(protocol string) (netwk, addr string, err error) {
parts := strings.SplitN(protocol, ":", 2)
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid dial address: %v", protocol)
}
host, port, err := net.SplitHostPort(parts[1])
if err != nil {
return "", "", fmt.Errorf("invalid dial address: %v", protocol)
}

var (
network = parts[0]
normalHost = canonicalizeHost(host)
// Still return the original host value, not the canonical value.
fullAddr = net.JoinHostPort(host, port)
)
if network == "" {
return "", "", fmt.Errorf("invalid dial address %q network: %v", protocol, network)
}
if host == "" {
return "", "", fmt.Errorf("invalid dial address %q host: %v", protocol, host)
}

portParsed, err := strconv.Atoi(port)
if err != nil || portParsed < 0 || bits.Len(uint(portParsed)) > 16 {
return "", "", fmt.Errorf("invalid dial address %q port: %v", protocol, port)
}
if len(msg.Policies) == 0 {
return network, fullAddr, nil
}

portParsedU16 := uint16(portParsed)
for _, p := range msg.Policies {
if p.permits(network, normalHost, portParsedU16) {
return network, fullAddr, nil
}
}

return "", "", fmt.Errorf("connections are not permitted to %q by policy", protocol)
}

// dialChannelMessage is used to notify a dial channel of a
// listening state. Modeled after net.OpError, and marshalled
// to that if Net is not "".
Expand All @@ -48,3 +115,19 @@ type dialChannelMessage struct {
Net string
Op string
}

// canonicalizeHost converts all representations of "localhost" to "localhost".
func canonicalizeHost(addr string) string {
addr = strings.TrimPrefix(addr, "[")
addr = strings.TrimSuffix(addr, "]")

ip := net.ParseIP(addr)
if ip == nil {
return addr
}

if ip.IsLoopback() {
return "localhost"
}
return addr
}
Loading