Skip to content

Commit d7f24b9

Browse files
committed
Make ParseURI() compatible with lib/pq's TLS keywords.
Add support for: - `sslrootcert` - `sslcert` - `sslkey` All three arguments, like thir `gitub.com/lib/pq` counterparts, are filesystem paths.
1 parent 4506a3e commit d7f24b9

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

conn.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ import (
44
"context"
55
"crypto/md5"
66
"crypto/tls"
7+
"crypto/x509"
78
"encoding/binary"
89
"encoding/hex"
910
"fmt"
1011
"io"
12+
"io/ioutil"
1113
"net"
1214
"net/url"
1315
"os"
@@ -706,9 +708,46 @@ func ParseURI(uri string) (ConnConfig, error) {
706708
return cp, err
707709
}
708710

711+
// Extract optional TLS parameters and reconstruct a coherent tls.Config based
712+
// on the DSN input. Reuse the same keywords found in github.com/lib/pq.
713+
if cp.TLSConfig != nil {
714+
{
715+
caCertPool := x509.NewCertPool()
716+
717+
caPath := url.Query().Get("sslrootcert")
718+
caCert, err := ioutil.ReadFile(caPath)
719+
if err != nil {
720+
return cp, errors.Wrapf(err, "unable to read CA file %q", caPath)
721+
}
722+
723+
if !caCertPool.AppendCertsFromPEM(caCert) {
724+
return cp, errors.Wrap(err, "unable to add CA to cert pool")
725+
}
726+
727+
cp.TLSConfig.RootCAs = caCertPool
728+
cp.TLSConfig.ClientCAs = caCertPool
729+
}
730+
731+
sslcert := url.Query().Get("sslcert")
732+
sslkey := url.Query().Get("sslkey")
733+
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
734+
return cp, fmt.Errorf(`both "sslcert" and "sslkey" are required`)
735+
}
736+
737+
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
738+
if err != nil {
739+
return cp, errors.Wrap(err, "unable to read cert")
740+
}
741+
742+
cp.TLSConfig.Certificates = []tls.Certificate{cert}
743+
}
744+
709745
ignoreKeys := map[string]struct{}{
710-
"sslmode": {},
711746
"connect_timeout": {},
747+
"sslcert": {},
748+
"sslkey": {},
749+
"sslmode": {},
750+
"sslrootcert": {},
712751
}
713752

714753
cp.RuntimeParams = make(map[string]string)

conn_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ func TestConnectWithTLSFallback(t *testing.T) {
228228
}
229229

230230
connConfig.UseFallbackTLS = true
231-
connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
231+
connConfig.FallbackTLSConfig = tlsConnConfig.TLSConfig
232+
connConfig.FallbackTLSConfig.InsecureSkipVerify = true
232233

233234
conn, err = pgx.Connect(connConfig)
234235
if err != nil {

0 commit comments

Comments
 (0)