From b2ff524a85f0b988c770d8e53dc6d1c0705acb6b Mon Sep 17 00:00:00 2001 From: yaythomas Date: Fri, 28 Nov 2025 12:14:24 -0800 Subject: [PATCH] feat: add WaitForCallbackContext to submitter BREAKING CHANGE: wait_for_callback submitter signature changed from submitter(callback_id: str) to submitter(callback_id: str, context: WaitForCallbackContext) The WaitForCallbackContext provides access to a logger, enabling submitter functions to log operations consistently with other SDK operations like step and wait_for_condition. This change aligns the wait_for_callback API with other context-aware operations in the SDK, improving consistency and extensibility. - Add WaitForCallbackContext type with logger field - Update wait_for_callback_handler to pass context to submitter - Update all callback tests to use new submitter signature - Add test coverage for context parameter validation --- .../context.py | 3 +- .../operation/callback.py | 15 ++++++--- src/aws_durable_execution_sdk_python/types.py | 5 +++ tests/operation/callback_test.py | 31 +++++++++++++++---- 4 files changed, 43 insertions(+), 11 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 20c3659..30df7bd 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -47,6 +47,7 @@ BatchResult, LoggerInterface, StepContext, + WaitForCallbackContext, WaitForConditionCheckContext, ) from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol @@ -489,7 +490,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None: def wait_for_callback( self, - submitter: Callable[[str], None], + submitter: Callable[[str, WaitForCallbackContext], None], name: str | None = None, config: WaitForCallbackConfig | None = None, ) -> Any: diff --git a/src/aws_durable_execution_sdk_python/operation/callback.py b/src/aws_durable_execution_sdk_python/operation/callback.py index 4fe2a1e..e7bc064 100644 --- a/src/aws_durable_execution_sdk_python/operation/callback.py +++ b/src/aws_durable_execution_sdk_python/operation/callback.py @@ -10,6 +10,7 @@ CallbackOptions, OperationUpdate, ) +from aws_durable_execution_sdk_python.types import WaitForCallbackContext if TYPE_CHECKING: from collections.abc import Callable @@ -23,7 +24,11 @@ CheckpointedResult, ExecutionState, ) - from aws_durable_execution_sdk_python.types import Callback, DurableContext + from aws_durable_execution_sdk_python.types import ( + Callback, + DurableContext, + StepContext, + ) def create_callback_handler( @@ -85,7 +90,7 @@ def create_callback_handler( def wait_for_callback_handler( context: DurableContext, - submitter: Callable[[str], None], + submitter: Callable[[str, WaitForCallbackContext], None], name: str | None = None, config: WaitForCallbackConfig | None = None, ) -> Any: @@ -98,8 +103,10 @@ def wait_for_callback_handler( name=f"{name_with_space}create callback id", config=config ) - def submitter_step(step_context): # noqa: ARG001 - return submitter(callback.callback_id) + def submitter_step(step_context: StepContext): + return submitter( + callback.callback_id, WaitForCallbackContext(logger=step_context.logger) + ) step_config = ( StepConfig( diff --git a/src/aws_durable_execution_sdk_python/types.py b/src/aws_durable_execution_sdk_python/types.py index 9163c86..9181be9 100644 --- a/src/aws_durable_execution_sdk_python/types.py +++ b/src/aws_durable_execution_sdk_python/types.py @@ -57,6 +57,11 @@ class StepContext(OperationContext): pass +@dataclass(frozen=True) +class WaitForCallbackContext(OperationContext): + """Context provided to waitForCallback submitter functions.""" + + @dataclass(frozen=True) class WaitForConditionCheckContext(OperationContext): pass diff --git a/tests/operation/callback_test.py b/tests/operation/callback_test.py index b3a1802..688704e 100644 --- a/tests/operation/callback_test.py +++ b/tests/operation/callback_test.py @@ -303,13 +303,18 @@ def test_wait_for_callback_handler_submitter_called_with_callback_id(): def capture_step_call(func, name, config=None): # Execute the step callable to verify submitter is called correctly step_context = Mock(spec=StepContext) + step_context.logger = Mock() func(step_context) mock_context.step.side_effect = capture_step_call wait_for_callback_handler(mock_context, mock_submitter, "test") - mock_submitter.assert_called_once_with("callback_test_id") + # Verify submitter was called with callback_id and WaitForCallbackContext + assert mock_submitter.call_count == 1 + call_args = mock_submitter.call_args[0] + assert call_args[0] == "callback_test_id" + assert hasattr(call_args[1], "logger") def test_create_callback_handler_with_none_operation_in_result(): @@ -350,6 +355,7 @@ def test_wait_for_callback_handler_with_none_callback_id(): def execute_step(func, name, config=None): step_context = Mock(spec=StepContext) + step_context.logger = Mock() return func(step_context) mock_context.step.side_effect = execute_step @@ -357,7 +363,11 @@ def execute_step(func, name, config=None): result = wait_for_callback_handler(mock_context, mock_submitter, "test") assert result == "result_with_none_id" - mock_submitter.assert_called_once_with(None) + # Verify submitter was called with None callback_id and WaitForCallbackContext + assert mock_submitter.call_count == 1 + call_args = mock_submitter.call_args[0] + assert call_args[0] is None + assert hasattr(call_args[1], "logger") def test_wait_for_callback_handler_with_empty_string_callback_id(): @@ -371,6 +381,7 @@ def test_wait_for_callback_handler_with_empty_string_callback_id(): def execute_step(func, name, config=None): step_context = Mock(spec=StepContext) + step_context.logger = Mock() return func(step_context) mock_context.step.side_effect = execute_step @@ -378,7 +389,11 @@ def execute_step(func, name, config=None): result = wait_for_callback_handler(mock_context, mock_submitter, "test") assert result == "result_with_empty_id" - mock_submitter.assert_called_once_with("") + # Verify submitter was called with empty string callback_id and WaitForCallbackContext + assert mock_submitter.call_count == 1 + call_args = mock_submitter.call_args[0] + assert call_args[0] == "" # noqa: PLC1901 - explicitly testing empty string, not just falsey + assert hasattr(call_args[1], "logger") def test_wait_for_callback_handler_with_large_data(): @@ -585,12 +600,13 @@ def test_wait_for_callback_handler_submitter_exception_handling(): mock_callback.result.return_value = "exception_result" mock_context.create_callback.return_value = mock_callback - def failing_submitter(callback_id): + def failing_submitter(callback_id, context): msg = "Submitter failed" raise ValueError(msg) def step_side_effect(func, name, config=None): step_context = Mock(spec=StepContext) + step_context.logger = Mock() func(step_context) mock_context.step.side_effect = step_side_effect @@ -775,12 +791,14 @@ def test_callback_lifecycle_complete_flow(): assert callback_id == "lifecycle_cb123" - def mock_submitter(cb_id): + def mock_submitter(cb_id, context): assert cb_id == "lifecycle_cb123" + assert hasattr(context, "logger") return "submitted" def execute_step(func, name, config=None): step_context = Mock(spec=StepContext) + step_context.logger = Mock() return func(step_context) mock_context.step.side_effect = execute_step @@ -889,7 +907,7 @@ def test_callback_with_complex_submitter(): submission_log = [] - def complex_submitter(callback_id): + def complex_submitter(callback_id, context): submission_log.append(f"received_id: {callback_id}") if callback_id == "complex_cb789": submission_log.append("api_call_success") @@ -901,6 +919,7 @@ def complex_submitter(callback_id): def execute_step(func, name, config): step_context = Mock(spec=StepContext) + step_context.logger = Mock() return func(step_context) mock_context.step.side_effect = execute_step