diff --git a/db/config.go b/db/config.go index d381a693..8297d603 100644 --- a/db/config.go +++ b/db/config.go @@ -38,9 +38,22 @@ func (c Config) ConnectionString() (string, error) { func buildPostgresConnectionString(c Config) (string, error) { ms := (time.Duration(c.Timeout) * time.Second).Nanoseconds() / 1000 / 1000 + sslmode := "disable" params := url.Values{} - params.Add("sslmode", "disable") + if c.RequireSSL { + if c.SkipHostnameValidation { + sslmode = "require" + } else { + if c.CACert == "" { + return "", fmt.Errorf("SSL is required but `CACert` is not provided") + } + sslmode = "verify-full" + params.Add("sslrootcert", c.CACert) + } + } + + params.Add("sslmode", sslmode) params.Add("connect_timeout", fmt.Sprintf("%d", ms)) connURL := url.URL{ diff --git a/db/config_test.go b/db/config_test.go index 8fd924f0..f9917ffb 100644 --- a/db/config_test.go +++ b/db/config_test.go @@ -5,6 +5,8 @@ import ( "code.cloudfoundry.org/cf-networking-helpers/db" + "net/url" + "github.com/go-sql-driver/mysql" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -69,6 +71,57 @@ var _ = Describe("Config", func() { Expect(err).NotTo(HaveOccurred()) Expect(connectionString).To(Equal("postgres://some-user:some-password@some-host:1234/some-database?connect_timeout=5000&sslmode=disable")) }) + + Context("when ssl is required", func() { + BeforeEach(func() { + config.RequireSSL = true + config.CACert = "/tmp/cert" + }) + + Context("when skip_hostname_validation is set", func() { + BeforeEach(func() { + config.SkipHostnameValidation = true + }) + It("sets sslmode to \"require\"", func() { + connectionString, err := config.ConnectionString() + Expect(err).NotTo(HaveOccurred()) + connUrl, err := url.Parse(connectionString) + Expect(err).NotTo(HaveOccurred()) + connQuery := connUrl.Query() + Expect(connQuery.Get("sslmode")).To(Equal("require")) + }) + }) + + Context("when skip_hostname_validation is not set", func() { + Context("when ca_cert is empty", func() { + BeforeEach(func() { + config.CACert = "" + }) + It("returns an error", func() { + _, err := config.ConnectionString() + Expect(err).To(HaveOccurred()) + }) + }) + + It("sets sslmode to \"verify-full\"", func() { + connectionString, err := config.ConnectionString() + Expect(err).NotTo(HaveOccurred()) + connUrl, err := url.Parse(connectionString) + Expect(err).NotTo(HaveOccurred()) + connQuery := connUrl.Query() + Expect(connQuery.Get("sslmode")).To(Equal("verify-full")) + }) + It("sets sslrootcert", func() { + connectionString, err := config.ConnectionString() + Expect(err).NotTo(HaveOccurred()) + connUrl, err := url.Parse(connectionString) + Expect(err).NotTo(HaveOccurred()) + connQuery := connUrl.Query() + Expect(connQuery.Get("sslrootcert")).To(Equal("/tmp/cert")) + }) + }) + + }) }) Context("when the type is mysql", func() {