Skip to content

Commit b90308a

Browse files
committed
Add support for AWS RDS IAM authentication
This change adds support for AWS' RDS IAM authentication feature: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAM.html The READMEs have been with details on how to configure this exporter for use with RDS IAM.
1 parent efd6766 commit b90308a

File tree

7 files changed

+228
-55
lines changed

7 files changed

+228
-55
lines changed

Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
FROM debian:10-slim
22
RUN useradd -u 20001 postgres_exporter
33

4+
# Install certs and create home directory needed by the AWS SDK
5+
RUN apt update && apt install ca-certificates -y \
6+
&& mkdir /home/postgres_exporter \
7+
&& chown postgres_exporter:postgres_exporter /home/postgres_exporter \
8+
&& rm -rf /var/lib/{apt,dpkg,cache,log}/
9+
410
USER postgres_exporter
511

612
ARG binary

README-RDS.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,21 @@ Running postgres-exporter in a container like so:
3636
```
3737
+ lastly, you must reboot the RDS instance.
3838

39+
### AWS RDS IAM Authentication
40+
41+
To use [AWS RDS IAM Authentication](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html)
42+
set the data source password to the magic value `AWS_RDS_IAM_AUTH`.
43+
44+
An AWS session is constructed using the default providers of the `aws-go-sdk` and should work will all common
45+
configuration options. Using EC2 IAM roles and `AWS_WEB_IDENTITY_TOKEN_FILE`s, via IAM Roles for Service Accounts,
46+
are specifically known to work.
47+
48+
Troubleshooting:
49+
- Do not set `sslmode=disabled`
50+
- Set the `AWS_REGION` environment variable if the running instance does not have access to the EC2 metadata endpoint
51+
- If using `AWS_WEB_IDENTITY_TOKEN_FILE` with kubernetes you likely need to configure your deployment to set the
52+
`securityContext`s `fsGroup` value to `20001` (the ID set for the `postgres_exporter` user in the Dockerfile).
53+
```
54+
securityContext:
55+
fsGroup: 20001
56+
```

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ ALTER USER postgres_exporter SET SEARCH_PATH TO postgres_exporter,pg_catalog;
228228
-- If deploying as non-superuser (for example in AWS RDS), uncomment the GRANT
229229
-- line below and replace <MASTER_USER> with your root user.
230230
-- GRANT postgres_exporter TO <MASTER_USER>;
231+
-- If using AWS RDS IAM authentication, uncomment the line below
232+
-- GRANT rds_iam TO postgres_exporter;
231233
CREATE SCHEMA IF NOT EXISTS postgres_exporter;
232234
GRANT USAGE ON SCHEMA postgres_exporter TO postgres_exporter;
233235
GRANT CONNECT ON DATABASE postgres TO postgres_exporter;
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package main
2+
3+
import (
4+
"github.com/aws/aws-sdk-go/aws/ec2metadata"
5+
"github.com/aws/aws-sdk-go/aws/session"
6+
"github.com/aws/aws-sdk-go/service/rds/rdsutils"
7+
"net/url"
8+
"os"
9+
)
10+
11+
// DsnFactory creates DSNs
12+
type DsnFactory interface {
13+
// Creates a new DSN
14+
NewDsn() (string, error)
15+
16+
// Return a stable, valid, & parsable DSN value for logging, map keys, etc.
17+
DsnID() string
18+
}
19+
20+
// RawDsnFactory is a no-op factory that returns the original 'raw' DSN.
21+
type RawDsnFactory struct {
22+
dsn string
23+
}
24+
25+
// DsnID returns the original DSN
26+
func (f *RawDsnFactory) DsnID() string {
27+
return f.dsn
28+
}
29+
30+
// NewDsn returns the original DSN. No mutations are required DSNs created by the RawDsnFactory
31+
func (f *RawDsnFactory) NewDsn() (string, error) {
32+
return f.dsn, nil
33+
}
34+
35+
// AwsRdsIamFactory generates a DSN configured with AWS RDS IAM authentication credentials
36+
type AwsRdsIamFactory struct {
37+
dsn string
38+
}
39+
40+
// DsnID returns the original DSN. The password is _not_ replaced with an AWS RDS IAM generated password token
41+
func (f *AwsRdsIamFactory) DsnID() string {
42+
return f.dsn
43+
}
44+
45+
// NewDsn builds a new DSN with the password set to one obtained from the AWS RDS IAM token API
46+
func (f *AwsRdsIamFactory) NewDsn() (string, error) {
47+
var err error
48+
var sess *session.Session
49+
sess, err = session.NewSession()
50+
if err != nil {
51+
return "", err
52+
}
53+
54+
var u *url.URL
55+
u, err = url.Parse(f.dsn)
56+
if err != nil {
57+
return "", err
58+
}
59+
60+
region := *sess.Config.Region
61+
62+
if len(region) == 0 {
63+
if len(os.Getenv("AWS_REGION")) > 0 {
64+
region = os.Getenv("AWS_REGION")
65+
} else {
66+
var r string
67+
r, err = ec2metadata.New(sess).Region()
68+
69+
if err != nil {
70+
return "", err
71+
}
72+
region = r
73+
}
74+
}
75+
76+
var token string
77+
token, err = rdsutils.BuildAuthToken(u.Host, region, u.User.Username(), sess.Config.Credentials)
78+
if err != nil {
79+
return "", err
80+
}
81+
82+
u.User = url.UserPassword(u.User.Username(), token)
83+
84+
return u.String(), nil
85+
}

cmd/postgres_exporter/postgres_exporter.go

Lines changed: 81 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -950,8 +950,9 @@ func NewServers(opts ...ServerOpt) *Servers {
950950
}
951951
}
952952

953-
// GetServer returns established connection from a collection.
954-
func (s *Servers) GetServer(dsn string) (*Server, error) {
953+
// GetServer returns established connection from a collection. If a connection
954+
// is invalid a new connection is established.
955+
func (s *Servers) GetServer(dsnID string, dsn string) (*Server, error) {
955956
s.m.Lock()
956957
defer s.m.Unlock()
957958
var err error
@@ -963,17 +964,17 @@ func (s *Servers) GetServer(dsn string) (*Server, error) {
963964
if errCount++; errCount > retries {
964965
return nil, err
965966
}
966-
server, ok = s.servers[dsn]
967+
server, ok = s.servers[dsnID]
967968
if !ok {
968969
server, err = NewServer(dsn, s.opts...)
969970
if err != nil {
970971
time.Sleep(time.Duration(errCount) * time.Second)
971972
continue
972973
}
973-
s.servers[dsn] = server
974+
s.servers[dsnID] = server
974975
}
975976
if err = server.Ping(); err != nil {
976-
delete(s.servers, dsn)
977+
delete(s.servers, dsnID)
977978
time.Sleep(time.Duration(errCount) * time.Second)
978979
continue
979980
}
@@ -1002,7 +1003,7 @@ type Exporter struct {
10021003
disableDefaultMetrics, disableSettingsMetrics, autoDiscoverDatabases bool
10031004

10041005
excludeDatabases []string
1005-
dsn []string
1006+
dsnFactories []DsnFactory
10061007
userQueriesPath string
10071008
constantLabels prometheus.Labels
10081009
duration prometheus.Gauge
@@ -1088,9 +1089,9 @@ func parseConstLabels(s string) prometheus.Labels {
10881089
}
10891090

10901091
// NewExporter returns a new PostgreSQL exporter for the provided DSN.
1091-
func NewExporter(dsn []string, opts ...ExporterOpt) *Exporter {
1092+
func NewExporter(dsnFactories []DsnFactory, opts ...ExporterOpt) *Exporter {
10921093
e := &Exporter{
1093-
dsn: dsn,
1094+
dsnFactories: dsnFactories,
10941095
builtinMetricMaps: builtinMetricMaps,
10951096
}
10961097

@@ -1459,18 +1460,33 @@ func (e *Exporter) scrape(ch chan<- prometheus.Metric) {
14591460

14601461
e.totalScrapes.Inc()
14611462

1462-
dsns := e.dsn
1463+
var errorsCount int
1464+
var dsns map[string]string
1465+
1466+
dsnFactories := e.dsnFactories
1467+
14631468
if e.autoDiscoverDatabases {
14641469
dsns = e.discoverDatabaseDSNs()
1470+
} else {
1471+
dsns = make(map[string]string)
1472+
for _, dsnFactory := range dsnFactories {
1473+
dsn, err := dsnFactory.NewDsn()
1474+
1475+
if err != nil {
1476+
log.Errorf("Unable to create dsn (%s): %v", loggableDSN(dsnFactory.DsnID()), err)
1477+
1478+
errorsCount++
1479+
} else {
1480+
dsns[dsnFactory.DsnID()] = dsn
1481+
}
1482+
}
14651483
}
14661484

1467-
var errorsCount int
14681485
var connectionErrorsCount int
14691486

1470-
for _, dsn := range dsns {
1471-
if err := e.scrapeDSN(ch, dsn); err != nil {
1487+
for dsnID, dsn := range dsns {
1488+
if err := e.scrapeDSN(ch, dsnID, dsn); err != nil {
14721489
errorsCount++
1473-
14741490
log.Errorf(err.Error())
14751491

14761492
if _, ok := err.(*ErrorConnectToServer); ok {
@@ -1494,17 +1510,31 @@ func (e *Exporter) scrape(ch chan<- prometheus.Metric) {
14941510
}
14951511
}
14961512

1497-
func (e *Exporter) discoverDatabaseDSNs() []string {
1498-
dsns := make(map[string]struct{})
1499-
for _, dsn := range e.dsn {
1500-
parsedDSN, err := url.Parse(dsn)
1513+
func (e *Exporter) discoverDatabaseDSNs() map[string]string {
1514+
dsns := make(map[string]string)
1515+
for _, dsnFactory := range e.dsnFactories {
1516+
dsnID := dsnFactory.DsnID()
1517+
dsn, err := dsnFactory.NewDsn()
1518+
if err != nil {
1519+
log.Errorf("Unable to create DSN (%s): %v", loggableDSN(dsnFactory.DsnID()), err)
1520+
}
1521+
1522+
// Validate the DSN returned from the factory
1523+
parsedDsn, err := url.Parse(dsn)
15011524
if err != nil {
15021525
log.Errorf("Unable to parse DSN (%s): %v", loggableDSN(dsn), err)
15031526
continue
15041527
}
15051528

1506-
dsns[dsn] = struct{}{}
1507-
server, err := e.servers.GetServer(dsn)
1529+
// Validate and DSN Id assigned to the factory
1530+
parsedDsnID, err := url.Parse(dsnID)
1531+
if err != nil {
1532+
log.Errorf("Unable to parse DSN Id (%s): %v", loggableDSN(dsnID), err)
1533+
continue
1534+
}
1535+
1536+
dsns[dsnID] = dsn
1537+
server, err := e.servers.GetServer(dsnID, dsn)
15081538
if err != nil {
15091539
log.Errorf("Error opening connection to database (%s): %v", loggableDSN(dsn), err)
15101540
continue
@@ -1522,23 +1552,19 @@ func (e *Exporter) discoverDatabaseDSNs() []string {
15221552
if contains(e.excludeDatabases, databaseName) {
15231553
continue
15241554
}
1525-
parsedDSN.Path = databaseName
1526-
dsns[parsedDSN.String()] = struct{}{}
1527-
}
1528-
}
15291555

1530-
result := make([]string, len(dsns))
1531-
index := 0
1532-
for dsn := range dsns {
1533-
result[index] = dsn
1534-
index++
1556+
parsedDsn.Path = databaseName
1557+
parsedDsnID.Path = databaseName
1558+
1559+
dsns[parsedDsnID.String()] = parsedDsn.String()
1560+
}
15351561
}
15361562

1537-
return result
1563+
return dsns
15381564
}
15391565

1540-
func (e *Exporter) scrapeDSN(ch chan<- prometheus.Metric, dsn string) error {
1541-
server, err := e.servers.GetServer(dsn)
1566+
func (e *Exporter) scrapeDSN(ch chan<- prometheus.Metric, dsnID string, dsn string) error {
1567+
server, err := e.servers.GetServer(dsnID, dsn)
15421568

15431569
if err != nil {
15441570
return &ErrorConnectToServer{fmt.Sprintf("Error opening connection to database (%s): %s", loggableDSN(dsn), err.Error())}
@@ -1561,9 +1587,12 @@ func (e *Exporter) scrapeDSN(ch chan<- prometheus.Metric, dsn string) error {
15611587
// DATA_SOURCE_NAME always wins so we do not break older versions
15621588
// reading secrets from files wins over secrets in environment variables
15631589
// DATA_SOURCE_NAME > DATA_SOURCE_{USER|PASS}_FILE > DATA_SOURCE_{USER|PASS}
1564-
func getDataSources() []string {
1590+
func getDataSources() []DsnFactory {
15651591
var dsn = os.Getenv("DATA_SOURCE_NAME")
1566-
if len(dsn) == 0 {
1592+
var dsns []string
1593+
if len(dsn) > 0 {
1594+
dsns = strings.Split(dsn, ",")
1595+
} else {
15671596
var user string
15681597
var pass string
15691598
var uri string
@@ -1602,9 +1631,23 @@ func getDataSources() []string {
16021631

16031632
dsn = "postgresql://" + ui + "@" + uri
16041633

1605-
return []string{dsn}
1634+
dsns = []string{dsn}
1635+
}
1636+
1637+
dsnFactories := make([]DsnFactory, 0, len(dsns))
1638+
for _, dsn := range dsns {
1639+
u, _ := url.Parse(dsn)
1640+
1641+
if pw, set := u.User.Password(); set && pw == "AWS_RDS_IAM_AUTH" {
1642+
log.Infof("Building AwsRdsIamFactory for %s", loggableDSN(dsn))
1643+
dsnFactories = append(dsnFactories, &AwsRdsIamFactory{dsn: dsn})
1644+
} else {
1645+
log.Infof("Building RawDsnFactory for %s", loggableDSN(dsn))
1646+
dsnFactories = append(dsnFactories, &RawDsnFactory{dsn: dsn})
1647+
}
16061648
}
1607-
return strings.Split(dsn, ",")
1649+
1650+
return dsnFactories
16081651
}
16091652

16101653
func contains(a []string, x string) bool {
@@ -1637,12 +1680,12 @@ func main() {
16371680
return
16381681
}
16391682

1640-
dsn := getDataSources()
1641-
if len(dsn) == 0 {
1683+
dsnFactories := getDataSources()
1684+
if len(dsnFactories) == 0 {
16421685
log.Fatal("couldn't find environment variables describing the datasource to use")
16431686
}
16441687

1645-
exporter := NewExporter(dsn,
1688+
exporter := NewExporter(dsnFactories,
16461689
DisableDefaultMetrics(*disableDefaultMetrics),
16471690
DisableSettingsMetrics(*disableSettingsMetrics),
16481691
AutoDiscoverDatabases(*autoDiscoverDatabases),

0 commit comments

Comments
 (0)