From 2fe22126a4ffc796d4baff7e005f0d3f27e2b3f4 Mon Sep 17 00:00:00 2001 From: Christian Ang Date: Wed, 17 Oct 2018 09:31:38 -0700 Subject: [PATCH] WIP: Add skip hostname verification to mysql conn - Also add client cert and client key [#160840442] Signed-off-by: Michael Oleske --- db/config.go | 52 +--- db/config_test.go | 57 ++-- db/mysql_adapter.go | 17 ++ db/mysql_connection_string_builder.go | 107 +++++++ db/mysql_connection_string_builder_test.go | 309 +++++++++++++++++++++ fakes/mysql_adapter.go | 165 +++++++++++ 6 files changed, 632 insertions(+), 75 deletions(-) create mode 100644 db/mysql_adapter.go create mode 100644 db/mysql_connection_string_builder.go create mode 100644 db/mysql_connection_string_builder_test.go create mode 100644 fakes/mysql_adapter.go diff --git a/db/config.go b/db/config.go index 2c526ac3..313a4ee3 100644 --- a/db/config.go +++ b/db/config.go @@ -3,10 +3,6 @@ package db import ( "fmt" "time" - "io/ioutil" - "crypto/x509" - "crypto/tls" - "github.com/go-sql-driver/mysql" ) type Config struct { @@ -19,6 +15,8 @@ type Config struct { DatabaseName string `json:"database_name" validate:""` RequireSSL bool `json:"require_ssl" validate:""` CACert string `json:"ca_cert" validate:""` + ClientCert string + ClientKey string } func (c Config) ConnectionString() (string, error) { @@ -30,49 +28,11 @@ func (c Config) ConnectionString() (string, error) { ms := (time.Duration(c.Timeout) * time.Second).Nanoseconds() / 1000 / 1000 return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable&connect_timeout=%d", c.User, c.Password, c.Host, c.Port, c.DatabaseName, ms), nil case "mysql": - return c.buildMysqlConnectionString() + mysqlConnectionStringBuilder := &MySQLConnectionStringBuilder{ + MySQLAdapter: &MySQLAdapter{}, + } + return mysqlConnectionStringBuilder.Build(c) default: return "", fmt.Errorf("database type '%s' is not supported", c.Type) } } - -func (c Config) buildMysqlConnectionString() (string, error) { - connString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true", c.User, c.Password, c.Host, c.Port, c.DatabaseName) - dbConfig, err := mysql.ParseDSN(connString) - if err != nil { - return "", fmt.Errorf("parsing db connection string: %s", err) - } - - timeoutDuration := time.Duration(c.Timeout) * time.Second - dbConfig.Timeout = timeoutDuration - dbConfig.ReadTimeout = timeoutDuration - dbConfig.WriteTimeout = timeoutDuration - - if c.RequireSSL { - certBytes, err := ioutil.ReadFile(c.CACert) - if err != nil { - return "", fmt.Errorf("reading db ca cert file: %s", err) - } - - caCertPool := x509.NewCertPool() - if ok := caCertPool.AppendCertsFromPEM(certBytes); !ok { - return "", fmt.Errorf("appending cert to pool from pem - invalid cert bytes") - } - - tlsConfig := &tls.Config{ - InsecureSkipVerify: false, - RootCAs: caCertPool, - } - - tlsConfigName := fmt.Sprintf("%s-tls", c.DatabaseName) - - err = mysql.RegisterTLSConfig(tlsConfigName, tlsConfig) - if err != nil { - return "", fmt.Errorf("registering mysql tls config: %s", err) - } - - dbConfig.TLSConfig = tlsConfigName - } - - return dbConfig.FormatDSN(), nil -} diff --git a/db/config_test.go b/db/config_test.go index 0cdd8b2a..1441abd4 100644 --- a/db/config_test.go +++ b/db/config_test.go @@ -5,40 +5,40 @@ import ( "code.cloudfoundry.org/cf-networking-helpers/db" + "github.com/go-sql-driver/mysql" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "github.com/go-sql-driver/mysql" ) const ( DATABASE_CA_CERT = `-----BEGIN CERTIFICATE----- -MIIE4jCCAsqgAwIBAgIBATANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDEwZmYWtl -Q0EwHhcNMTgwNTEwMjM1MDM2WhcNMTkxMTEwMjM1MDM2WjARMQ8wDQYDVQQDEwZm -YWtlQ0EwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQC3u/I7qztSp8rO -S266wo53NtqdtM/8iyyCigqCwHgJ7CauGKq33zTBaUkRljeRn/AXvkChPPEA3KQm -Wrv5YkFhCq/1EOB2JUMPVrUXjP/g6HwPAHX2IvC0pJoYMmb7TloGXfTjV/A/2e41 -Q1zRSWAMDXCUfnAn6skkihV9YGipdM/r0+O9n8tb3F+Z+dYvMu89DwyptI/oNzNK -DyzkQf5WZ1PCqEow7ZcbSQP3RH2Ds6I+AG98nxB4irsmUkoZnUhQzTc9DpINmgI8 -3Yg8YyTFODZ1BbsnST1Y01rWMvVkXy/89+fBqN4kGI12CYtbh69Shr/0cim8fETT -N9CLzqDpPlnfAJGv/VwVSzwxhYuYBfh3PtlAo5OfVBhYoGq9npjV3H/j9N9r0aE4 -MkQvYaATB/fQ823mtjLDqtkIvZCXq1PZA90oQ87n1FPeklc9/T14SXNcHVMBMpSX -mPdaJvBoXjlwl1EKvZIQzz/luxMZfgqSRy4TLcJKJ+E+3bU2RZcz56r5aTV8+9aS -/SL80oQpGzXK4pWcFvELlGcW2LnP7XPE3t1HzS0kEVSFnVyw4/UJvcsyZSUl2bDs -FJl0HOkVuNtjnhCKTiRpRTYdKxKhvxp46/0FMtnujIq1WQF6yKUK3mOUpJpaJe+4 -3fd7UsX8Qz0Gwj9scxBCTTNeVXkU8QIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAQYw -EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUsUmpQ5m8K0hqWlw/7M5Q1wG0 -FO4wDQYJKoZIhvcNAQELBQADggIBAKWrPjCEYWWWnWWjIFazsbe98eSu7N4yTDtc -yN4D+k0sIWbbgKbeAV9k5N5H8p9w5tNzjsjUCK88qdEvN+0kJHWCvt2zffBlP6tX -nC2bB12CjPPYIUpnG14ghZB/Uxj240eo/1JCrsb/qTecW2H3UbLmjtmx10RJVP7U -kseGnsXQwPIEgOVHubVLkIobv88zLSJKgf8syhnbihl5/eKIBMfreaI7mW2+CqTE -Y1SfIPTpU59YHW7TNLsI9WgQNtDqCORKwzzpVnUWPfO0iQ1+wEnjfhsMDOmzfBLO -l1HspfZpRnOXZFROnuNvR1V+qyPrKMm/F01B7Z4ESxa7ktNEbrwOt2wGT/keXocw -z1LXbrG/WBjov4HCD6pXv+w4XwkR9bPEHkrMZ/INCm5oIq7JLTZcjb56aawqAk8W -0XXKhjFTIGO46GPTbcJTWxs3BJX8C2mL5aVHWekJfuXsCU/0GIxudV8VQJqulq/1 -dlZjOpycEZ11hWkZENsJ8ddDX0eYTR95MGAq8J7m1Q0Ts0X/d5ATc1mREf3wqhSn -TFFl82cBZE15vJfk5ekNof8Hx2NTZYwfplKKqb8epo2pIA/j3/PRjo80AFyicjoY -7/Xiu2K2JGmsEF3XQVowXVsxngkLZSqHml+WRweqaK48Zbojj/hUkz+xOAoucqZO -VCbEyl6T +MIIE5DCCAsygAwIBAgIBATANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDEwdteXNx +bENBMB4XDTE4MTAxNjIwNDMzN1oXDTIwMDQxNjIwNDMzN1owEjEQMA4GA1UEAxMH +bXlzcWxDQTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBANiduEyqVkST +hwG7TnihP/GHpvzJwdJVuhxAa0Qd++uo8VsD8cqHKX6IbmwcVUgrBDikY0kaks7g +el5bwUDg4oi/1pTdT+vbrOZLCgksO043/34zLe2itVsYVABnvqMnOZM+bAr7Im9V +wb6rfXAGulow1vmzuzxGtGGV4x5ETCOA+SnytxHH1L57RgAmUJnEK/Ks4CEWoByR +v8kejd8KScAo2pZ2ldlXJk+ggSifbyxCrMaA/E0HfA2epnhBy9nhbBXx4/p35SIW +B+Nzv9FMmb8PZ0AHw7PEg3WJzrrrYZXyKZL1EmGT9w0FzxcirYFmmWL/Zzm4lkHM +YvwOI3eCPoPFYJcgzJuJbtbcXkiCIZ2M5QunIGKgACBDw/YHNrZMHGCFplZxT2E0 +rJTq01ZWC4wKhpKGtTkvZzR9SC3tyK/hbcjjP+jtSc3ZA7V2h8Lwa8ZfydurJQ2n +fgKpiaXpiqvv36wpfXzieywooioh86qsZhUs39V8JPr8HK11PXKT9ZRwmdcTjcBA +9gfauxUBa/NKAGIPoxe0+1JTWFwxt85hEZPMq/A+6zPH6smgDVV8SayEoYQQ7yLr +/mtKmEZ7QTl/0prTBQPsfIowvHr2MIPOiNrbuoUvV+gKZfhGMeLmOa8BxQ/FHiUe +WZ5Xhd9a0xnOnQb+QdQAfAoEaMYdPeQpAgMBAAGjRTBDMA4GA1UdDwEB/wQEAwIB +BjASBgNVHRMBAf8ECDAGAQH/AgEAMB0GA1UdDgQWBBTVi+UqriPLeufRqMwe49lv +WU333DANBgkqhkiG9w0BAQsFAAOCAgEAUacizdIDaONqXYaERZky5vsTS8rCrL6w +bHj9n9iGpx9irMREnEUFuecttkuZbDWQqXiH//vuUbbE4YoiVwT2GlLlBAvpYsXY +qDn011c7ZjU4Aw5voJpaR25lCpVrztev5//Cyz9KIm0aZ01ARjWDcPSg4GRyjnFj +fIszidqquRr0lrutNEBjyKxibMoDzkgbXXCdh3hjdVJX7zhMN7wcFqcx5/8f+fb3 +EcO0qiT385LUh3qlhg5w9t/gxblwsnQK6X241O8nDoZgxvnW62RN4GqUn1ZtLeGs +pylBCQ/CePSasDN4mJRmPxMKKiKJyp0XdvgSagseq5kW+Zaz6H04QdCjgBOgrPdD +UkWnb8hQiVsboPxpc//a0JIsXZ0krb2UkSv4JrIYOVa/4lRaj9Ie5weBZkYmd+Kp +7f//4UezWSDfWv7S1GbxF+d8rWkcZksV9/es2GhspH6oM2GtE+7198R12XIq8aVk +X7e056LpGxjMy6rvnVId3NwITdk6VB5SnsJL0RaGIu1YXhs9HR/Q9TCuWABJEJ/M +P1zzwuCu9cIOfXYzAGV5miS6FsgQEvxFNp15U4bS/Mbrct/6Z6JFo96ueSjOvehb +RSper1U+5n6G+LEHYrn8mpl1T/YkVTgmrKrxNFdsw9YYWFhut8Mh/N04Pd5Y1yzi +hB1P/1ZlKVU= -----END CERTIFICATE-----` ) @@ -132,7 +132,6 @@ var _ = Describe("Config", func() { config.CACert = caCertFile.Name() }) - It("returns an error", func() { _, err := config.ConnectionString() Expect(err).To(HaveOccurred()) diff --git a/db/mysql_adapter.go b/db/mysql_adapter.go new file mode 100644 index 00000000..d4f11e53 --- /dev/null +++ b/db/mysql_adapter.go @@ -0,0 +1,17 @@ +package db + +import ( + "crypto/tls" + + "github.com/go-sql-driver/mysql" +) + +type MySQLAdapter struct{} + +func (m MySQLAdapter) ParseDSN(dsn string) (cfg *mysql.Config, err error) { + return mysql.ParseDSN(dsn) +} + +func (m MySQLAdapter) RegisterTLSConfig(key string, config *tls.Config) error { + return mysql.RegisterTLSConfig(key, config) +} diff --git a/db/mysql_connection_string_builder.go b/db/mysql_connection_string_builder.go new file mode 100644 index 00000000..ae029926 --- /dev/null +++ b/db/mysql_connection_string_builder.go @@ -0,0 +1,107 @@ +package db + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "time" + + "github.com/go-sql-driver/mysql" +) + +//go:generate counterfeiter -o ../fakes/mysql_adapter.go --fake-name MySQLAdapter . mySQLAdapter +type mySQLAdapter interface { + ParseDSN(dsn string) (cfg *mysql.Config, err error) + RegisterTLSConfig(key string, config *tls.Config) error +} + +type MySQLConnectionStringBuilder struct { + MySQLAdapter mySQLAdapter +} + +func (m *MySQLConnectionStringBuilder) Build(config Config) (string, error) { + connString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true", config.User, config.Password, config.Host, config.Port, config.DatabaseName) + + dbConfig, err := m.MySQLAdapter.ParseDSN(connString) + if err != nil { + return "", fmt.Errorf("parsing db connection string: %s", err) + } + + timeoutDuration := time.Duration(config.Timeout) * time.Second + dbConfig.Timeout = timeoutDuration + dbConfig.ReadTimeout = timeoutDuration + dbConfig.WriteTimeout = timeoutDuration + + if config.RequireSSL { + dbConfig.TLSConfig = fmt.Sprintf("%s-tls", config.DatabaseName) + + certBytes, err := ioutil.ReadFile(config.CACert) + if err != nil { + return "", fmt.Errorf("reading db ca cert file: %s", err) + } + + caCertPool := x509.NewCertPool() + if ok := caCertPool.AppendCertsFromPEM(certBytes); !ok { + return "", fmt.Errorf("appending cert to pool from pem - invalid cert bytes") + } + + tlsConfig := &tls.Config{ + InsecureSkipVerify: false, + RootCAs: caCertPool, + } + + if config.ClientCert != "" && config.ClientKey != "" { + clientCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey) + if err != nil { + return "", fmt.Errorf("loading key pair: %s", err) + } + + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + RootCAs: caCertPool, + Certificates: []tls.Certificate{clientCert}, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + return VerifyCertificatesIgnoreHostname(rawCerts, caCertPool) + }, + } + } + err = m.MySQLAdapter.RegisterTLSConfig(dbConfig.TLSConfig, tlsConfig) + if err != nil { + return "", fmt.Errorf("registering mysql tls config: %s", err) + } + } + + return dbConfig.FormatDSN(), nil +} + +func VerifyCertificatesIgnoreHostname(rawCerts [][]byte, caCertPool *x509.CertPool) error { + certs := make([]*x509.Certificate, len(rawCerts)) + for i, asn1Data := range rawCerts { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return fmt.Errorf("tls: failed to parse certificate from server: %s", err) + } + certs[i] = cert + } + + opts := x509.VerifyOptions{ + Roots: caCertPool, + CurrentTime: time.Now(), + Intermediates: x509.NewCertPool(), + } + + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + _, err := certs[0].Verify(opts) + if err != nil { + return err + } + + return nil +} diff --git a/db/mysql_connection_string_builder_test.go b/db/mysql_connection_string_builder_test.go new file mode 100644 index 00000000..05970e74 --- /dev/null +++ b/db/mysql_connection_string_builder_test.go @@ -0,0 +1,309 @@ +package db_test + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "io/ioutil" + + "code.cloudfoundry.org/cf-networking-helpers/db" + "code.cloudfoundry.org/cf-networking-helpers/fakes" + + "github.com/go-sql-driver/mysql" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +const ( + DATABASE_CLIENT_CERT = `-----BEGIN CERTIFICATE----- +MIIEJDCCAgygAwIBAgIRAPaxi331A4Tad6C4UzTuU80wDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEAxMHbXlzcWxDQTAeFw0xODEwMTYyMDQ0MjNaFw0yMDA0MTYyMDQz +MzdaMBYxFDASBgNVBAMTC215c3FsQ2xpZW50MIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAonUIEmTXRwXeE4VpCNwj0A92XGVXPBAUrIoxPPzQiy8ey9wR +JqWKCPQY/g2vkEme/4uNIN+o8iI4COYmPaIuRv1tqot4U2/mhfeDH79+E7oc97FX +AnksLTEni+zOBtJUOQd1KOF6TlyVP3PRn9m+QQxLU0qp5TPpSIuE11E4SeVShocx +65AcQPgmu5+YBKCSVN7J5hosXm8Wtd/d28lJ36WBumZFUa+qD3DMiWUqU+AcXeCh +GJOjiE5osalg9UaWSuMnag+wvTXEWYP3qd3eThFcfS4Bj9/1XABPvdlyug15behG +2iU1Ra4UO35zRb6jF6Ax9Tc+14FDzVsaHBNg1QIDAQABo3EwbzAOBgNVHQ8BAf8E +BAMCA7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRL +AKaQdBXjRQ6AsA7oduLGpBe7sTAfBgNVHSMEGDAWgBTVi+UqriPLeufRqMwe49lv +WU333DANBgkqhkiG9w0BAQsFAAOCAgEAaCUpDpTL2f5hQNh5eLeeEgdmnci+9ju7 +MfLuhDbn9Ft4vyqoHUkbgyPThNBaA3ENWxu2Q4dsTshSAxg1QG50kZDTO9u5z/Ge +tuDGghzC6+Zw8xbfXUlpkSWgOcxUvVTKuPwNfoIgmjpFxmnJ96LeFT68ORspwrYo +7yh2ffn3oLIRMTKgVaoPqF3EzWcX/9Ij+TnnZQsfZkEQaMWrKorc/IioBLUsvhKt +YV4+GJb2jtr+T/kNgahMyeE6tZQWfAsvu11TA25QCyfuccy5EJcu89U5pqFR/0cf +jpvyUU/ODGqgMzOYEDkfZyrLfUUGCb4i1rv4uLNVIE/nzZdDM+Lw9tFL6oh8p9uP +Y8z2EqevhL2WN6RO+IpnpzMO3QZtzJMPCHFrLjHfAPD8qNB1HOp2mLQMe16ILd6P +rSFteZqLuBQ2rvej9JlU94+uMszV2/JVxmM2MZstyXGgqnD9TtLtWBKHAJg7gVqh +s+3Vg8CzDwJYMzVYRySkDIMBq+CTtxyt+AEinVum4PX8Zno43Z60k7VWQJZV1c3J +bi73lsmhEMhRmyC4rfIJQWF2r3ExV4Qhc8ISfUZgWJB7K3f1GxS1R5zGXa2GvOVQ +mar20iTXssU3qbyaOjVI3+8P1chBz3cUNLWT5WzZWB5EMdtHGyjCKnqE6XTZcVmG +Z8Xwl8mRjMI= +-----END CERTIFICATE-----` + + DATABASE_CLIENT_KEY = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAonUIEmTXRwXeE4VpCNwj0A92XGVXPBAUrIoxPPzQiy8ey9wR +JqWKCPQY/g2vkEme/4uNIN+o8iI4COYmPaIuRv1tqot4U2/mhfeDH79+E7oc97FX +AnksLTEni+zOBtJUOQd1KOF6TlyVP3PRn9m+QQxLU0qp5TPpSIuE11E4SeVShocx +65AcQPgmu5+YBKCSVN7J5hosXm8Wtd/d28lJ36WBumZFUa+qD3DMiWUqU+AcXeCh +GJOjiE5osalg9UaWSuMnag+wvTXEWYP3qd3eThFcfS4Bj9/1XABPvdlyug15behG +2iU1Ra4UO35zRb6jF6Ax9Tc+14FDzVsaHBNg1QIDAQABAoIBACnnuEpOWr2GRO+S +JTLU3iQIKQbSWTs0BrEvAF5z9DNC11XMkVv/rWh71oqJ6zRz2SCf1aqaJtE2hG+/ +NjQFxpwnOQeZ7FLRdYwu+VLSKWpbQqedxgzsRrntiP7t+YMG9BS12MHPz6Ww+gqh +DHyIRSwwSKnWg5aM2msNGhoUaEmfBVzj7i2zaO2B3MGERMIC0ZYT2qwa4LOPyxOO +UkgXrsKaGPAlLLxR0q3uLhCueFpfQSwvdl7s03s9iz3gmdccwlpixhPtiBVOAwid +IfywswqYxcYj7jotv2LazuJQNPpfLiTDOiSGFPwzs1XJiZar2Q9aMjvfg6+s2Mzb +I7Tn0e0CgYEAwYM4smPnaz0jvJjuYd1hqPE+y+TmX4hQJfdt7nggmzQ2etKGiLbG +navrMu2KKnl+7PMQG/X3ok220ojybAZlOD89gvHQhRJsnI4a2BVCFB9u4MCnIfvJ +4CxTMH2qQ3Nux8r0CNuY6x2ZQqn7i7PPE+fJnkJJ2N5WTUaxFIsD4cMCgYEA1uqa +bSa/QsWbxivxy0/hzPtLin239b9SjgS8I/YJOVlbvb58Nz6SgOvcR/zbNf5GiiZu +p5iHSEvLchDZ2f9NSaUjnSExZ67fgIaUxdfVxzHFr64GHkFrJmzttnyLcwTVmKK0 +xwQaLROsNvYoGUhoYLlGuG+Wlm3yLtlZKY/bMYcCgYBykxI3tSUo/nsxSE8kTKJt +F+F5cZ7hA2GJCTXikuejXUfAcvPK8IUqh8brUW+T9HmtK8Dm/TxQsbjEcOcwBJ1b +rz3pUOmIUL9T9mN4eyWzqmTI1+hdG6qMe1IKDO2JoEgALW9N609gLhc3PFO+hIjg +HUXn2RHGQOZSPL/ODP0QZwKBgAWnTkCo0Ec1Y4+nAElU5J+7zJTsEbbJPaa2wSxB +AKUdkKhBJotdfgUeL0FFiY62Daz8rdSC0qw4MjXh85kkeigBzBoKEX6kvwRmhete +biU7TfP9I/QPzH3KR8aRKCnyapwFS7Qgi3+8EL+xYgSoPvasaQvZA6EZa1GILixF +uIJpAoGBALXgHfbnc+BQDDfdmdGeRpD+3UyB0ThqTn7DDXjXTNlZ7tjcKjmMi1uw +oL2J2cUvbU3Y5R/ppJw0hQNgAXfiax1xuRATUHNweJXA57YwsbrRboxyyzPASQdk +NyGiWUJ68TPzhj9AkypCQfxhqK1Qve1TALPZ6N5lAE7erYsAJbGq +-----END RSA PRIVATE KEY-----` + + CERTIFICATE_FROM_ANOTHER_CA = `-----BEGIN CERTIFICATE----- +MIIEIzCCAgugAwIBAgIRAKRY04+EGcnAfPpeLf3dxZcwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEAxMHbXlzcWxDQTAeFw0xODEwMTYyMjU2NDdaFw0xODEwMTYyMzU2 +NDdaMBUxEzARBgNVBAMTCmV4cGlyZVNvb24wggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQCZ/8fZc0q1I03L78hro9jr987Tn6hsJNGod61GiWOHybHezt5i ++SOp7S/fdmLQsRSopUOSmlAH6ta5QbffGbtHY2NJQKJq7N8KRt6aSfbHxPDG96Rp +Q0OZZLyiEaFz2jECoTjqwZX9duG5wA1/AVnZEKqnbAWdIWP9AOTzwdJ/ne4CLzyj +Lm/HUNi9xsZvU5xgb8ZSW3z8SOf39UedocmDcA/rTZWAkO6ELPvx4KD6t5aBC4ir +k7tGveQFxTvziZr3lNZk+NTX2OWUrz5yoH/nMiXtHe4JuytFsN5DYF1f6/3Fxl19 +AhkCkxTj238/FFLID34W7mfZbgN59ByBgnPxAgMBAAGjcTBvMA4GA1UdDwEB/wQE +AwIDuDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwHQYDVR0OBBYEFJuI +c/+CukBgDTwmU6A+6++WwzaKMB8GA1UdIwQYMBaAFFmJZ4pg0d2UuDKpQzF3XQAM +LandMA0GCSqGSIb3DQEBCwUAA4ICAQAh7+THwf0fe3syAaPVqnpx2kswUAqP9VTw +waxXswwp632JnQa9vctuVBQ7DNwOHSixaNlM7yR+w1FlubwLzNRR5EXOgi2kl5Le +mewKBmJLpMwkmAbpCUB2B2ofJJguMe0JVQC6OC3eA3JsTc1/FtqJ4H1+RD5xT6hx +uOxla3zwfynYD4WdRMAosYVJouCScgWJpK+MWEkMCx94GUcO4Ik9acWhzBcdgaUG +qjbtTq5dHgVwernhJaiuUC2R5wEvb3rkhav2TYHJucFm0NHFbMCCYNbFAp1t1OyW +hiNrGtUGN2jBoFZ9OEZaWuY00mKs0Elp5/ugHQ5hW6HXam/4Fh95PMBR1QC+c5AC +AhdCYEXpZXkjCe5vnXHegBxAMV2FU33G9rPWWAi76sBlqjApGaYfbYJW63bhEOZT +AtnHlrPVw/GM16KkzMEEbi4lRvY4F3F2FJ+LZSMKMNs9aX/CAAWs9up3n7PcePP0 +fV70C2hVtCJbIfRPaWvrVAAktBP9xLTnzUvzijPLMEJ9o45vWdrtvyBFknQCpMts +lw6sWU26m2gvxs6CcX3yt0bt8SxjqyulqrOdFCVSjZbGMDaIamdEKnC6k5ySyizn +SM2qNm+nV5FhjsyMyzs6OuCNEZGDAqklWBAHHqLncb6elO9NZgDysB/xn6jS+zqT +F1Y5M6wvLA== +-----END CERTIFICATE-----` +) + +var _ = Describe("MySQLConnectionStringBuilder", func() { + Describe("Build", func() { + var ( + mysqlConnectionStringBuilder *db.MySQLConnectionStringBuilder + mySQLAdapter *fakes.MySQLAdapter + + config db.Config + ) + + BeforeEach(func() { + config = db.Config{ + User: "some-user", + Password: "some-password", + Host: "some-host", + Port: uint16(1234), + DatabaseName: "some-database", + Timeout: 5, + } + + mySQLAdapter = &fakes.MySQLAdapter{} + + mysqlConnectionStringBuilder = &db.MySQLConnectionStringBuilder{ + MySQLAdapter: mySQLAdapter, + } + mySQLAdapter.ParseDSNStub = func(dsn string) (cfg *mysql.Config, err error) { + return mysql.ParseDSN(dsn) + } + }) + + It("builds a connection string", func() { + connectionString, err := mysqlConnectionStringBuilder.Build(config) + Expect(err).NotTo(HaveOccurred()) + Expect(connectionString).To(Equal("some-user:some-password@tcp(some-host:1234)/some-database?parseTime=true&readTimeout=5s&timeout=5s&writeTimeout=5s")) + }) + + Context("when mysql.ParseDSN can't parse the connection string", func() { + BeforeEach(func() { + mySQLAdapter.ParseDSNReturns(nil, errors.New("foxtrot")) + }) + + It("returns an error", func() { + _, err := mysqlConnectionStringBuilder.Build(config) + Expect(err).To(MatchError("parsing db connection string: foxtrot")) + }) + }) + + Context("when requiring ssl", func() { + var ( + caCertPool *x509.CertPool + ) + + BeforeEach(func() { + caCertFile, err := ioutil.TempFile("", "") + _, err = caCertFile.Write([]byte(DATABASE_CA_CERT)) + Expect(err).NotTo(HaveOccurred()) + + config.RequireSSL = true + config.CACert = caCertFile.Name() + + caCertPool = x509.NewCertPool() + ok := caCertPool.AppendCertsFromPEM([]byte(DATABASE_CA_CERT)) + Expect(ok).To(BeTrue()) + }) + + It("builds a tls connection string", func() { + connectionString, err := mysqlConnectionStringBuilder.Build(config) + Expect(err).NotTo(HaveOccurred()) + Expect(connectionString).To(Equal("some-user:some-password@tcp(some-host:1234)/some-database?parseTime=true&readTimeout=5s&timeout=5s&tls=some-database-tls&writeTimeout=5s")) + + Expect(mySQLAdapter.RegisterTLSConfigCallCount()).To(Equal(1)) + passedTLSConfigName, passedTLSConfig := mySQLAdapter.RegisterTLSConfigArgsForCall(0) + Expect(passedTLSConfigName).To(Equal("some-database-tls")) + Expect(passedTLSConfig).To(Equal(&tls.Config{ + InsecureSkipVerify: false, + RootCAs: caCertPool, + })) + }) + + Context("when a client cert and key is provided", func() { + var ( + clientCert tls.Certificate + ) + + BeforeEach(func() { + clientCertFile, err := ioutil.TempFile("", "") + _, err = clientCertFile.Write([]byte(DATABASE_CLIENT_CERT)) + Expect(err).NotTo(HaveOccurred()) + + clientKeyFile, err := ioutil.TempFile("", "") + _, err = clientKeyFile.Write([]byte(DATABASE_CLIENT_KEY)) + Expect(err).NotTo(HaveOccurred()) + + config.ClientCert = clientCertFile.Name() + config.ClientKey = clientKeyFile.Name() + + clientCert, err = tls.LoadX509KeyPair(clientCertFile.Name(), clientKeyFile.Name()) + Expect(err).NotTo(HaveOccurred()) + }) + + It("builds a tls config with the client cert and key", func() { + connectionString, err := mysqlConnectionStringBuilder.Build(config) + Expect(err).NotTo(HaveOccurred()) + Expect(connectionString).To(Equal("some-user:some-password@tcp(some-host:1234)/some-database?parseTime=true&readTimeout=5s&timeout=5s&tls=some-database-tls&writeTimeout=5s")) + + Expect(mySQLAdapter.RegisterTLSConfigCallCount()).To(Equal(1)) + passedTLSConfigName, passedTLSConfig := mySQLAdapter.RegisterTLSConfigArgsForCall(0) + Expect(passedTLSConfigName).To(Equal("some-database-tls")) + Expect(passedTLSConfig.InsecureSkipVerify).To(BeTrue()) + Expect(passedTLSConfig.RootCAs).To(Equal(caCertPool)) + Expect(passedTLSConfig.Certificates).To(Equal([]tls.Certificate{clientCert})) + // impossible to assert VerifyPeerCertificate is set to a specfic function + Expect(passedTLSConfig.VerifyPeerCertificate).NotTo(BeNil()) + }) + + Context("when loading key pairs errors", func() { + BeforeEach(func() { + config.ClientCert = "/foo/bar" + config.ClientKey = "/foo/bar" + }) + + It("returns an error", func() { + _, err := mysqlConnectionStringBuilder.Build(config) + Expect(err).To(MatchError("loading key pair: open /foo/bar: no such file or directory")) + }) + }) + }) + + Context("when it can't read the ca cert file", func() { + BeforeEach(func() { + config.CACert = "/foo/bar" + }) + + It("returns an error", func() { + _, err := mysqlConnectionStringBuilder.Build(config) + Expect(err).To(MatchError("reading db ca cert file: open /foo/bar: no such file or directory")) + }) + }) + + Context("when it can't append the ca cert to the cert pool", func() { + BeforeEach(func() { + caCertFile, err := ioutil.TempFile("", "") + _, err = caCertFile.Write([]byte("bad cert")) + Expect(err).NotTo(HaveOccurred()) + + config.CACert = caCertFile.Name() + }) + + It("returns an error", func() { + _, err := mysqlConnectionStringBuilder.Build(config) + Expect(err).To(MatchError("appending cert to pool from pem - invalid cert bytes")) + }) + }) + + Context("when it can't register TLS config", func() { + BeforeEach(func() { + mySQLAdapter.RegisterTLSConfigReturns(errors.New("bad things happened")) + }) + + It("retruns an error", func() { + _, err := mysqlConnectionStringBuilder.Build(config) + Expect(err).To(MatchError("registering mysql tls config: bad things happened")) + }) + }) + }) + }) + + Describe("VerifyCertificatesIgnoreHostname", func() { + var ( + caCertPool *x509.CertPool + ) + + BeforeEach(func() { + caCertPool = x509.NewCertPool() + ok := caCertPool.AppendCertsFromPEM([]byte(DATABASE_CA_CERT)) + Expect(ok).To(BeTrue()) + }) + + It("verifies that provided certificates are valid", func() { + block, _ := pem.Decode([]byte(DATABASE_CLIENT_CERT)) + + err := db.VerifyCertificatesIgnoreHostname([][]byte{ + block.Bytes, + }, caCertPool) + Expect(err).NotTo(HaveOccurred()) + }) + + Context("when raw certs are not parsable", func() { + It("returns an error", func() { + err := db.VerifyCertificatesIgnoreHostname([][]byte{ + []byte("foo"), + []byte("bar"), + }, nil) + Expect(err).To(MatchError("tls: failed to parse certificate from server: asn1: structure error: tags don't match (16 vs {class:1 tag:6 length:111 isCompound:true}) {optional:false explicit:false application:false private:false defaultValue: tag: stringType:0 timeType:0 set:false omitEmpty:false} certificate @2")) + }) + }) + + Context("when verifying a bad cert", func() { + + It("returns an error", func() { + block, _ := pem.Decode([]byte(CERTIFICATE_FROM_ANOTHER_CA)) + + err := db.VerifyCertificatesIgnoreHostname([][]byte{ + block.Bytes, + }, caCertPool) + + Expect(err).To(MatchError(`x509: certificate signed by unknown authority (possibly because of "crypto/rsa: verification error" while trying to verify candidate authority certificate "mysqlCA")`)) + }) + }) + }) +}) diff --git a/fakes/mysql_adapter.go b/fakes/mysql_adapter.go new file mode 100644 index 00000000..0d078e18 --- /dev/null +++ b/fakes/mysql_adapter.go @@ -0,0 +1,165 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fakes + +import ( + "crypto/tls" + "sync" + + "github.com/go-sql-driver/mysql" +) + +type MySQLAdapter struct { + ParseDSNStub func(dsn string) (cfg *mysql.Config, err error) + parseDSNMutex sync.RWMutex + parseDSNArgsForCall []struct { + dsn string + } + parseDSNReturns struct { + result1 *mysql.Config + result2 error + } + parseDSNReturnsOnCall map[int]struct { + result1 *mysql.Config + result2 error + } + RegisterTLSConfigStub func(key string, config *tls.Config) error + registerTLSConfigMutex sync.RWMutex + registerTLSConfigArgsForCall []struct { + key string + config *tls.Config + } + registerTLSConfigReturns struct { + result1 error + } + registerTLSConfigReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *MySQLAdapter) ParseDSN(dsn string) (cfg *mysql.Config, err error) { + fake.parseDSNMutex.Lock() + ret, specificReturn := fake.parseDSNReturnsOnCall[len(fake.parseDSNArgsForCall)] + fake.parseDSNArgsForCall = append(fake.parseDSNArgsForCall, struct { + dsn string + }{dsn}) + fake.recordInvocation("ParseDSN", []interface{}{dsn}) + fake.parseDSNMutex.Unlock() + if fake.ParseDSNStub != nil { + return fake.ParseDSNStub(dsn) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fake.parseDSNReturns.result1, fake.parseDSNReturns.result2 +} + +func (fake *MySQLAdapter) ParseDSNCallCount() int { + fake.parseDSNMutex.RLock() + defer fake.parseDSNMutex.RUnlock() + return len(fake.parseDSNArgsForCall) +} + +func (fake *MySQLAdapter) ParseDSNArgsForCall(i int) string { + fake.parseDSNMutex.RLock() + defer fake.parseDSNMutex.RUnlock() + return fake.parseDSNArgsForCall[i].dsn +} + +func (fake *MySQLAdapter) ParseDSNReturns(result1 *mysql.Config, result2 error) { + fake.ParseDSNStub = nil + fake.parseDSNReturns = struct { + result1 *mysql.Config + result2 error + }{result1, result2} +} + +func (fake *MySQLAdapter) ParseDSNReturnsOnCall(i int, result1 *mysql.Config, result2 error) { + fake.ParseDSNStub = nil + if fake.parseDSNReturnsOnCall == nil { + fake.parseDSNReturnsOnCall = make(map[int]struct { + result1 *mysql.Config + result2 error + }) + } + fake.parseDSNReturnsOnCall[i] = struct { + result1 *mysql.Config + result2 error + }{result1, result2} +} + +func (fake *MySQLAdapter) RegisterTLSConfig(key string, config *tls.Config) error { + fake.registerTLSConfigMutex.Lock() + ret, specificReturn := fake.registerTLSConfigReturnsOnCall[len(fake.registerTLSConfigArgsForCall)] + fake.registerTLSConfigArgsForCall = append(fake.registerTLSConfigArgsForCall, struct { + key string + config *tls.Config + }{key, config}) + fake.recordInvocation("RegisterTLSConfig", []interface{}{key, config}) + fake.registerTLSConfigMutex.Unlock() + if fake.RegisterTLSConfigStub != nil { + return fake.RegisterTLSConfigStub(key, config) + } + if specificReturn { + return ret.result1 + } + return fake.registerTLSConfigReturns.result1 +} + +func (fake *MySQLAdapter) RegisterTLSConfigCallCount() int { + fake.registerTLSConfigMutex.RLock() + defer fake.registerTLSConfigMutex.RUnlock() + return len(fake.registerTLSConfigArgsForCall) +} + +func (fake *MySQLAdapter) RegisterTLSConfigArgsForCall(i int) (string, *tls.Config) { + fake.registerTLSConfigMutex.RLock() + defer fake.registerTLSConfigMutex.RUnlock() + return fake.registerTLSConfigArgsForCall[i].key, fake.registerTLSConfigArgsForCall[i].config +} + +func (fake *MySQLAdapter) RegisterTLSConfigReturns(result1 error) { + fake.RegisterTLSConfigStub = nil + fake.registerTLSConfigReturns = struct { + result1 error + }{result1} +} + +func (fake *MySQLAdapter) RegisterTLSConfigReturnsOnCall(i int, result1 error) { + fake.RegisterTLSConfigStub = nil + if fake.registerTLSConfigReturnsOnCall == nil { + fake.registerTLSConfigReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.registerTLSConfigReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *MySQLAdapter) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.parseDSNMutex.RLock() + defer fake.parseDSNMutex.RUnlock() + fake.registerTLSConfigMutex.RLock() + defer fake.registerTLSConfigMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *MySQLAdapter) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +}