From 9d1ae79dc93a7008fd7673a56b762e998df62370 Mon Sep 17 00:00:00 2001 From: Carla Leal Ramos Date: Thu, 21 May 2026 19:26:01 +0000 Subject: [PATCH] feat: Make BedrockModel._format_request and _convert_non_streaming_to_streaming public --- src/strands/models/bedrock.py | 76 ++++++++++++++++++++++++---- tests/strands/models/test_bedrock.py | 76 +++++++++++++++++++++++++++- 2 files changed, 140 insertions(+), 12 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4cd6f7fbc..eddde3ab7 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -234,6 +234,28 @@ def get_config(self) -> BedrockConfig: """ return resolve_config_metadata(self.config, self.config.get("model_id", "")) + def format_request( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + tool_choice: ToolChoice | None = None, + ) -> dict[str, Any]: + """Format a Bedrock converse stream request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. + + Returns: + A Bedrock converse stream request. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return self._format_request(messages, tool_specs, system_prompt_content, tool_choice) + def _format_request( self, messages: Messages, @@ -243,6 +265,9 @@ def _format_request( ) -> dict[str, Any]: """Format a Bedrock converse stream request. + .. deprecated:: + Use :meth:`format_request` instead. This will be removed in September 2026. + Args: messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. @@ -252,6 +277,12 @@ def _format_request( Returns: A Bedrock converse stream request. """ + warnings.warn( + "_format_request is on the deprecation path, use format_request instead. " + "This will be removed in September 2026.", + DeprecationWarning, + stacklevel=2, + ) if not tool_specs: has_tool_content = any( any("toolUse" in block or "toolResult" in block for block in msg.get("content", [])) for msg in messages @@ -830,7 +861,9 @@ async def count_tokens( if system_prompt and system_prompt_content is None: system_prompt_content = [{"text": system_prompt}] - request = self._format_request(messages, tool_specs, system_prompt_content) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + request = self._format_request(messages, tool_specs, system_prompt_content) converse_input: dict[str, Any] = {} if "messages" in request: converse_input["messages"] = request["messages"] @@ -852,13 +885,9 @@ async def count_tokens( logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens) return total_tokens except Exception as e: - if ( - isinstance(e, ClientError) - and e.response.get("Error", {}).get("Code") == "AccessDeniedException" - ): + if isinstance(e, ClientError) and e.response.get("Error", {}).get("Code") == "AccessDeniedException": logger.warning( - "model_id=<%s> | bedrock:CountTokens permission denied," - " falling back to heuristic estimation: %s", + "model_id=<%s> | bedrock:CountTokens permission denied, falling back to heuristic estimation: %s", model_id, e, ) @@ -964,7 +993,9 @@ def _stream( """ try: logger.debug("formatting request") - request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") @@ -988,8 +1019,10 @@ def _stream( else: response = self.client.converse(**request) - for event in self._convert_non_streaming_to_streaming(response): - callback(event) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + for event in self._convert_non_streaming_to_streaming(response): + callback(event) if ( "trace" in response @@ -1044,15 +1077,38 @@ def _stream( callback() logger.debug("finished streaming response from model") + def convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: + """Convert a non-streaming response to the streaming format. + + Args: + response: The non-streaming response from the Bedrock model. + + Returns: + An iterable of response events in the streaming format. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + yield from self._convert_non_streaming_to_streaming(response) + def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: """Convert a non-streaming response to the streaming format. + .. deprecated:: + Use :meth:`convert_non_streaming_to_streaming` instead. This will be removed in September 2026. + Args: response: The non-streaming response from the Bedrock model. Returns: An iterable of response events in the streaming format. """ + warnings.warn( + "_convert_non_streaming_to_streaming is on the deprecation path, " + "use convert_non_streaming_to_streaming instead. " + "This will be removed in September 2026.", + DeprecationWarning, + stacklevel=2, + ) # Yield messageStart event yield {"messageStart": {"role": response["output"]["message"]["role"]}} diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 319b5574f..236098b27 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -24,6 +24,13 @@ from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec +pytestmark = [ + pytest.mark.filterwarnings("ignore:_format_request is on the deprecation path:DeprecationWarning"), + pytest.mark.filterwarnings( + "ignore:_convert_non_streaming_to_streaming is on the deprecation path:DeprecationWarning" + ), +] + FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID @@ -2424,14 +2431,16 @@ def test_tool_choice_supported_no_warning(model, messages, tool_spec, captured_w tool_choice = {"auto": {}} model._format_request(messages, [tool_spec], tool_choice=tool_choice) - assert len(captured_warnings) == 0 + non_deprecation_warnings = [w for w in captured_warnings if not issubclass(w.category, DeprecationWarning)] + assert len(non_deprecation_warnings) == 0 def test_tool_choice_none_no_warning(model, messages, captured_warnings): """Test that None toolChoice doesn't emit warning.""" model._format_request(messages, tool_choice=None) - assert len(captured_warnings) == 0 + non_deprecation_warnings = [w for w in captured_warnings if not issubclass(w.category, DeprecationWarning)] + assert len(non_deprecation_warnings) == 0 def test_get_default_model_with_warning_supported_regions_shows_no_warning(captured_warnings): @@ -3620,3 +3629,66 @@ def test_format_request_cache_tools_string_backward_compat(model, messages, mode exp_cache_point = {"cachePoint": {"type": cache_type}} assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_delegates_to_private(model, messages): + """Test that format_request delegates to _format_request.""" + with unittest.mock.patch.object(model, "_format_request", wraps=model._format_request) as mock_private: + result = model.format_request(messages) + mock_private.assert_called_once_with(messages, None, None, None) + assert result == model.format_request(messages) + + +def test_format_request_passes_all_arguments(model, messages): + """Test that format_request passes all arguments to _format_request.""" + tool_specs = [{"name": "test_tool", "description": "A test tool", "inputSchema": {"json": {}}}] + system_prompt_content = [{"text": "system prompt"}] + tool_choice = {"auto": {}} + + with unittest.mock.patch.object(model, "_format_request", wraps=model._format_request) as mock_private: + model.format_request(messages, tool_specs, system_prompt_content, tool_choice) + mock_private.assert_called_once_with(messages, tool_specs, system_prompt_content, tool_choice) + + +def test_convert_non_streaming_to_streaming_delegates_to_private(model): + """Test that convert_non_streaming_to_streaming delegates to _convert_non_streaming_to_streaming.""" + response = { + "output": {"message": {"role": "assistant", "content": [{"text": "hello"}]}}, + "stopReason": "end_turn", + } + with unittest.mock.patch.object( + model, "_convert_non_streaming_to_streaming", wraps=model._convert_non_streaming_to_streaming + ) as mock_private: + result = list(model.convert_non_streaming_to_streaming(response)) + mock_private.assert_called_once_with(response) + assert len(result) > 0 + + +def test_convert_non_streaming_to_streaming_passes_all_arguments(model): + """Test that convert_non_streaming_to_streaming passes the response to _convert_non_streaming_to_streaming.""" + response = { + "output": {"message": {"role": "assistant", "content": [{"text": "hello"}]}}, + "stopReason": "end_turn", + } + with unittest.mock.patch.object( + model, "_convert_non_streaming_to_streaming", wraps=model._convert_non_streaming_to_streaming + ) as mock_private: + list(model.convert_non_streaming_to_streaming(response)) + call_args = mock_private.call_args + assert call_args.args[0] is response + + +def test_format_request_private_emits_deprecation_warning(model, messages): + """Test that _format_request emits a DeprecationWarning when called directly.""" + with pytest.warns(DeprecationWarning, match="_format_request is on the deprecation path"): + model._format_request(messages) + + +def test_convert_non_streaming_to_streaming_private_emits_deprecation_warning(model): + """Test that _convert_non_streaming_to_streaming emits a DeprecationWarning when called directly.""" + response = { + "output": {"message": {"role": "assistant", "content": [{"text": "hello"}]}}, + "stopReason": "end_turn", + } + with pytest.warns(DeprecationWarning, match="_convert_non_streaming_to_streaming is on the deprecation path"): + list(model._convert_non_streaming_to_streaming(response))