diff --git a/CHANGELOG.md b/CHANGELOG.md index 0922b8bd..edda063e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +# confluent-kafka-javascript 1.6.0 + +v1.6.0 is a feature release. It is supported for all usage. + +### Enhancements + +1. References librdkafka v2.12.0. Refer to the [librdkafka v2.12.0 release notes](https://github.com/confluentinc/librdkafka/releases/tag/v2.12.0) for more information. +2. OAuth OIDC method for Schema Registry metadata based authentication with + an Azure IMDS endpoint using an attached managed identity as principal (#). + + # confluent-kafka-javascript 1.5.0 v1.5.0 is a feature release. It is supported for all usage. diff --git a/deps/librdkafka b/deps/librdkafka index 69b1865e..5e9034be 160000 --- a/deps/librdkafka +++ b/deps/librdkafka @@ -1 +1 @@ -Subproject commit 69b1865efdc0118cd017760d038d34e52fb3f0d0 +Subproject commit 5e9034bebb67872778ce667852217eb11f949b42 diff --git a/lib/error.js b/lib/error.js index d9d468c0..135db2ea 100644 --- a/lib/error.js +++ b/lib/error.js @@ -28,7 +28,7 @@ LibrdKafkaError.wrap = errorWrap; * @constant * @memberof RdKafka */ -// ====== Generated from librdkafka 2.11.1 file src-cpp/rdkafkacpp.h ====== +// ====== Generated from librdkafka dev_oauthbearer_metadata_based file src-cpp/rdkafkacpp.h ====== LibrdKafkaError.codes = { /* Internal errors to rdkafka: */ diff --git a/package.json b/package.json index 828a8197..631cfd0f 100644 --- a/package.json +++ b/package.json @@ -2,8 +2,8 @@ "name": "@confluentinc/kafka-javascript", "version": "1.5.0", "description": "Node.js bindings for librdkafka", - "librdkafka": "2.11.1", - "librdkafka_win": "2.11.1", + "librdkafka": "dev_oauthbearer_metadata_based", + "librdkafka_win": "dev_oauthbearer_metadata_based", "main": "lib/index.js", "types": "types/index.d.ts", "scripts": { diff --git a/schemaregistry-examples/src/constants.ts b/schemaregistry-examples/src/constants.ts index a30f5881..a3c4e5c9 100644 --- a/schemaregistry-examples/src/constants.ts +++ b/schemaregistry-examples/src/constants.ts @@ -1,6 +1,7 @@ import { BasicAuthCredentials } from '@confluentinc/schemaregistry'; const issuerEndpointUrl = ''; // e.g. 'https://dev-123456.okta.com/oauth2/default/v1/token'; +const azureIMDSIssuerEndpointQuery = 'api-version=&resource=&client_id='; // e.g. 'api-version=&resource=api://&client_id='; const oauthClientId = ''; const oauthClientSecret = ''; const scope = ''; // e.g. 'schemaregistry'; @@ -23,6 +24,7 @@ const basicAuthCredentials: BasicAuthCredentials = { }; export { - issuerEndpointUrl, oauthClientId, oauthClientSecret, scope, identityPoolId, kafkaLogicalCluster, schemaRegistryLogicalCluster, + issuerEndpointUrl, + azureIMDSIssuerEndpointQuery, oauthClientId, oauthClientSecret, scope, identityPoolId, kafkaLogicalCluster, schemaRegistryLogicalCluster, baseUrl, clusterBootstrapUrl, clusterApiKey, clusterApiSecret, basicAuthCredentials, localAuthCredentials }; \ No newline at end of file diff --git a/schemaregistry-examples/src/kafka-oauth.ts b/schemaregistry-examples/src/kafka-oauth.ts index 0e044d87..a0111e2c 100644 --- a/schemaregistry-examples/src/kafka-oauth.ts +++ b/schemaregistry-examples/src/kafka-oauth.ts @@ -7,7 +7,9 @@ import { KafkaJS } from '@confluentinc/kafka-javascript'; import { clusterBootstrapUrl, baseUrl, - issuerEndpointUrl, oauthClientId, oauthClientSecret, scope, + issuerEndpointUrl, + azureIMDSIssuerEndpointQuery, + oauthClientId, oauthClientSecret, scope, identityPoolId, schemaRegistryLogicalCluster, kafkaLogicalCluster } from "./constants"; import axios from 'axios'; @@ -134,4 +136,91 @@ async function kafkaProducerAvro() { await producer.disconnect(); } -kafkaProducerAvro(); \ No newline at end of file +async function kafkaProducerAvroAzureIMDS() { + + const createAxiosDefaults: CreateAxiosDefaults = { + timeout: 10000 + }; + + const bearerAuthCredentials: BearerAuthCredentials = { + credentialsSource: 'OAUTHBEARER_AZURE_IMDS', + issuerEndpointQuery: azureIMDSIssuerEndpointQuery, + logicalCluster: schemaRegistryLogicalCluster, + identityPoolId: identityPoolId, + } + + const clientConfig: ClientConfig = { + baseURLs: [baseUrl], + createAxiosDefaults: createAxiosDefaults, + cacheCapacity: 512, + cacheLatestTtlSecs: 60, + bearerAuthCredentials + }; + + const schemaRegistryClient = new SchemaRegistryClient(clientConfig); + + const kafka: KafkaJS.Kafka = new KafkaJS.Kafka({ + 'bootstrap.servers': clusterBootstrapUrl, + 'security.protocol': 'sasl_ssl', + 'sasl.mechanism': 'OAUTHBEARER', + 'sasl.oauthbearer.method': 'oidc', + 'sasl.oauthbearer.metadata.authentication.type': 'azure_imds', + 'sasl.oauthbearer.config': 'query=' + azureIMDSIssuerEndpointQuery, + 'sasl.oauthbearer.extensions': `logicalCluster=${kafkaLogicalCluster},identityPoolId=${identityPoolId}` + }); + + const producer: KafkaJS.Producer = kafka.producer({ + kafkaJS: { + allowAutoTopicCreation: true, + acks: 1, + compression: KafkaJS.CompressionTypes.GZIP, + } + }); + + console.log("Producer created"); + + const schemaString: string = JSON.stringify({ + type: 'record', + name: 'User', + fields: [ + { name: 'name', type: 'string' }, + { name: 'age', type: 'int' }, + ], + }); + + const schemaInfo: SchemaInfo = { + schemaType: 'AVRO', + schema: schemaString, + }; + + const userTopic = 'example-user-topic'; + await schemaRegistryClient.register(userTopic + "-value", schemaInfo); + + const userInfo = { name: 'Alice N Bob', age: 30 }; + + const avroSerializerConfig: AvroSerializerConfig = { useLatestVersion: true }; + + const serializer: AvroSerializer = new AvroSerializer(schemaRegistryClient, SerdeType.VALUE, avroSerializerConfig); + + const outgoingMessage = { + key: "1", + value: await serializer.serialize(userTopic, userInfo) + }; + + console.log("Outgoing message: ", outgoingMessage); + + await producer.connect(); + + await producer.send({ + topic: userTopic, + messages: [outgoingMessage] + }); + + await producer.disconnect(); +} + +async function main() { + await kafkaProducerAvro(); + await kafkaProducerAvroAzureIMDS(); +} +main(); \ No newline at end of file diff --git a/schemaregistry/oauth/abstract-oauth-client.ts b/schemaregistry/oauth/abstract-oauth-client.ts new file mode 100644 index 00000000..9bf94d23 --- /dev/null +++ b/schemaregistry/oauth/abstract-oauth-client.ts @@ -0,0 +1,102 @@ +import { sleep, fullJitter, isRetriable } from '../retry-helper'; +import { isBoom } from '@hapi/boom'; +import { + _BearerTokenProvider as BearerTokenProvider, + _BearerTokenProviderBuilder as BearerTokenProviderBuilder +} from './bearer-token-provider'; +import { BearerAuthCredentials } from '../rest-service'; + +abstract class AbstractBearerTokenProviderBuilder implements BearerTokenProviderBuilder { + + protected bearerAuthCredentials : BearerAuthCredentials; + + constructor( + bearerAuthCredentials: BearerAuthCredentials) { + this.bearerAuthCredentials = bearerAuthCredentials; + } + + protected validate() { + const headers = ['logicalCluster', 'identityPoolId']; + const missingHeader = headers.find(header => !(header in this.bearerAuthCredentials)); + + if (missingHeader) { + throw new Error(`Bearer auth header '${missingHeader}' not provided`); + } + } + + abstract build(maxRetries: number, retriesWaitMs: number, retriesMaxWaitMs: number): BearerTokenProvider; +} + +abstract class AbstractOauthTokenProvider implements BearerTokenProvider { + + private additionalHeaders: Record; + + constructor(bearerAuthCredentials: BearerAuthCredentials) { + this.additionalHeaders = { + 'target-sr-cluster': bearerAuthCredentials.logicalCluster!, + 'Confluent-Identity-Pool-Id': bearerAuthCredentials.identityPoolId!, + }; + } + + abstract getAccessToken(): Promise + + abstract tokenExpired(): boolean; + + getAdditionalHeaders(): Record { + return this.additionalHeaders; + } +} + +abstract class AbstractOAuthClient extends AbstractOauthTokenProvider { + private token: string | null = null; + private maxRetries: number; + private retriesWaitMs: number; + private retriesMaxWaitMs: number; + + constructor(bearerAuthCredentials: BearerAuthCredentials, + maxRetries: number, retriesWaitMs: number, retriesMaxWaitMs: number + ) { + super(bearerAuthCredentials); + this.maxRetries = maxRetries; + this.retriesWaitMs = retriesWaitMs; + this.retriesMaxWaitMs = retriesMaxWaitMs; + } + + abstract fetchToken(): Promise; + + override async getAccessToken(): Promise { + if (this.token === null || this.tokenExpired()) { + await this.generateAccessToken(); + if (this.token === null) + throw new Error(`token must be available here`); + } + + return this.token; + } + + async generateAccessToken(): Promise { + for (let i = 0; i < this.maxRetries + 1; i++) { + try { + this.token = await this.fetchToken(); + return; + } catch (error: any) { + if (isBoom(error) && i < this.maxRetries) { + const statusCode = error.output.statusCode; + if (isRetriable(statusCode)) { + const waitTime = fullJitter(this.retriesWaitMs, this.retriesMaxWaitMs, i); + await sleep(waitTime); + continue; + } + } + throw new Error(`Failed to get token from server: ${error}`); + } + } + } +} + +// internal/testing usage only +export { + AbstractBearerTokenProviderBuilder as _AbstractBearerTokenProviderBuilder, + AbstractOauthTokenProvider as _AbstractOauthTokenProvider, + AbstractOAuthClient as _AbstractOAuthClient, +} \ No newline at end of file diff --git a/schemaregistry/oauth/bearer-token-provider.ts b/schemaregistry/oauth/bearer-token-provider.ts new file mode 100644 index 00000000..aea3e661 --- /dev/null +++ b/schemaregistry/oauth/bearer-token-provider.ts @@ -0,0 +1,16 @@ +interface BearerTokenProvider { + getAccessToken(): Promise; + getAdditionalHeaders(): Record; + tokenExpired(): boolean; +} + +interface BearerTokenProviderBuilder { + build(maxRetries: number, retriesWaitMs: number, retriesMaxWaitMs: number): BearerTokenProvider +} + +// internal/testing usage only +export { + BearerTokenProvider as _BearerTokenProvider, + BearerTokenProviderBuilder as _BearerTokenProviderBuilder +} + diff --git a/schemaregistry/oauth/oauth-client-azure-imds.ts b/schemaregistry/oauth/oauth-client-azure-imds.ts new file mode 100644 index 00000000..e8035960 --- /dev/null +++ b/schemaregistry/oauth/oauth-client-azure-imds.ts @@ -0,0 +1,102 @@ +import { + _AbstractOAuthClient as AbstractOAuthClient, + _AbstractBearerTokenProviderBuilder as AbstractBearerTokenProviderBuilder +} from './abstract-oauth-client'; +import Wreck from '@hapi/wreck'; +import { BearerAuthCredentials } from '../rest-service'; +import { + _BearerTokenProvider as BearerTokenProvider +} from './bearer-token-provider'; + +const TOKEN_EXPIRATION_THRESHOLD_PERCENTAGE = 0.8; + +class AzureIMDSBearerToken { + access_token?: string = undefined; + expires_in?: string = undefined; + expires_on?: string = undefined; +} + +class AzureIMDSOAuthClientBuilder extends AbstractBearerTokenProviderBuilder { + constructor( + bearerAuthCredentials: BearerAuthCredentials) { + super(bearerAuthCredentials); + } + + protected override validate() { + super.validate(); + if (!this.bearerAuthCredentials.issuerEndpointUrl && + !this.bearerAuthCredentials.issuerEndpointQuery) + throw new Error(`Missing required configuration property: issuerEndpointQuery`); + } + + override build(maxRetries: number, retriesWaitMs: number, retriesMaxWaitMs: number): BearerTokenProvider { + this.validate(); + return new AzureIMDSOAuthClient(this.bearerAuthCredentials, maxRetries, retriesWaitMs, retriesMaxWaitMs); + } +} + +class AzureIMDSOAuthClient extends AbstractOAuthClient { + + private tokenEndpoint: string; + private tokenObject?: AzureIMDSBearerToken; + private static readonly DEFAULT_AZURE_IMDS_TOKEN_ENDPOINT : string = 'http://169.254.169.254/metadata/identity/oauth2/token'; + + constructor(bearerAuthCredentials: BearerAuthCredentials, + maxRetries: number, retriesWaitMs: number, + retriesMaxWaitMs: number + ) { + super(bearerAuthCredentials, maxRetries, retriesWaitMs, retriesMaxWaitMs); + this.tokenEndpoint = bearerAuthCredentials.issuerEndpointUrl || + AzureIMDSOAuthClient.DEFAULT_AZURE_IMDS_TOKEN_ENDPOINT; + if (bearerAuthCredentials.issuerEndpointQuery) { + const url = new URL(this.tokenEndpoint); + url.search = bearerAuthCredentials.issuerEndpointQuery; + url.hash = ''; + this.tokenEndpoint = url.toString(); + } + } + + override async fetchToken(): Promise { + const { payload } = await Wreck.get( + this.tokenEndpoint, { + headers: { + Metadata: 'true' + }, + json: 'force', + timeout: 30000 // 30 seconds limit for each request + }); + this.tokenObject = payload; + return this.getAccessTokenString(); + } + + override tokenExpired(): boolean { + if (!this.tokenObject?.expires_in || !this.tokenObject?.expires_on) + return true; + + const expiresIn = +this.tokenObject.expires_in; + let expiresOn = +this.tokenObject.expires_on; + if (isNaN(expiresIn) || isNaN(expiresOn)) + return true; + + const expiryWindow = expiresIn * 1000 * TOKEN_EXPIRATION_THRESHOLD_PERCENTAGE; + expiresOn = expiresOn * 1000; + return expiresOn < Date.now() + expiryWindow; + } + + private getAccessTokenString(): string { + const accessToken = this.tokenObject?.access_token; + + if (typeof accessToken !== 'string') { + throw new Error('Access token is not available'); + } + + return accessToken; + } +} + +// internal/testing usage only +export { + AzureIMDSOAuthClientBuilder as _AzureIMDSOAuthClientBuilder, + AzureIMDSOAuthClient as _AzureIMDSOAuthClient, + AzureIMDSBearerToken as _AzureIMDSBearerToken +} \ No newline at end of file diff --git a/schemaregistry/oauth/oauth-client.ts b/schemaregistry/oauth/oauth-client.ts index 512b749c..3056c745 100644 --- a/schemaregistry/oauth/oauth-client.ts +++ b/schemaregistry/oauth/oauth-client.ts @@ -1,79 +1,101 @@ import { ModuleOptions, ClientCredentials, ClientCredentialTokenConfig, AccessToken } from 'simple-oauth2'; -import { sleep, fullJitter, isRetriable } from '../retry-helper'; -import { isBoom } from '@hapi/boom'; - +import { + _AbstractBearerTokenProviderBuilder as AbstractBearerTokenProviderBuilder, + _AbstractOAuthClient as AbstractOAuthClient, +} from './abstract-oauth-client'; +import { BearerAuthCredentials } from '../rest-service'; +import { + _BearerTokenProvider as BearerTokenProvider +} from './bearer-token-provider'; const TOKEN_EXPIRATION_THRESHOLD_SECONDS = 30 * 60; // 30 minutes -export class OAuthClient { +class OAuthClientBuilder extends AbstractBearerTokenProviderBuilder { + + static readonly requiredFields = [ + 'clientId', + 'clientSecret', + 'issuerEndpointUrl', + 'scope' + ]; + + constructor( + bearerAuthCredentials: BearerAuthCredentials) { + super(bearerAuthCredentials); + } + + protected override validate() { + super.validate(); + const missingField = OAuthClientBuilder.requiredFields.find( + field => !(field in this.bearerAuthCredentials)); + + if (missingField) { + throw new Error(`OAuth credential '${missingField}' not provided`); + } + } + + override build(maxRetries: number, retriesWaitMs: number, retriesMaxWaitMs: number) : BearerTokenProvider { + this.validate(); + return new OAuthClient( + this.bearerAuthCredentials, + maxRetries, + retriesWaitMs, + retriesMaxWaitMs); + } +} + +class OAuthClient extends AbstractOAuthClient { private client: ClientCredentials; - private token: AccessToken | undefined; + private tokenObject: AccessToken | undefined; private tokenParams: ClientCredentialTokenConfig; - private maxRetries: number; - private retriesWaitMs: number; - private retriesMaxWaitMs: number; - constructor(clientId: string, clientSecret: string, tokenHost: string, tokenPath: string, scope: string, + constructor(bearerAuthCredentials: BearerAuthCredentials, maxRetries: number, retriesWaitMs: number, retriesMaxWaitMs: number ) { + super(bearerAuthCredentials, maxRetries, retriesWaitMs, retriesMaxWaitMs); + + const tokenEndpoint = new URL(bearerAuthCredentials.issuerEndpointUrl!); const clientConfig: ModuleOptions = { client: { - id: clientId, - secret: clientSecret, + id: bearerAuthCredentials.clientId!, + secret: bearerAuthCredentials.clientSecret!, }, auth: { - tokenHost: tokenHost, - tokenPath: tokenPath + tokenHost: tokenEndpoint.origin, + tokenPath: tokenEndpoint.pathname }, options: { credentialsEncodingMode: 'loose' } } - this.tokenParams = { scope }; + this.tokenParams = { scope: bearerAuthCredentials.scope! }; this.client = new ClientCredentials(clientConfig); - - this.maxRetries = maxRetries; - this.retriesWaitMs = retriesWaitMs; - this.retriesMaxWaitMs = retriesMaxWaitMs; } - async getAccessToken(): Promise { - if (!this.token || this.token.expired(TOKEN_EXPIRATION_THRESHOLD_SECONDS)) { - await this.generateAccessToken(); - } - - return this.getAccessTokenString(); + override async fetchToken(): Promise { + this.tokenObject = await this.client.getToken(this.tokenParams); + return this.getAccessTokenString(); } - async generateAccessToken(): Promise { - for (let i = 0; i < this.maxRetries + 1; i++) { - try { - const token = await this.client.getToken(this.tokenParams); - this.token = token; - return; - } catch (error: any) { - if (isBoom(error) && i < this.maxRetries) { - const statusCode = error.output.statusCode; - if (isRetriable(statusCode)) { - const waitTime = fullJitter(this.retriesWaitMs, this.retriesMaxWaitMs, i); - await sleep(waitTime); - continue; - } - } - throw new Error(`Failed to get token from server: ${error}`); - } - } + override tokenExpired(): boolean { + return this.tokenObject === undefined || + this.tokenObject.expired(TOKEN_EXPIRATION_THRESHOLD_SECONDS); } - async getAccessTokenString(): Promise { - const accessToken = this.token?.token?.['access_token']; + private getAccessTokenString(): string { + const accessToken = this.tokenObject?.token?.['access_token']; - if (typeof accessToken === 'string') { - return accessToken; + if (typeof accessToken !== 'string') { + throw new Error('Access token is not available'); } - throw new Error('Access token is not available'); + return accessToken; } } +// internal/testing usage only +export { + OAuthClient as _OAuthClient, + OAuthClientBuilder as _OAuthClientBuilder, +} \ No newline at end of file diff --git a/schemaregistry/oauth/static-token-provider.ts b/schemaregistry/oauth/static-token-provider.ts new file mode 100644 index 00000000..0290364a --- /dev/null +++ b/schemaregistry/oauth/static-token-provider.ts @@ -0,0 +1,52 @@ +import { BearerAuthCredentials } from '../rest-service'; +import { + _AbstractBearerTokenProviderBuilder as AbstractBearerTokenProviderBuilder, + _AbstractOauthTokenProvider as AbstractOauthTokenProvider, +} from './abstract-oauth-client'; +import { + _BearerTokenProvider as BearerTokenProvider +} from './bearer-token-provider'; + + +class StaticTokenProviderBuilder extends AbstractBearerTokenProviderBuilder { + + constructor( + bearerAuthCredentials: BearerAuthCredentials) { + super(bearerAuthCredentials); + } + + protected override validate() { + super.validate(); + if (!this.bearerAuthCredentials.token) { + throw new Error('Bearer token not provided'); + } + } + + override build() : BearerTokenProvider { + this.validate(); + return new StaticTokenProvider(this.bearerAuthCredentials); + } +} + +class StaticTokenProvider extends AbstractOauthTokenProvider { + + private token: string; + + constructor(bearerAuthCredentials: BearerAuthCredentials) { + super(bearerAuthCredentials); + this.token = bearerAuthCredentials.token!; + } + + getAccessToken(): Promise { + return Promise.resolve(this.token); + } + + tokenExpired(): boolean { + return false; + } +} + +// internal/testing usage only +export { + StaticTokenProviderBuilder as _StaticTokenProviderBuilder +} diff --git a/schemaregistry/rest-service.ts b/schemaregistry/rest-service.ts index e183de90..3c9e480a 100644 --- a/schemaregistry/rest-service.ts +++ b/schemaregistry/rest-service.ts @@ -1,8 +1,20 @@ import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse, CreateAxiosDefaults } from 'axios'; -import { OAuthClient } from './oauth/oauth-client'; import { RestError } from './rest-error'; import axiosRetry from "axios-retry"; import { fullJitter, isRetriable, isSuccess } from './retry-helper'; +import { + _BearerTokenProvider as BearerTokenProvider, + _BearerTokenProviderBuilder as BearerTokenProviderBuilder +} from './oauth/bearer-token-provider'; +import { + _StaticTokenProviderBuilder as StaticTokenProviderBuilder +} from './oauth/static-token-provider'; +import { + _OAuthClientBuilder as OAuthClientBuilder, +} from './oauth/oauth-client'; +import { + _AzureIMDSOAuthClientBuilder as AzureIMDSOAuthClientBuilder, +} from './oauth/oauth-client-azure-imds'; /* * Confluent-Schema-Registry-TypeScript - Node.js wrapper for Confluent Schema Registry * @@ -25,9 +37,10 @@ export interface SaslInfo { } export interface BearerAuthCredentials { - credentialsSource: 'STATIC_TOKEN' | 'OAUTHBEARER', + credentialsSource: 'STATIC_TOKEN' | 'OAUTHBEARER' | 'OAUTHBEARER_AZURE_IMDS', token?: string, issuerEndpointUrl?: string, + issuerEndpointQuery?: string, clientId?: string, clientSecret?: string, scope?: string, @@ -53,8 +66,12 @@ const toBase64 = (str: string): string => Buffer.from(str).toString('base64'); export class RestService { private client: AxiosInstance; private baseURLs: string[]; - private oauthClient?: OAuthClient; - private oauthBearer: boolean = false; + private bearerTokenProvider?: BearerTokenProvider; + private static oauthBearerTokenProviderBuilders : Record BearerTokenProviderBuilder> = { + 'STATIC_TOKEN': (credentials) => new StaticTokenProviderBuilder(credentials), + 'OAUTHBEARER': (credentials) => new OAuthClientBuilder(credentials), + 'OAUTHBEARER_AZURE_IMDS': (credentials) => new AzureIMDSOAuthClientBuilder(credentials) + } constructor(baseURLs: string[], isForward?: boolean, axiosDefaults?: CreateAxiosDefaults, basicAuthCredentials?: BasicAuthCredentials, bearerAuthCredentials?: BearerAuthCredentials, @@ -123,39 +140,17 @@ export class RestService { throw new Error(`Bearer auth header '${missingHeader}' not provided`); } - this.setHeaders({ - 'Confluent-Identity-Pool-Id': bearerAuthCredentials.identityPoolId!, - 'target-sr-cluster': bearerAuthCredentials.logicalCluster! - }); - - switch (bearerAuthCredentials.credentialsSource) { - case 'STATIC_TOKEN': - if (!bearerAuthCredentials.token) { - throw new Error('Bearer token not provided'); - } - this.setAuth(undefined, bearerAuthCredentials.token); - break; - case 'OAUTHBEARER': - this.oauthBearer = true; - const requiredFields = [ - 'clientId', - 'clientSecret', - 'issuerEndpointUrl', - 'scope' - ]; - const missingField = requiredFields.find(field => !(field in bearerAuthCredentials)); - - if (missingField) { - throw new Error(`OAuth credential '${missingField}' not provided`); - } - const issuerEndPointUrl = new URL(bearerAuthCredentials.issuerEndpointUrl!); - this.oauthClient = new OAuthClient(bearerAuthCredentials.clientId!, bearerAuthCredentials.clientSecret!, - issuerEndPointUrl.origin, issuerEndPointUrl.pathname, bearerAuthCredentials.scope!, - maxRetries, retriesWaitMs, retriesMaxWaitMs); - break; - default: - throw new Error('Invalid bearer auth credentials source'); + if (!(bearerAuthCredentials.credentialsSource in + RestService.oauthBearerTokenProviderBuilders)) { + throw new Error('Invalid bearer auth credentials source'); } + + this.bearerTokenProvider = RestService.oauthBearerTokenProviderBuilders[ + bearerAuthCredentials.credentialsSource](bearerAuthCredentials).build( + maxRetries, retriesWaitMs, retriesMaxWaitMs + ); + + this.setHeaders(this.bearerTokenProvider!.getAdditionalHeaders()); } } @@ -166,7 +161,7 @@ export class RestService { config?: AxiosRequestConfig, ): Promise> { - if (this.oauthBearer) { + if (this.bearerTokenProvider && this.bearerTokenProvider.tokenExpired()) { await this.setOAuthBearerToken(); } @@ -213,11 +208,7 @@ export class RestService { } async setOAuthBearerToken(): Promise { - if (!this.oauthClient) { - throw new Error('OAuthClient not initialized'); - } - - const bearerToken: string = await this.oauthClient.getAccessToken(); + const bearerToken: string = await this.bearerTokenProvider!.getAccessToken(); this.setAuth(undefined, bearerToken); } diff --git a/schemaregistry/test/oauth-client-azure-imds.spec.ts b/schemaregistry/test/oauth-client-azure-imds.spec.ts new file mode 100644 index 00000000..618002c9 --- /dev/null +++ b/schemaregistry/test/oauth-client-azure-imds.spec.ts @@ -0,0 +1,187 @@ +import { + _AzureIMDSOAuthClient as AzureIMDSOAuthClient, + _AzureIMDSOAuthClientBuilder as AzureIMDSOAuthClientBuilder, + _AzureIMDSBearerToken as AzureIMDSBearerToken, + } from '../oauth/oauth-client-azure-imds'; +import { beforeEach, afterEach, describe, expect, it, jest } from '@jest/globals'; +import Wreck from '@hapi/wreck'; +import * as retryHelper from '@confluentinc/schemaregistry/retry-helper'; +import { maxRetries, retriesWaitMs, retriesMaxWaitMs } from './test-constants'; +import { boomify } from '@hapi/boom'; +import { BearerAuthCredentials } from '../rest-service'; +import Http from 'http'; + +const mockError = boomify(new Error('Error Message'), { statusCode: 429 }); +const mockErrorNonRetry = boomify(new Error('Error Message'), { statusCode: 401 }); + +describe('AzureIMDSOAuthClient', () => { + const tokenEndpoint = 'https://example.com/token'; + const tokenEndpointQuery = 'resource=&api-version=&client_id='; + + let oauthClient: AzureIMDSOAuthClient; + + const res : Http.IncomingMessage = {} as Http.IncomingMessage; + const WreckGetSpy = jest.spyOn(Wreck, 'get'); + const mockToken: AzureIMDSBearerToken = { + access_token: 'mockAccessToken', + expires_in: '3600', + expires_on: (Math.floor(Date.now() / 1000) + 3600).toString(), // 1 hour from now + }; + const mockTokenExpired: AzureIMDSBearerToken = { + access_token: 'mockAccessToken', + expires_in: '3600', + expires_on: (Math.floor(Date.now() / 1000) - 7200).toString(), // 2 hours ago + }; + const basicConfig: BearerAuthCredentials = { + credentialsSource: 'OAUTHBEARER_AZURE_IMDS', + logicalCluster: 'clusterId', + identityPoolId: 'identityPoolId', + }; + + beforeEach(() => { + + const bearerAuthCredentials: BearerAuthCredentials = { + ...basicConfig, + issuerEndpointQuery: tokenEndpointQuery + }; + + oauthClient = new AzureIMDSOAuthClientBuilder( + bearerAuthCredentials, + ).build( + maxRetries, retriesWaitMs, retriesMaxWaitMs + ) as AzureIMDSOAuthClient; + + jest.spyOn(retryHelper, 'isRetriable'); + jest.spyOn(retryHelper, 'fullJitter'); + jest.spyOn(retryHelper, 'sleep'); + + jest.spyOn(oauthClient, 'generateAccessToken'); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should fail when no endpoint or query parameters are provided', async () => { + const bearerAuthCredentialsException: BearerAuthCredentials = { + ...basicConfig + } + expect(() => + new AzureIMDSOAuthClientBuilder( + bearerAuthCredentialsException, + ).build(maxRetries, retriesWaitMs, retriesMaxWaitMs) + ).toThrow(new Error('Missing required configuration property: issuerEndpointQuery')); + }); + + it('should retrieve an access token successfully', async () => { + WreckGetSpy.mockResolvedValueOnce({ payload: mockToken, res }); + + const token = await oauthClient.getAccessToken(); + expect(token).toBe('mockAccessToken'); + expect(WreckGetSpy).toHaveBeenCalledWith( + expect.stringContaining(`http://169.254.169.254/metadata/identity/oauth2/token?${tokenEndpointQuery}`), + expect.any(Object)); + }); + + it('should succeed when an endpoint is provided', async () => { + WreckGetSpy.mockResolvedValueOnce({ payload: mockToken, res }); + const bearerAuthCredentialsEndpoint: BearerAuthCredentials = { + ...basicConfig, + issuerEndpointUrl: tokenEndpoint + } + + await expect(async () => { + const oauthClientEndpoint = new AzureIMDSOAuthClientBuilder( + bearerAuthCredentialsEndpoint, + ).build(maxRetries, retriesWaitMs, retriesMaxWaitMs); + + const token = await oauthClientEndpoint.getAccessToken(); + expect(token).toBe('mockAccessToken'); + expect(WreckGetSpy).toHaveBeenCalledWith( + expect.stringContaining(`${tokenEndpoint}`), + expect.any(Object)); + }).not.toThrow(); + }); + + it('should succeed when both an endpoint and a query are provided', async () => { + WreckGetSpy.mockResolvedValueOnce({ payload: mockToken, res }); + const bearerAuthCredentialsEndpointQuery: BearerAuthCredentials = { + ...basicConfig, + issuerEndpointUrl: tokenEndpoint, + issuerEndpointQuery: tokenEndpointQuery, + } + + await expect(async () => { + const oauthClientEndpoint = new AzureIMDSOAuthClientBuilder( + bearerAuthCredentialsEndpointQuery, + ).build(maxRetries, retriesWaitMs, retriesMaxWaitMs); + + const token = await oauthClientEndpoint.getAccessToken(); + expect(token).toBe('mockAccessToken'); + expect(WreckGetSpy).toHaveBeenCalledWith( + expect.stringContaining(`${tokenEndpoint}?${tokenEndpointQuery}`), + expect.any(Object)); + }).not.toThrow(); + }); + + it('should retry on retriable errors and succeed', async () => { + + // Fail twice with retriable errors, then succeed + WreckGetSpy + .mockRejectedValueOnce(mockError) + .mockRejectedValueOnce(mockError) + .mockResolvedValue({ payload: mockToken, res }); + + const token = await oauthClient.getAccessToken(); + + expect(token).toBe('mockAccessToken'); + expect(retryHelper.fullJitter).toHaveBeenCalledTimes(maxRetries); + expect(retryHelper.fullJitter).toHaveBeenCalledWith(retriesWaitMs, retriesMaxWaitMs, 0); + expect(retryHelper.fullJitter).toHaveBeenCalledWith(retriesWaitMs, retriesMaxWaitMs, 1); + + expect(retryHelper.isRetriable).toHaveBeenCalledTimes(maxRetries); + expect(retryHelper.sleep).toHaveBeenCalledTimes(maxRetries); + }); + + it('should fail immediately on non-retriable errors', async () => { + WreckGetSpy.mockRejectedValueOnce(mockErrorNonRetry); + await expect(oauthClient.getAccessToken()).rejects.toThrowError(); + + expect(retryHelper.isRetriable).toHaveBeenCalledTimes(1); + expect(retryHelper.fullJitter).not.toHaveBeenCalled(); + expect(retryHelper.sleep).not.toHaveBeenCalled(); + }); + + it('should fail after exhausting all retries', async () => { + WreckGetSpy.mockRejectedValue(mockError); + + await expect(oauthClient.getAccessToken()).rejects.toThrowError(); + + + expect(retryHelper.isRetriable).toHaveBeenCalledTimes(maxRetries); + + expect(retryHelper.fullJitter).toHaveBeenCalledTimes(maxRetries); + expect(retryHelper.fullJitter).toHaveBeenCalledWith(retriesWaitMs, retriesMaxWaitMs, 0); + expect(retryHelper.fullJitter).toHaveBeenCalledWith(retriesWaitMs, retriesMaxWaitMs, 1); + expect(retryHelper.sleep).toHaveBeenCalledTimes(maxRetries); + }); + + it('should not refresh token when not expired', async () => { + WreckGetSpy.mockResolvedValueOnce({ payload: mockToken, res }); + + await oauthClient.getAccessToken(); + await oauthClient.getAccessToken(); + + expect(oauthClient.generateAccessToken).toHaveBeenCalledTimes(1); + }); + + it('should refresh token when expired', async () => { + WreckGetSpy.mockResolvedValueOnce({ payload: mockTokenExpired, res }); + WreckGetSpy.mockResolvedValueOnce({ payload: mockToken, res }); + + await oauthClient.getAccessToken(); + await oauthClient.getAccessToken(); + + expect(oauthClient.generateAccessToken).toHaveBeenCalledTimes(2); + }); +}); diff --git a/schemaregistry/test/oauth-client.spec.ts b/schemaregistry/test/oauth-client.spec.ts index 9ac7c990..200db290 100644 --- a/schemaregistry/test/oauth-client.spec.ts +++ b/schemaregistry/test/oauth-client.spec.ts @@ -1,9 +1,10 @@ -import { OAuthClient } from '../oauth/oauth-client'; +import { _OAuthClient as OAuthClient } from '../oauth/oauth-client'; import { ClientCredentials, AccessToken } from 'simple-oauth2'; import { beforeEach, afterEach, describe, expect, it, jest } from '@jest/globals'; import * as retryHelper from '@confluentinc/schemaregistry/retry-helper'; import { maxRetries, retriesWaitMs, retriesMaxWaitMs } from './test-constants'; import { boomify } from '@hapi/boom'; +import { BearerAuthCredentials } from '../rest-service'; jest.mock('simple-oauth2'); @@ -61,12 +62,16 @@ describe('OAuthClient', () => { }; beforeEach(() => { - oauthClient = new OAuthClient( + const bearerAuthCredentials: BearerAuthCredentials = { + credentialsSource: 'OAUTHBEARER', + issuerEndpointUrl: `${tokenHost}${tokenPath}`, clientId, clientSecret, - tokenHost, - tokenPath, - scope, + scope + } + + oauthClient = new OAuthClient( + bearerAuthCredentials, maxRetries, retriesWaitMs, retriesMaxWaitMs diff --git a/types/config.d.ts b/types/config.d.ts index 0df9d404..e59bb99a 100644 --- a/types/config.d.ts +++ b/types/config.d.ts @@ -1,4 +1,4 @@ -// ====== Generated from librdkafka 2.11.1 file CONFIGURATION.md ====== +// ====== Generated from librdkafka dev_oauthbearer_metadata_based file CONFIGURATION.md ====== // Code that generated this is a derivative work of the code from Nam Nguyen // https://gist.github.com/ntgn81/066c2c8ec5b4238f85d1e9168a04e3fb @@ -713,6 +713,13 @@ export interface GlobalConfig { */ "sasl.oauthbearer.assertion.jwt.template.file"?: string; + /** + * Type of metadata-based authentication to use for OAUTHBEARER/OIDC `azure_imds` authenticates using the Azure IMDS endpoint. Sets a default value for `sasl.oauthbearer.token.endpoint.url` if missing. Configuration values specific of chosen authentication type can be passed through `sasl.oauthbearer.config`. + * + * @default none + */ + "sasl.oauthbearer.metadata.authentication.type"?: 'none' | 'azure_imds'; + /** * List of plugin libraries to load (; separated). The library search path is platform dependent (see dlopen(3) for Unix and LoadLibrary() for Windows). If no filename extension is specified the platform-specific extension (such as .dll or .so) will be appended automatically. */ diff --git a/types/errors.d.ts b/types/errors.d.ts index fba2f51a..a1880fd2 100644 --- a/types/errors.d.ts +++ b/types/errors.d.ts @@ -1,4 +1,4 @@ -// ====== Generated from librdkafka 2.11.1 file src-cpp/rdkafkacpp.h ====== +// ====== Generated from librdkafka dev_oauthbearer_metadata_based file src-cpp/rdkafkacpp.h ====== export const CODES: { ERRORS: { /* Internal errors to rdkafka: */ /** Begin internal error codes (**-200**) */