diff --git a/.release-please-manifest.json b/.release-please-manifest.json index f1c1e58..bcd0522 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.5.0" + ".": "0.6.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 678cab7..6294c77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,29 @@ All notable changes to the LaunchDarkly Python AI package will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org). +## [0.6.0](https://github.com/launchdarkly/python-server-sdk-ai/compare/0.5.0...0.6.0) (2024-12-17) + + +### ⚠ BREAKING CHANGES + +* Unify tracking token to use only `TokenUsage` ([#32](https://github.com/launchdarkly/python-server-sdk-ai/issues/32)) +* Change version_key_to variation_key ([#29](https://github.com/launchdarkly/python-server-sdk-ai/issues/29)) + +### Features + +* Add `LDAIConfigTracker.get_summary` method ([#31](https://github.com/launchdarkly/python-server-sdk-ai/issues/31)) ([e425b1f](https://github.com/launchdarkly/python-server-sdk-ai/commit/e425b1f9e7bf27ab195b877e62af48012eb601c1)) +* Add `track_error` to mirror `track_success` ([#33](https://github.com/launchdarkly/python-server-sdk-ai/issues/33)) ([404f704](https://github.com/launchdarkly/python-server-sdk-ai/commit/404f704dd38f4fc15c718e3dc1027efbda5f36b6)) + + +### Bug Fixes + +* Unify tracking token to use only `TokenUsage` ([#32](https://github.com/launchdarkly/python-server-sdk-ai/issues/32)) ([80e1845](https://github.com/launchdarkly/python-server-sdk-ai/commit/80e18452a936356937660eabe7a186beae4d17bd)) + + +### Code Refactoring + +* Change version_key_to variation_key ([#29](https://github.com/launchdarkly/python-server-sdk-ai/issues/29)) ([fcc720a](https://github.com/launchdarkly/python-server-sdk-ai/commit/fcc720a101c97ccb92fd95509b3e7819d557dde5)) + ## [0.5.0](https://github.com/launchdarkly/python-server-sdk-ai/compare/0.4.0...0.5.0) (2024-12-09) diff --git a/PROVENANCE.md b/PROVENANCE.md index 68e6fa2..097c4ab 100644 --- a/PROVENANCE.md +++ b/PROVENANCE.md @@ -10,7 +10,7 @@ To verify SLSA provenance attestations, we recommend using [slsa-verifier](https ``` # Set the version of the library to verify -VERSION=0.5.0 +VERSION=0.6.0 ``` diff --git a/ldai/__init__.py b/ldai/__init__.py index 19dcef3..42c2d87 100644 --- a/ldai/__init__.py +++ b/ldai/__init__.py @@ -1 +1 @@ -__version__ = "0.5.0" # x-release-please-version +__version__ = "0.6.0" # x-release-please-version diff --git a/ldai/client.py b/ldai/client.py index 0cd2d19..6f488f3 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -129,7 +129,7 @@ class LDAIClient: """The LaunchDarkly AI SDK client object.""" def __init__(self, client: LDClient): - self.client = client + self._client = client def config( self, @@ -147,7 +147,7 @@ def config( :param variables: Additional variables for the model configuration. :return: The value of the model configuration along with a tracker used for gathering metrics. """ - variation = self.client.variation(key, context, default_value.to_dict()) + variation = self._client.variation(key, context, default_value.to_dict()) all_variables = {} if variables: @@ -184,8 +184,8 @@ def config( ) tracker = LDAIConfigTracker( - self.client, - variation.get('_ldMeta', {}).get('versionKey', ''), + self._client, + variation.get('_ldMeta', {}).get('variationKey', ''), key, context, ) diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 53baa1d..48fd00c 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -15,7 +15,7 @@ def td() -> TestData: 'model': {'name': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}}, 'provider': {'name': 'fakeProvider'}, 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], - '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, + '_ldMeta': {'enabled': True, 'variationKey': 'abcd'}, }, "green", ) @@ -31,7 +31,7 @@ def td() -> TestData: {'role': 'system', 'content': 'Hello, {{name}}!'}, {'role': 'user', 'content': 'The day is, {{day}}!'}, ], - '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, + '_ldMeta': {'enabled': True, 'variationKey': 'abcd'}, }, "green", ) @@ -44,7 +44,7 @@ def td() -> TestData: { 'model': {'name': 'fakeModel', 'parameters': {'extra-attribute': 'I can be anything I set my mind/type to'}}, 'messages': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}! Is your last name {{ldctx.last}}?'}], - '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, + '_ldMeta': {'enabled': True, 'variationKey': 'abcd'}, } ) .variation_for_all(0) @@ -56,7 +56,7 @@ def td() -> TestData: { 'model': {'name': 'fakeModel', 'parameters': {'extra-attribute': 'I can be anything I set my mind/type to'}}, 'messages': [{'role': 'system', 'content': 'Hello, {{ldctx.user.name}}! Do you work for {{ldctx.org.shortname}}?'}], - '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, + '_ldMeta': {'enabled': True, 'variationKey': 'abcd'}, } ) .variation_for_all(0) @@ -68,7 +68,7 @@ def td() -> TestData: { 'model': {'name': 'fakeModel', 'parameters': {'temperature': 0.1}}, 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], - '_ldMeta': {'enabled': False, 'versionKey': 'abcd'}, + '_ldMeta': {'enabled': False, 'variationKey': 'abcd'}, } ) .variation_for_all(0) diff --git a/ldai/testing/test_tracker.py b/ldai/testing/test_tracker.py new file mode 100644 index 0000000..3196bfb --- /dev/null +++ b/ldai/testing/test_tracker.py @@ -0,0 +1,306 @@ +from time import sleep +from unittest.mock import MagicMock, call + +import pytest +from ldclient import Config, Context, LDClient +from ldclient.integrations.test_data import TestData + +from ldai.tracker import FeedbackKind, LDAIConfigTracker, TokenUsage + + +@pytest.fixture +def td() -> TestData: + td = TestData.data_source() + td.update( + td.flag('model-config') + .variations( + { + 'model': {'name': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}}, + 'provider': {'name': 'fakeProvider'}, + 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], + '_ldMeta': {'enabled': True, 'variationKey': 'abcd'}, + }, + "green", + ) + .variation_for_all(0) + ) + + return td + + +@pytest.fixture +def client(td: TestData) -> LDClient: + config = Config('sdk-key', update_processor_class=td, send_events=False) + client = LDClient(config=config) + client.track = MagicMock() # type: ignore + return client + + +def test_summary_starts_empty(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + assert tracker.get_summary().duration is None + assert tracker.get_summary().feedback is None + assert tracker.get_summary().success is None + assert tracker.get_summary().usage is None + + +def test_tracks_duration(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + tracker.track_duration(100) + + client.track.assert_called_with( # type: ignore + '$ld:ai:duration:total', + context, + {'variationKey': 'variation-key', 'configKey': 'config-key'}, + 100 + ) + + assert tracker.get_summary().duration == 100 + + +def test_tracks_duration_of(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + tracker.track_duration_of(lambda: sleep(0.01)) + + calls = client.track.mock_calls # type: ignore + + assert len(calls) == 1 + assert calls[0].args[0] == '$ld:ai:duration:total' + assert calls[0].args[1] == context + assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'} + assert calls[0].args[3] == pytest.approx(10, rel=10) + + +def test_tracks_duration_of_with_exception(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + def sleep_and_throw(): + sleep(0.01) + raise ValueError("Something went wrong") + + try: + tracker.track_duration_of(sleep_and_throw) + assert False, "Should have thrown an exception" + except ValueError: + pass + + calls = client.track.mock_calls # type: ignore + + assert len(calls) == 1 + assert calls[0].args[0] == '$ld:ai:duration:total' + assert calls[0].args[1] == context + assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'} + assert calls[0].args[3] == pytest.approx(10, rel=10) + + +def test_tracks_token_usage(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + tokens = TokenUsage(300, 200, 100) + tracker.track_tokens(tokens) + + calls = [ + call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 300), + call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 200), + call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 100), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().usage == tokens + + +def test_tracks_bedrock_metrics(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + bedrock_result = { + '$metadata': {'httpStatusCode': 200}, + 'usage': { + 'totalTokens': 330, + 'inputTokens': 220, + 'outputTokens': 110, + }, + 'metrics': { + 'latencyMs': 50, + } + } + tracker.track_bedrock_converse_metrics(bedrock_result) + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50), + call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330), + call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220), + call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().success is True + assert tracker.get_summary().duration == 50 + assert tracker.get_summary().usage == TokenUsage(330, 220, 110) + + +def test_tracks_bedrock_metrics_with_error(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + bedrock_result = { + '$metadata': {'httpStatusCode': 500}, + 'usage': { + 'totalTokens': 330, + 'inputTokens': 220, + 'outputTokens': 110, + }, + 'metrics': { + 'latencyMs': 50, + } + } + tracker.track_bedrock_converse_metrics(bedrock_result) + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50), + call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330), + call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220), + call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().success is False + assert tracker.get_summary().duration == 50 + assert tracker.get_summary().usage == TokenUsage(330, 220, 110) + + +def test_tracks_openai_metrics(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + class Result: + def __init__(self): + self.usage = Usage() + + class Usage: + def to_dict(self): + return { + 'total_tokens': 330, + 'prompt_tokens': 220, + 'completion_tokens': 110, + } + + tracker.track_openai_metrics(lambda: Result()) + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330), + call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220), + call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110), + ] + + client.track.assert_has_calls(calls, any_order=False) # type: ignore + + assert tracker.get_summary().usage == TokenUsage(330, 220, 110) + + +def test_tracks_openai_metrics_with_exception(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + def raise_exception(): + raise ValueError("Something went wrong") + + try: + tracker.track_openai_metrics(raise_exception) + assert False, "Should have thrown an exception" + except ValueError: + pass + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + ] + + client.track.assert_has_calls(calls, any_order=False) # type: ignore + + assert tracker.get_summary().usage is None + + +@pytest.mark.parametrize( + "kind,label", + [ + pytest.param(FeedbackKind.Positive, "positive", id="positive"), + pytest.param(FeedbackKind.Negative, "negative", id="negative"), + ], +) +def test_tracks_feedback(client: LDClient, kind: FeedbackKind, label: str): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + tracker.track_feedback({'kind': kind}) + + client.track.assert_called_with( # type: ignore + f'$ld:ai:feedback:user:{label}', + context, + {'variationKey': 'variation-key', 'configKey': 'config-key'}, + 1 + ) + assert tracker.get_summary().feedback == {'kind': kind} + + +def test_tracks_success(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + tracker.track_success() + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().success is True + + +def test_tracks_error(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + tracker.track_error() + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().success is False + + +def test_error_overwrites_success(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + tracker.track_success() + tracker.track_error() + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().success is False diff --git a/ldai/tracker.py b/ldai/tracker.py index d179674..8f3c15c 100644 --- a/ldai/tracker.py +++ b/ldai/tracker.py @@ -1,27 +1,11 @@ import time from dataclasses import dataclass from enum import Enum -from typing import Dict, Union +from typing import Dict, Optional from ldclient import Context, LDClient -@dataclass -class TokenMetrics: - """ - Metrics for token usage in AI operations. - - :param total: Total number of tokens used. - :param input: Number of input tokens. - :param output: Number of output tokens. - """ - - total: int - input: int - output: int # type: ignore - - -@dataclass class FeedbackKind(Enum): """ Types of feedback that can be provided for AI operations. @@ -36,99 +20,42 @@ class TokenUsage: """ Tracks token usage for AI operations. - :param total_tokens: Total number of tokens used. - :param prompt_tokens: Number of tokens in the prompt. - :param completion_tokens: Number of tokens in the completion. - """ - - total_tokens: int - prompt_tokens: int - completion_tokens: int - - def to_metrics(self): - """ - Convert token usage to metrics format. - - :return: Dictionary containing token metrics. - """ - return { - 'total': self['total_tokens'], - 'input': self['prompt_tokens'], - 'output': self['completion_tokens'], - } - - -@dataclass -class LDOpenAIUsage: - """ - LaunchDarkly-specific OpenAI usage tracking. - - :param total_tokens: Total number of tokens used. - :param prompt_tokens: Number of tokens in the prompt. - :param completion_tokens: Number of tokens in the completion. + :param total: Total number of tokens used. + :param input: Number of tokens in the prompt. + :param output: Number of tokens in the completion. """ - total_tokens: int - prompt_tokens: int - completion_tokens: int + total: int + input: int + output: int -@dataclass -class OpenAITokenUsage: +class LDAIMetricSummary: """ - Tracks OpenAI-specific token usage. + Summary of metrics which have been tracked. """ - def __init__(self, data: LDOpenAIUsage): - """ - Initialize OpenAI token usage tracking. - - :param data: OpenAI usage data. - """ - self.total_tokens = data.total_tokens - self.prompt_tokens = data.prompt_tokens - self.completion_tokens = data.completion_tokens - - def to_metrics(self) -> TokenMetrics: - """ - Convert OpenAI token usage to metrics format. - - :return: TokenMetrics object containing usage data. - """ - return TokenMetrics( - total=self.total_tokens, - input=self.prompt_tokens, - output=self.completion_tokens, - ) - - -@dataclass -class BedrockTokenUsage: - """ - Tracks AWS Bedrock-specific token usage. - """ + def __init__(self): + self._duration = None + self._success = None + self._feedback = None + self._usage = None - def __init__(self, data: dict): - """ - Initialize Bedrock token usage tracking. + @property + def duration(self) -> Optional[int]: + return self._duration - :param data: Dictionary containing Bedrock usage data. - """ - self.totalTokens = data.get('totalTokens', 0) - self.inputTokens = data.get('inputTokens', 0) - self.outputTokens = data.get('outputTokens', 0) + @property + def success(self) -> Optional[bool]: + return self._success - def to_metrics(self) -> TokenMetrics: - """ - Convert Bedrock token usage to metrics format. + @property + def feedback(self) -> Optional[Dict[str, FeedbackKind]]: + return self._feedback - :return: TokenMetrics object containing usage data. - """ - return TokenMetrics( - total=self.totalTokens, - input=self.inputTokens, - output=self.outputTokens, - ) + @property + def usage(self) -> Optional[TokenUsage]: + return self._usage class LDAIConfigTracker: @@ -137,30 +64,31 @@ class LDAIConfigTracker: """ def __init__( - self, ld_client: LDClient, version_key: str, config_key: str, context: Context + self, ld_client: LDClient, variation_key: str, config_key: str, context: Context ): """ Initialize an AI configuration tracker. :param ld_client: LaunchDarkly client instance. - :param version_key: Version key for tracking. + :param variation_key: Variation key for tracking. :param config_key: Configuration key for tracking. :param context: Context for evaluation. """ - self.ld_client = ld_client - self.version_key = version_key - self.config_key = config_key - self.context = context + self._ld_client = ld_client + self._variation_key = variation_key + self._config_key = config_key + self._context = context + self._summary = LDAIMetricSummary() def __get_track_data(self): """ Get tracking data for events. - :return: Dictionary containing version and config keys. + :return: Dictionary containing variation and config keys. """ return { - 'versionKey': self.version_key, - 'configKey': self.config_key, + 'variationKey': self._variation_key, + 'configKey': self._config_key, } def track_duration(self, duration: int) -> None: @@ -169,22 +97,29 @@ def track_duration(self, duration: int) -> None: :param duration: Duration in milliseconds. """ - self.ld_client.track( - '$ld:ai:duration:total', self.context, self.__get_track_data(), duration + self._summary._duration = duration + self._ld_client.track( + '$ld:ai:duration:total', self._context, self.__get_track_data(), duration ) def track_duration_of(self, func): """ Automatically track the duration of an AI operation. + An exception occurring during the execution of the function will still + track the duration. The exception will be re-thrown. + :param func: Function to track. :return: Result of the tracked function. """ start_time = time.time() - result = func() - end_time = time.time() - duration = int((end_time - start_time) * 1000) # duration in milliseconds - self.track_duration(duration) + try: + result = func() + finally: + end_time = time.time() + duration = int((end_time - start_time) * 1000) # duration in milliseconds + self.track_duration(duration) + return result def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None: @@ -193,17 +128,18 @@ def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None: :param feedback: Dictionary containing feedback kind. """ + self._summary._feedback = feedback if feedback['kind'] == FeedbackKind.Positive: - self.ld_client.track( + self._ld_client.track( '$ld:ai:feedback:user:positive', - self.context, + self._context, self.__get_track_data(), 1, ) elif feedback['kind'] == FeedbackKind.Negative: - self.ld_client.track( + self._ld_client.track( '$ld:ai:feedback:user:negative', - self.context, + self._context, self.__get_track_data(), 1, ) @@ -212,26 +148,62 @@ def track_success(self) -> None: """ Track a successful AI generation. """ - self.ld_client.track( - '$ld:ai:generation', self.context, self.__get_track_data(), 1 + self._summary._success = True + self._ld_client.track( + '$ld:ai:generation', self._context, self.__get_track_data(), 1 + ) + self._ld_client.track( + '$ld:ai:generation:success', self._context, self.__get_track_data(), 1 + ) + + def track_error(self) -> None: + """ + Track an unsuccessful AI generation attempt. + """ + self._summary._success = False + self._ld_client.track( + '$ld:ai:generation', self._context, self.__get_track_data(), 1 + ) + self._ld_client.track( + '$ld:ai:generation:error', self._context, self.__get_track_data(), 1 ) def track_openai_metrics(self, func): """ Track OpenAI-specific operations. + This function will track the duration of the operation, the token + usage, and the success or error status. + + If the provided function throws, then this method will also throw. + + In the case the provided function throws, this function will record the + duration and an error. + + A failed operation will not have any token usage data. + :param func: Function to track. :return: Result of the tracked function. """ - result = self.track_duration_of(func) - if result.usage: - self.track_tokens(OpenAITokenUsage(result.usage)) + try: + result = self.track_duration_of(func) + self.track_success() + if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'): + self.track_tokens(_openai_to_token_usage(result.usage.to_dict())) + except Exception: + self.track_error() + raise + return result def track_bedrock_converse_metrics(self, res: dict) -> dict: """ Track AWS Bedrock conversation operations. + + This function will track the duration of the operation, the token + usage, and the success or error status. + :param res: Response dictionary from Bedrock. :return: The original response dictionary. """ @@ -239,39 +211,74 @@ def track_bedrock_converse_metrics(self, res: dict) -> dict: if status_code == 200: self.track_success() elif status_code >= 400: - # Potentially add error tracking in the future. - pass + self.track_error() if res.get('metrics', {}).get('latencyMs'): self.track_duration(res['metrics']['latencyMs']) if res.get('usage'): - self.track_tokens(BedrockTokenUsage(res['usage'])) + self.track_tokens(_bedrock_to_token_usage(res['usage'])) return res - def track_tokens(self, tokens: Union[TokenUsage, BedrockTokenUsage]) -> None: + def track_tokens(self, tokens: TokenUsage) -> None: """ Track token usage metrics. :param tokens: Token usage data from either custom, OpenAI, or Bedrock sources. """ - token_metrics = tokens.to_metrics() - if token_metrics.total > 0: - self.ld_client.track( + self._summary._usage = tokens + if tokens.total > 0: + self._ld_client.track( '$ld:ai:tokens:total', - self.context, + self._context, self.__get_track_data(), - token_metrics.total, + tokens.total, ) - if token_metrics.input > 0: - self.ld_client.track( + if tokens.input > 0: + self._ld_client.track( '$ld:ai:tokens:input', - self.context, + self._context, self.__get_track_data(), - token_metrics.input, + tokens.input, ) - if token_metrics.output > 0: - self.ld_client.track( + if tokens.output > 0: + self._ld_client.track( '$ld:ai:tokens:output', - self.context, + self._context, self.__get_track_data(), - token_metrics.output, + tokens.output, ) + + def get_summary(self) -> LDAIMetricSummary: + """ + Get the current summary of AI metrics. + + :return: Summary of AI metrics. + """ + return self._summary + + +def _bedrock_to_token_usage(data: dict) -> TokenUsage: + """ + Convert a Bedrock usage dictionary to a TokenUsage object. + + :param data: Dictionary containing Bedrock usage data. + :return: TokenUsage object containing usage data. + """ + return TokenUsage( + total=data.get('totalTokens', 0), + input=data.get('inputTokens', 0), + output=data.get('outputTokens', 0), + ) + + +def _openai_to_token_usage(data: dict) -> TokenUsage: + """ + Convert an OpenAI usage dictionary to a TokenUsage object. + + :param data: Dictionary containing OpenAI usage data. + :return: TokenUsage object containing usage data. + """ + return TokenUsage( + total=data.get('total_tokens', 0), + input=data.get('prompt_tokens', 0), + output=data.get('completion_tokens', 0), + ) diff --git a/pyproject.toml b/pyproject.toml index a0a36d7..4225b10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "launchdarkly-server-sdk-ai" -version = "0.5.0" +version = "0.6.0" description = "LaunchDarkly SDK for AI" authors = ["LaunchDarkly "] license = "Apache-2.0"