-
Notifications
You must be signed in to change notification settings - Fork 926
/
Copy pathvalidation.go
216 lines (187 loc) · 6.32 KB
/
validation.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
package validation
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/pkg/errors"
"golang.org/x/net/idna"
)
const (
defaultScheme = "http"
accessDomain = "cloudflareaccess.com"
accessCertPath = "/cdn-cgi/access/certs"
accessJwtHeader = "Cf-access-jwt-assertion"
)
var (
supportedProtocols = []string{"http", "https", "rdp", "ssh", "smb", "tcp"}
validationTimeout = time.Duration(30 * time.Second)
)
func ValidateHostname(hostname string) (string, error) {
if hostname == "" {
return "", nil
}
// users gives url(http://github.com/contains%20schema) not just hostname
if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") {
unescapeHostname, err := url.PathUnescape(hostname)
if err != nil {
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname)
}
hostnameToURL, err := url.Parse(unescapeHostname)
if err != nil {
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL)
}
asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname())
if err != nil {
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname)
}
return asciiHostname, nil
}
asciiHostname, err := idna.ToASCII(hostname)
if err != nil {
return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname)
}
hostnameToURL, err := url.Parse(asciiHostname)
if err != nil {
return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL)
}
return hostnameToURL.RequestURI(), nil
}
// ValidateUrl returns a validated version of `originUrl` with a scheme prepended (by default http://).
// Note: when originUrl contains a scheme, the path is removed:
//
// ValidateUrl("https://localhost:8080/api/") => "https://localhost:8080"
//
// but when it does not, the path is preserved:
//
// ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
//
// This is arguably a bug, but changing it might break some cloudflared users.
func ValidateUrl(originUrl string) (*url.URL, error) {
urlStr, err := validateUrlString(originUrl)
if err != nil {
return nil, err
}
return url.Parse(urlStr)
}
func validateUrlString(originUrl string) (string, error) {
if originUrl == "" {
return "", fmt.Errorf("URL should not be empty")
}
if net.ParseIP(originUrl) != nil {
return validateIP("", originUrl, "")
} else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") {
// ParseIP doesn't recoginze [::1]
return validateIP("", originUrl[1:len(originUrl)-1], "")
}
host, port, err := net.SplitHostPort(originUrl)
// user might pass in an ip address like 127.0.0.1
if err == nil && net.ParseIP(host) != nil {
return validateIP("", host, port)
}
unescapedUrl, err := url.PathUnescape(originUrl)
if err != nil {
return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl)
}
parsedUrl, err := url.Parse(unescapedUrl)
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
// if the url is in the form of host:port, IsAbs() will think host is the schema
var hostname string
hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != ""
if hasScheme {
err := validateScheme(parsedUrl.Scheme)
if err != nil {
return "", err
}
// The earlier check for ip address will miss the case http://[::1]
// and http://[::1]:8080
if net.ParseIP(parsedUrl.Hostname()) != nil {
return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port())
}
hostname, err = ValidateHostname(parsedUrl.Hostname())
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
if parsedUrl.Port() != "" {
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil
}
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil
} else {
if host == "" {
hostname, err = ValidateHostname(originUrl)
if err != nil {
return "", fmt.Errorf("URL no %s has invalid format", originUrl)
}
return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil
} else {
hostname, err = ValidateHostname(host)
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
// This is why the path is preserved when `originUrl` doesn't have a schema.
// Using `parsedUrl.Port()` here, instead of `port`, would remove the path
return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
}
}
}
func validateScheme(scheme string) error {
for _, protocol := range supportedProtocols {
if scheme == protocol {
return nil
}
}
return fmt.Errorf("Currently Cloudflare Tunnel does not support %s protocol.", scheme)
}
func validateIP(scheme, host, port string) (string, error) {
if scheme == "" {
scheme = defaultScheme
}
if port != "" {
return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil
} else if strings.Contains(host, ":") {
// IPv6
return fmt.Sprintf("%s://[%s]", scheme, host), nil
}
return fmt.Sprintf("%s://%s", scheme, host), nil
}
// Access checks if a JWT from Cloudflare Access is valid.
type Access struct {
verifier *oidc.IDTokenVerifier
}
func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) {
domainURL, err := validateUrlString(domain)
if err != nil {
return nil, err
}
issuerURL, err := validateUrlString(issuer)
if err != nil {
return nil, err
}
// An issuerURL from Cloudflare Access will always use HTTPS.
issuerURL = strings.Replace(issuerURL, "http:", "https:", 1)
keySet := oidc.NewRemoteKeySet(ctx, domainURL+accessCertPath)
return &Access{oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ClientID: applicationAUD})}, nil
}
func (a *Access) Validate(ctx context.Context, jwt string) error {
token, err := a.verifier.Verify(ctx, jwt)
if err != nil {
return errors.Wrapf(err, "token is invalid: %s", jwt)
}
// Perform extra sanity checks, just to be safe.
if token == nil {
return fmt.Errorf("token is nil: %s", jwt)
}
if !strings.HasSuffix(token.Issuer, accessDomain) {
return fmt.Errorf("token has non-cloudflare issuer of %s: %s", token.Issuer, jwt)
}
return nil
}
func (a *Access) ValidateRequest(ctx context.Context, r *http.Request) error {
return a.Validate(ctx, r.Header.Get(accessJwtHeader))
}