From b50187bcfaba59bef7e704d2a65e45f564e5c6ee Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Tue, 16 Jun 2026 02:20:17 -0500 Subject: [PATCH 1/2] Fix metadata timeout propagation - propagate per-call read_timeout, connection_timeout, and timeout (operation deadline) options across query setup metadata calls (container read, query plan, /pkranges) in sync and async paths - extend test coverage for timeout propagation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 72 +++ .../azure-cosmos/azure/cosmos/_constants.py | 10 +- .../azure/cosmos/_cosmos_client_connection.py | 17 +- .../aio/execution_dispatcher.py | 9 +- .../aio/hybrid_search_aggregator.py | 4 + .../execution_dispatcher.py | 9 +- .../hybrid_search_aggregator.py | 4 + .../azure/cosmos/aio/_container.py | 13 +- .../aio/_cosmos_client_connection_async.py | 19 +- .../azure-cosmos/azure/cosmos/container.py | 11 +- .../tests/test_container_rid_header_unit.py | 68 ++- .../test_metadata_timeout_propagation.py | 207 ++++++++ .../tests/test_timeout_propagation_unit.py | 462 ++++++++++++++++++ 14 files changed, 846 insertions(+), 60 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_metadata_timeout_propagation.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_timeout_propagation_unit.py diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index b8b1c451ad10..61139283695e 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -7,6 +7,7 @@ #### Breaking Changes #### Bugs Fixed +* Fixed per-call `read_timeout`, `connection_timeout`, and `timeout` (operation deadline) being dropped on the metadata calls a query makes before its first page. #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 297e62d69d47..9a1517e3dfc8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -124,6 +124,11 @@ def build_options(kwargs: dict[str, Any]) -> dict[str, Any]: options[Constants.Kwargs.READ_TIMEOUT] = kwargs[Constants.Kwargs.READ_TIMEOUT] if Constants.Kwargs.TIMEOUT in kwargs: options[Constants.Kwargs.TIMEOUT] = kwargs[Constants.Kwargs.TIMEOUT] + # Copy (not pop) so connection_timeout stays in kwargs for the page fetch + # and is also placed in options, where the container read, partition-key + # ranges, and query plan calls read it. + if Constants.Kwargs.CONNECTION_TIMEOUT in kwargs: + options[Constants.Kwargs.CONNECTION_TIMEOUT] = kwargs[Constants.Kwargs.CONNECTION_TIMEOUT] options[Constants.OperationStartTime] = time.time() @@ -1082,6 +1087,70 @@ def _build_properties_cache(properties: dict[str, Any], container_link: str) -> "partitionKey": properties.get("partitionKey", None), "container_link": container_link } +# The per-call timeout keys a caller can set on a single request. Listed once +# here so format_pk_range_options and the hybrid-search fetch forward the same +# set. +_PER_CALL_TIMEOUT_OPTION_KEYS: Tuple[str, ...] = ( + Constants.Kwargs.READ_TIMEOUT, + Constants.Kwargs.CONNECTION_TIMEOUT, + Constants.Kwargs.TIMEOUT, +) + +# The operation deadline is checked as elapsed = now - OperationStartTime, and +# OperationStartTime defaults to the current time when it is missing. So timeout +# and OperationStartTime must be carried together onto the metadata setup calls; +# otherwise a setup call measures the deadline from its own start instead of the +# operation's start. This adds OperationStartTime to the three timeout keys above. +_PER_CALL_DEADLINE_OPTION_KEYS: Tuple[str, ...] = _PER_CALL_TIMEOUT_OPTION_KEYS + ( + Constants.OperationStartTime, +) + + +def _carry_per_call_timeout_options(source: Mapping[str, Any], destination: dict[str, Any]) -> None: + """Copy the per-call timeouts and the operation start time from source into destination. + + Copies read_timeout, connection_timeout, timeout, and OperationStartTime. Only + keys present in source are copied, so a timeout the caller did not set stays + absent and the request uses the client default instead of None. + + :param source: The request options to read the timeouts from. + :type source: ~collections.abc.Mapping[str, typing.Any] + :param destination: The options dict to copy the timeouts into. + :type destination: dict[str, typing.Any] + :return: None + :rtype: None + """ + for key in _PER_CALL_DEADLINE_OPTION_KEYS: + if key in source: + destination[key] = source[key] + + +def _copy_per_call_timeouts_to_kwargs( + options: Optional[Mapping[str, Any]], + kwargs: dict[str, Any] +) -> None: + """Copy the per-call timeouts and the operation start time from options into kwargs. + + Moves read_timeout, connection_timeout, timeout, and OperationStartTime from + the request options into the kwargs the request layer reads. A value is copied + only when it is set (not None), so an unset timeout falls back to the client + default instead of None; setdefault keeps any value already in kwargs. + + :param options: The request options to read the timeouts from (may be None or empty). + :type options: ~collections.abc.Mapping[str, typing.Any] or None + :param kwargs: The kwargs dict to copy the timeouts into; mutated in place. + :type kwargs: dict[str, typing.Any] + :return: None + :rtype: None + """ + if not options: + return + for key in _PER_CALL_DEADLINE_OPTION_KEYS: + value = options.get(key) + if value is not None: + kwargs.setdefault(key, value) + + def format_pk_range_options(query_options: Mapping[str, Any]) -> dict[str, Any]: """Formats the partition key range options to be used internally from the query ones. :param dict query_options: The query options being used. @@ -1094,4 +1163,7 @@ def format_pk_range_options(query_options: Mapping[str, Any]) -> dict[str, Any]: pk_range_options[Constants.ContainerRID] = query_options[Constants.ContainerRID] if "excludedLocations" in query_options: pk_range_options["excludedLocations"] = query_options["excludedLocations"] + # Keep the per-call timeouts so the partition-key ranges fetch uses them + # instead of the client default. + _carry_per_call_timeout_options(query_options, pk_range_options) return pk_range_options diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 73ba0649a859..e9d6c0d2f578 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -116,15 +116,17 @@ class _Constants: class Kwargs: """Keyword arguments used in the azure-cosmos package""" + # Whether to retry write operations if they fail. Used either at client level or request level. RETRY_WRITE: Literal["retry_write"] = "retry_write" - """Whether to retry write operations if they fail. Used either at client level or request level.""" EXCLUDED_LOCATIONS: Literal["excludedLocations"] = "excludedLocations" + # Availability strategy config. Used either at client level or request level. AVAILABILITY_STRATEGY: Literal["availabilityStrategy"] = "availabilityStrategy" - """Availability strategy config. Used either at client level or request level""" + # Socket read timeout in seconds. Used either at client level or request level. READ_TIMEOUT: Literal["read_timeout"] = "read_timeout" - """Socket read timeout in seconds. Used either at client level or request level.""" + # Absolute timeout in seconds for the combined HTTP request and response processing. TIMEOUT: Literal["timeout"] = "timeout" - """Absolute timeout in seconds for the combined HTTP request and response processing.""" + # Socket connect (handshake) timeout in seconds. Used either at client level or request level. + CONNECTION_TIMEOUT: Literal["connection_timeout"] = "connection_timeout" class UserAgentFeatureFlags(IntEnum): """ diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index a5d01a7a12c9..4393db5e51b1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3244,18 +3244,11 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma """ if options is None: options = {} - read_timeout = options.get("read_timeout") - if read_timeout is not None: - # we currently have a gap where kwargs are not getting passed correctly down the pipeline. In order to make - # absolute time out work, we are passing read_timeout via kwargs as a temporary fix - kwargs.setdefault("read_timeout", read_timeout) - - operation_start_time = options.get(Constants.OperationStartTime) - if operation_start_time is not None: - kwargs.setdefault(Constants.OperationStartTime, operation_start_time) - timeout = options.get("timeout") - if timeout is not None: - kwargs.setdefault("timeout", timeout) + # Copy the per-call timeouts and the operation start time out of options into + # kwargs, where _Request reads them. A value is copied only when set, so + # an unset timeout falls back to the client/policy default instead of + # None; setdefault keeps any explicit kwarg the caller already placed. + base._copy_per_call_timeouts_to_kwargs(options, kwargs) # Execution context injects this via request options; keep kwargs fallback # for compatibility with call paths that still thread internal values there. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index e380e56a2a8e..ce5ca26d2f93 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -35,6 +35,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.http_constants import StatusCodes from ..._constants import _Constants as Constants +from ... import _base # pylint: disable=protected-access @@ -67,11 +68,17 @@ def __init__(self, client, resource_link, query, options, fetch_function, async def _create_execution_context_with_query_plan(self): self._fetched_query_plan = True query_to_use = self._query if self._query is not None else "Select * from root r" + # read_timeout is forwarded as-is (None when the caller did not set it) to + # keep its existing behavior. It is set before the helper, so the helper's + # setdefault leaves it unchanged and only adds connection_timeout, timeout, + # and OperationStartTime when the caller set them. + query_plan_kwargs = {"read_timeout": self._options.get('read_timeout')} + _base._copy_per_call_timeouts_to_kwargs(self._options, query_plan_kwargs) query_plan = await self._client._GetQueryPlanThroughGateway( query_to_use, self._resource_link, self._options.get('excludedLocations'), - read_timeout=self._options.get('read_timeout') + **query_plan_kwargs ) query_execution_info = _PartitionedQueryExecutionInfo(query_plan) qe_info = getattr(query_execution_info, "_query_execution_info", None) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py index e1a1393238cb..e2947c4e3b87 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py @@ -10,6 +10,7 @@ _FULL_TEXT_SCORE_SCOPE_KEY, _FULL_TEXT_SCORE_SCOPE_LOCAL, _FULL_TEXT_SCORE_SCOPE_DEFAULT from azure.cosmos._routing import routing_range from azure.cosmos import exceptions +from azure.cosmos import _base from ..._constants import _Constants as Constants # pylint: disable=protected-access @@ -297,6 +298,9 @@ async def _get_target_partition_key_range(self, target_all_ranges): feed_options = {} if Constants.ContainerRID in self._options: feed_options[Constants.ContainerRID] = self._options[Constants.ContainerRID] + # This path calls _ReadPartitionKeyRanges directly and skips + # format_pk_range_options, so copy the per-call timeouts here too. + _base._carry_per_call_timeout_options(self._options, feed_options) return [item async for item in self._client._ReadPartitionKeyRanges( collection_link=self._resource_link, feed_options=feed_options)] query_ranges = self._partitioned_query_ex_info.get_query_ranges() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 99a22f670e29..aacd538d7a61 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -34,6 +34,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.http_constants import StatusCodes, SubStatusCodes from .._constants import _Constants as Constants +from .. import _base # pylint: disable=protected-access @@ -97,11 +98,17 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon def _create_execution_context_with_query_plan(self): self._fetched_query_plan = True query_to_use = self._query if self._query is not None else "Select * from root r" + # read_timeout is forwarded as-is (None when the caller did not set it) to + # keep its existing behavior. It is set before the helper, so the helper's + # setdefault leaves it unchanged and only adds connection_timeout, timeout, + # and OperationStartTime when the caller set them. + query_plan_kwargs = {"read_timeout": self._options.get('read_timeout')} + _base._copy_per_call_timeouts_to_kwargs(self._options, query_plan_kwargs) query_plan = self._client._GetQueryPlanThroughGateway( query_to_use, self._resource_link, self._options.get('excludedLocations'), - read_timeout=self._options.get('read_timeout') + **query_plan_kwargs ) query_execution_info = _PartitionedQueryExecutionInfo(query_plan) qe_info = getattr(query_execution_info, "_query_execution_info", None) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py index f85d3eeb60bf..a39b54ad54fb 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py @@ -8,6 +8,7 @@ from azure.cosmos._execution_context import document_producer from azure.cosmos._routing import routing_range from azure.cosmos import exceptions +from azure.cosmos import _base from .._constants import _Constants as Constants # pylint: disable=protected-access @@ -454,6 +455,9 @@ def _get_target_partition_key_range(self, target_all_ranges): feed_options = {} if Constants.ContainerRID in self._options: feed_options[Constants.ContainerRID] = self._options[Constants.ContainerRID] + # This path calls _ReadPartitionKeyRanges directly and skips + # format_pk_range_options, so copy the per-call timeouts here too. + _base._carry_per_call_timeout_options(self._options, feed_options) return list(self._client._ReadPartitionKeyRanges( collection_link=self._resource_link, feed_options=feed_options)) query_ranges = self._partitioned_query_ex_info.get_query_ranges() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index b4d760b5c5f4..dc1119710414 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -39,7 +39,8 @@ from .. import _utils as utils from .._availability_strategy_config import _validate_request_hedging_strategy from .._base import (_build_properties_cache, _deserialize_throughput, _replace_throughput, - build_options as _build_options, GenerateGuidId, validate_cache_staleness_value) + build_options as _build_options, _copy_per_call_timeouts_to_kwargs, + GenerateGuidId, validate_cache_staleness_value) from .._change_feed.feed_range_internal import FeedRangeInternalEpk from .._cosmos_responses import CosmosDict, CosmosList, CosmosAsyncItemPaged @@ -102,13 +103,9 @@ async def _get_properties_with_options(self, options: Optional[dict[str, Any]] = if options: if "excludedLocations" in options: kwargs['excluded_locations'] = options['excludedLocations'] - if Constants.OperationStartTime in options: - kwargs[Constants.OperationStartTime] = options[Constants.OperationStartTime] - if Constants.Kwargs.TIMEOUT in options: - kwargs[Constants.Kwargs.TIMEOUT] = options[Constants.Kwargs.TIMEOUT] - if Constants.Kwargs.READ_TIMEOUT in options: - kwargs[Constants.Kwargs.READ_TIMEOUT] = options[Constants.Kwargs.READ_TIMEOUT] - + # Forward the per-call timeouts and the operation start time so the + # container read honors them instead of the client/policy default. + _copy_per_call_timeouts_to_kwargs(options, kwargs) return await self._get_properties(**kwargs) async def _get_properties(self, **kwargs: Any) -> dict[str, Any]: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 1f9f2ca87369..0a1a1af77144 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -3038,20 +3038,11 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, if options is None: options = {} - read_timeout = options.get("read_timeout") - if read_timeout is not None: - # we currently have a gap where kwargs are not getting passed correctly down the pipeline. In order to make - # absolute time out work, we are passing read_timeout via kwargs as a temporary fix - kwargs.setdefault("read_timeout", read_timeout) - - operation_start_time = options.get(Constants.OperationStartTime) - if operation_start_time is not None: - # we need to set operation_state in kwargs as thats where it is looked at while sending the request - kwargs.setdefault(Constants.OperationStartTime, operation_start_time) - timeout = options.get("timeout") - if timeout is not None: - # we need to set operation_state in kwargs as that's where it is looked at while sending the request - kwargs.setdefault("timeout", timeout) + # Copy the per-call timeouts and the operation start time out of options into + # kwargs, where _Request reads them. A value is copied only when set, so + # an unset timeout falls back to the client/policy default instead of + # None; setdefault keeps any explicit kwarg the caller already placed. + base._copy_per_call_timeouts_to_kwargs(options, kwargs) # The capture dict can arrive via two upstream paths: # 1. The query execution context puts it into ``options`` (the diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 665ae2e4f869..4972cbe10c2d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -35,7 +35,7 @@ from . import _utils as utils from ._availability_strategy_config import _validate_request_hedging_strategy from ._base import (_build_properties_cache, _deserialize_throughput, _replace_throughput, build_options, - GenerateGuidId, validate_cache_staleness_value) + _copy_per_call_timeouts_to_kwargs, GenerateGuidId, validate_cache_staleness_value) from ._change_feed.feed_range_internal import FeedRangeInternalEpk from ._constants import _Constants as Constants, TimeoutScope from ._cosmos_client_connection import CosmosClientConnection @@ -103,12 +103,9 @@ def _get_properties_with_options(self, options: Optional[dict[str, Any]] = None) if options: if "excludedLocations" in options: kwargs['excluded_locations'] = options['excludedLocations'] - if Constants.OperationStartTime in options: - kwargs[Constants.OperationStartTime] = options[Constants.OperationStartTime] - if Constants.Kwargs.TIMEOUT in options: - kwargs[Constants.Kwargs.TIMEOUT] = options[Constants.Kwargs.TIMEOUT] - if Constants.Kwargs.READ_TIMEOUT in options: - kwargs[Constants.Kwargs.READ_TIMEOUT] = options[Constants.Kwargs.READ_TIMEOUT] + # Forward the per-call timeouts and the operation start time so the + # container read honors them instead of the client/policy default. + _copy_per_call_timeouts_to_kwargs(options, kwargs) return self._get_properties(**kwargs) def _get_properties(self, **kwargs: Any) -> dict[str, Any]: diff --git a/sdk/cosmos/azure-cosmos/tests/test_container_rid_header_unit.py b/sdk/cosmos/azure-cosmos/tests/test_container_rid_header_unit.py index 04214d6c3f9a..15000902f0ed 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_container_rid_header_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_container_rid_header_unit.py @@ -35,6 +35,22 @@ ] +def _wire_mock_change_feed_sidecars(kwargs, etag): + """Populate optional internal sidecars used by the /pkranges drain loop. + + The production path wires `_internal_response_status_capture` in + `_synchronized_request` / `_asynchronous_request`. These unit-test mocks call + `_ReadPartitionKeyRanges` directly, so they must seed the same sidecar to + keep `evaluate_drain_page` on its normal path. + """ + status_capture = kwargs.get("_internal_response_status_capture") + if status_capture is not None: + status_capture[0] = http_constants.StatusCodes.NOT_MODIFIED + response_hook = kwargs.get("response_hook") + if response_hook: + response_hook({"etag": etag}, None) + + class CapturingMockClient: """A mock CosmosClientConnection that records the feed_options passed to _ReadPartitionKeyRanges so tests can assert on them.""" @@ -52,10 +68,7 @@ def _ReadPartitionKeyRanges( ): self.captured_feed_options = dict(feed_options) if feed_options else {} self.call_count += 1 - # Invoke the response_hook if provided (the cache uses it to capture etag) - response_hook = kwargs.get("response_hook") - if response_hook: - response_hook({"etag": "test-etag-1"}, None) + _wire_mock_change_feed_sidecars(kwargs, "test-etag-1") return iter(self.partition_key_ranges) @@ -92,6 +105,39 @@ def test_format_pk_range_options_excludedLocations_passes_through(self): assert result["containerRID"] == CONTAINER_RID assert result["excludedLocations"] == ["West US"] + def test_format_pk_range_options_timeouts_pass_through(self): + """The per-call timeouts (read_timeout, connection_timeout, timeout) + must survive sanitization so the partition-key ranges fetch uses the + caller's values instead of the client default.""" + result = _base.format_pk_range_options({ + "containerRID": CONTAINER_RID, + "excludedLocations": ["West US"], + "read_timeout": 30, + "connection_timeout": 0.5, + "timeout": 2, + "somethingElse": 123, + }) + assert result["containerRID"] == CONTAINER_RID + assert result["excludedLocations"] == ["West US"] + assert result["read_timeout"] == 30 + assert result["connection_timeout"] == 0.5 + assert result["timeout"] == 2 + # Unknown keys are still stripped. + assert "somethingElse" not in result + + def test_format_pk_range_options_timeouts_absent_not_added(self): + """A timeout the caller did not set must stay absent, so the request + uses the client default instead of being given None. Only the + timeout(s) actually present survive.""" + result = _base.format_pk_range_options({ + "containerRID": CONTAINER_RID, + "read_timeout": 30, + }) + assert result["read_timeout"] == 30 + assert "connection_timeout" not in result + assert "timeout" not in result + + # ----- PartitionKeyRangeCache ----- def test_initial_load_passes_containerRID(self): @@ -192,6 +238,7 @@ def __init__(self): def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): self.captured_feed_options = dict(feed_options) if feed_options else {} self.call_count += 1 + _wire_mock_change_feed_sidecars(kwargs, "test-etag-1") # First call: initial load returns original ranges # Second call: incremental returns split ranges with unknown parents, # raising _IncrementalMergeFailed (caught → retry incremental, count 0→1) @@ -235,6 +282,7 @@ def __init__(self): def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): self.captured_feed_options = dict(feed_options) if feed_options else {} self.call_count += 1 + _wire_mock_change_feed_sidecars(kwargs, "test-etag-1") # Call 1: initial load → good ranges # Call 2: incremental → split ranges (unresolvable) → retry (count 0→1) # Call 3: incremental retry → split ranges again (still unresolvable) @@ -314,9 +362,7 @@ def test_full_load_removes_stale_if_none_match_header(self): class HeaderCapturingClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): captured_headers.append(dict(kwargs.get('headers', {}))) - response_hook = kwargs.get("response_hook") - if response_hook: - response_hook({"etag": "etag-full"}, None) + _wire_mock_change_feed_sidecars(kwargs, "etag-full") return iter(PARTITION_KEY_RANGES) client = HeaderCapturingClient() @@ -343,9 +389,7 @@ def test_full_load_with_incomplete_ranges_surfaces_503(self): class IncompleteRangesClient: """Returns ranges with a gap — CompleteRoutingMap will return None.""" def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): - response_hook = kwargs.get("response_hook") - if response_hook: - response_hook({"etag": "etag-incomplete"}, None) + _wire_mock_change_feed_sidecars(kwargs, "etag-incomplete") # Gap: missing the range covering 3F-7F return iter([ {"id": "0", "minInclusive": "", "maxExclusive": "3F"}, @@ -375,9 +419,7 @@ def test_incremental_fallback_to_full_load_succeeds(self): class FallbackClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count[0] += 1 - response_hook = kwargs.get("response_hook") - if response_hook: - response_hook({"etag": f"etag-{call_count[0]}"}, None) + _wire_mock_change_feed_sidecars(kwargs, f"etag-{call_count[0]}") if call_count[0] == 1: # First call: incremental — return a child whose parent doesn't exist diff --git a/sdk/cosmos/azure-cosmos/tests/test_metadata_timeout_propagation.py b/sdk/cosmos/azure-cosmos/tests/test_metadata_timeout_propagation.py new file mode 100644 index 000000000000..e3cd800707b6 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_metadata_timeout_propagation.py @@ -0,0 +1,207 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Fault-injection tests for per-call timeout propagation to the metadata +"setup" calls a query makes before its first page: the container read, the +``/pkranges`` partition-key-ranges fetch, and the query-plan fetch. + +They run end to end against a live account (or the emulator), using a recording +transport that captures the ``connection_timeout`` / ``read_timeout`` the SDK +hands to the wire for each request, plus an injected delay to exercise the +operation deadline: + +* **Cold start** -- a per-call ``read_timeout`` / ``connection_timeout`` set on + ``query_items`` reaches the container read, the query plan, and ``/pkranges``, + while the forced-short account probe does not inherit them. +* **Post-split** -- after the routing map is force-refreshed (the ``410 Gone`` + path a partition split triggers), the re-issued ``/pkranges`` fetch still + carries the caller's per-call timeouts. +* **Operation deadline** -- with the query-plan fetch delayed past a tight + ``timeout``, the query raises ``CosmosClientTimeoutError`` during the setup + phase, before the ``/pkranges`` fan-out is issued. + +The recording transport observes ``connection_timeout`` / ``read_timeout`` as +they reach the transport ``send`` call, so a regression that drops a per-call +timeout on a setup call surfaces here as the client default instead of the +value the test set. +""" + +import re +import unittest +from time import sleep +from urllib.parse import urlparse + +import pytest + +import test_config +from azure.cosmos import CosmosClient, exceptions +from azure.cosmos.http_constants import HttpHeaders + +from _fault_injection_transport import FaultInjectionTransport + + +def _classify_request(request): + """Bucket an outgoing request by URL/header so the test can assert on the + metadata setup calls independently of the page fetch.""" + url = request.url or "" + headers = request.headers or {} + if "/pkranges" in url: + return "pkranges" + if "/docs" in url: + flag = headers.get(HttpHeaders.IsQueryPlanRequest) + if flag is not None and str(flag).lower() == "true": + return "query_plan" + return "page_fetch" + if re.search(r"/colls/[^/]+/?$", urlparse(url).path): + return "container_read" + if "/dbs/" not in url: + return "account_probe" + return "other" + + +def _is_query_plan(request): + return _classify_request(request) == "query_plan" + + +class _RecordingFaultTransport(FaultInjectionTransport): + """Records the per-request ``connection_timeout`` / ``read_timeout`` handed + to the transport, and optionally sleeps before a matching request so the + operation deadline can be exercised without depending on real latency.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.records = [] + self._delays = [] + + def add_delay(self, predicate, seconds): + self._delays.append((predicate, seconds)) + + def send(self, request, *, proxies=None, **kwargs): + self.records.append({ + "kind": _classify_request(request), + "url": request.url, + "connection_timeout": kwargs.get("connection_timeout"), + "read_timeout": kwargs.get("read_timeout"), + }) + for predicate, seconds in self._delays: + if predicate(request): + sleep(seconds) + break + return super().send(request, proxies=proxies, **kwargs) + + def records_for(self, kind): + return [r for r in self.records if r["kind"] == kind] + + +@pytest.mark.cosmosEmulator +class TestMetadataTimeoutPropagation(unittest.TestCase): + """End-to-end propagation of per-call timeouts to the metadata setup calls.""" + + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID + + # A cross-partition aggregate forces the cross-partition pipeline: an + # optimistic attempt, then a query-plan fetch, then a /pkranges fan-out. + QUERY = "SELECT VALUE COUNT(1) FROM c" + + # Per-call values that are deliberately not the client/policy defaults, so a + # match proves the caller's value -- not the default -- reached the wire. + READ_TIMEOUT = 33 + CONNECTION_TIMEOUT = 7 + + def _cold_client(self, transport): + # A fresh client => cold container-properties and routing caches, so the + # setup calls actually go on the wire. + return CosmosClient(self.host, self.master_key, transport=transport) + + def test_cold_start_per_call_timeouts_reach_setup_calls(self): + transport = _RecordingFaultTransport() + client = self._cold_client(transport) + try: + container = client.get_database_client(self.TEST_DATABASE_ID).get_container_client( + self.TEST_CONTAINER_ID) + list(container.query_items( + query=self.QUERY, + enable_cross_partition_query=True, + read_timeout=self.READ_TIMEOUT, + connection_timeout=self.CONNECTION_TIMEOUT, + )) + finally: + client.close() + + for kind in ("container_read", "query_plan", "pkranges"): + recs = transport.records_for(kind) + self.assertTrue(recs, "expected at least one {} request on a cold client".format(kind)) + for r in recs: + self.assertEqual( + r["read_timeout"], self.READ_TIMEOUT, + "{} dropped the per-call read_timeout (got {})".format(kind, r["read_timeout"])) + self.assertEqual( + r["connection_timeout"], self.CONNECTION_TIMEOUT, + "{} dropped the per-call connection_timeout (got {})".format(kind, r["connection_timeout"])) + + # The forced-short failover probe must never inherit a caller's per-call + # values -- that is what keeps a generous read_timeout from slowing + # regional failover. + for r in transport.records_for("account_probe"): + self.assertNotEqual(r["read_timeout"], self.READ_TIMEOUT) + self.assertNotEqual(r["connection_timeout"], self.CONNECTION_TIMEOUT) + + def test_post_split_pkranges_refresh_carries_per_call_timeouts(self): + transport = _RecordingFaultTransport() + client = self._cold_client(transport) + try: + container = client.get_database_client(self.TEST_DATABASE_ID).get_container_client( + self.TEST_CONTAINER_ID) + # Warm the caches with default timeouts. + list(container.query_items(query=self.QUERY, enable_cross_partition_query=True)) + + # Simulate the post-split refresh: clearing the routing map forces the + # next query to re-issue /pkranges, exactly as the 410-Gone path does. + client.client_connection._routing_map_provider.clear_cache() # pylint: disable=protected-access + transport.records.clear() + + list(container.query_items( + query=self.QUERY, + enable_cross_partition_query=True, + read_timeout=44, + connection_timeout=9, + )) + finally: + client.close() + + pkranges = transport.records_for("pkranges") + self.assertTrue(pkranges, "a /pkranges refresh should occur after the routing cache is cleared") + for r in pkranges: + self.assertEqual(r["read_timeout"], 44, + "the post-split /pkranges refresh dropped the per-call read_timeout") + self.assertEqual(r["connection_timeout"], 9, + "the post-split /pkranges refresh dropped the per-call connection_timeout") + + def test_operation_deadline_halts_setup_phase(self): + transport = _RecordingFaultTransport() + # Delay the query-plan fetch past the operation deadline. With timeout=1 + # and a 2s query plan, the setup phase exceeds the budget, so the query + # raises CosmosClientTimeoutError before it issues the /pkranges fan-out. + transport.add_delay(_is_query_plan, 2.0) + client = self._cold_client(transport) + try: + container = client.get_database_client(self.TEST_DATABASE_ID).get_container_client( + self.TEST_CONTAINER_ID) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + list(container.query_items(query=self.QUERY, enable_cross_partition_query=True, timeout=1)) + finally: + client.close() + + # The deadline tripped during the setup phase: the /pkranges fan-out + # (which follows the query plan) never went out. + self.assertEqual(transport.records_for("pkranges"), [], + "the operation deadline should halt the query during the setup phase, " + "before the /pkranges fan-out is issued") + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_timeout_propagation_unit.py b/sdk/cosmos/azure-cosmos/tests/test_timeout_propagation_unit.py new file mode 100644 index 000000000000..fb8203491d6b --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_timeout_propagation_unit.py @@ -0,0 +1,462 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Unit tests for per-call timeout propagation to the container read, query +plan, and partition-key ranges metadata calls. + +These tests are deliberately mock-light and do not touch the network, so they +validate the propagation logic in isolation: + +* ``build_options`` carries ``connection_timeout`` into the options dict. +* ``_carry_per_call_timeout_options`` forwards only the timeouts actually set, + plus the operation start time (``OperationStartTime``) when present. +* ``format_pk_range_options`` carries the timers and the operation start time. +* The query-plan dispatcher (sync + async) forwards ``connection_timeout``, + ``timeout`` and ``OperationStartTime`` only when set -- never as ``None`` -- + so an unset value cannot override the client default in the request layer. +* The container read (``_get_properties_with_options``) copies the timeouts and + the operation start time from options into the kwargs it hands down. +""" + +import unittest + +import pytest + +from azure.cosmos import _base +from azure.cosmos._constants import _Constants as Constants +from azure.cosmos.container import ContainerProxy +from azure.cosmos._execution_context.execution_dispatcher import _ProxyQueryExecutionContext +from azure.cosmos._execution_context.aio.execution_dispatcher import ( + _ProxyQueryExecutionContext as _AsyncProxyQueryExecutionContext, +) +from azure.cosmos._execution_context.hybrid_search_aggregator import _HybridSearchContextAggregator +from azure.cosmos._execution_context.aio.hybrid_search_aggregator import ( + _HybridSearchContextAggregator as _AsyncHybridSearchContextAggregator, +) + + +class _StopBeforePipeline(Exception): + """Raised by the recording client's query-plan stub to short-circuit + ``_create_execution_context_with_query_plan`` right after the gateway call, + before it tries to build a real pipelined execution context.""" + + +class _RecordingQueryPlanClient: + """Minimal stand-in for CosmosClientConnection that records the kwargs the + query-plan dispatcher forwards to ``_GetQueryPlanThroughGateway``.""" + + def __init__(self): + self.captured_kwargs = None + + def _GetQueryPlanThroughGateway(self, query, resource_link, excluded_locations=None, **kwargs): + self.captured_kwargs = dict(kwargs) + raise _StopBeforePipeline() + + +class _AsyncRecordingQueryPlanClient: + """Async counterpart of :class:`_RecordingQueryPlanClient`.""" + + def __init__(self): + self.captured_kwargs = None + + async def _GetQueryPlanThroughGateway(self, query, resource_link, excluded_locations=None, **kwargs): + self.captured_kwargs = dict(kwargs) + raise _StopBeforePipeline() + + +def _noop_fetch(_options): + return [], {} + + +class TestBuildOptionsConnectionTimeout(unittest.TestCase): + """build_options copies connection_timeout into the options dict (like + read_timeout and timeout) while leaving it in kwargs for the page fetch.""" + + def test_connection_timeout_copied_into_options(self): + kwargs = {"connection_timeout": 0.5, "read_timeout": 30, "timeout": 2} + options = _base.build_options(kwargs) + assert options[Constants.Kwargs.CONNECTION_TIMEOUT] == 0.5 + assert options[Constants.Kwargs.READ_TIMEOUT] == 30 + assert options[Constants.Kwargs.TIMEOUT] == 2 + + def test_connection_timeout_stays_in_kwargs(self): + # A copy (not a pop): the page fetch consumes connection_timeout from + # kwargs, so build_options must not remove it. + kwargs = {"connection_timeout": 0.5} + _base.build_options(kwargs) + assert kwargs["connection_timeout"] == 0.5 + + def test_connection_timeout_absent_not_added(self): + options = _base.build_options({}) + assert Constants.Kwargs.CONNECTION_TIMEOUT not in options + + +class TestCarryPerCallTimeoutOptions(unittest.TestCase): + """The shared helper that copies the per-call timeouts into an options dict.""" + + def test_carries_only_present_keys(self): + destination = {} + _base._carry_per_call_timeout_options( + {"read_timeout": 30, "timeout": 2, "unrelated": 9}, destination + ) + assert destination == {"read_timeout": 30, "timeout": 2} + + def test_empty_source_leaves_destination_untouched(self): + destination = {"containerRID": "rid"} + _base._carry_per_call_timeout_options({}, destination) + assert destination == {"containerRID": "rid"} + + def test_keys_are_the_three_per_call_timeouts(self): + assert _base._PER_CALL_TIMEOUT_OPTION_KEYS == ( + Constants.Kwargs.READ_TIMEOUT, + Constants.Kwargs.CONNECTION_TIMEOUT, + Constants.Kwargs.TIMEOUT, + ) + + +def _make_sync_ctx(client, options): + return _ProxyQueryExecutionContext( + client, + "dbs/db/colls/coll", + "SELECT * FROM c", + options, + _noop_fetch, + None, + None, + "docs", + ) + + +def _make_async_ctx(client, options): + return _AsyncProxyQueryExecutionContext( + client, + "dbs/db/colls/coll", + "SELECT * FROM c", + options, + _noop_fetch, + None, + None, + "docs", + ) + + +class TestQueryPlanDispatcherForwarding(unittest.TestCase): + """The sync query-plan dispatcher forwards the per-call timeouts, and does + not pass connection_timeout or timeout as None when the caller left them + unset.""" + + def test_forwards_all_three_when_set(self): + client = _RecordingQueryPlanClient() + ctx = _make_sync_ctx(client, {"read_timeout": 30, "connection_timeout": 0.5, "timeout": 2}) + with pytest.raises(_StopBeforePipeline): + ctx._create_execution_context_with_query_plan() + assert client.captured_kwargs == { + "read_timeout": 30, + "connection_timeout": 0.5, + "timeout": 2, + } + + def test_omits_connection_timeout_and_timeout_when_unset(self): + # connection_timeout and timeout must not be forwarded as None: they go + # straight to the request as kwargs, where None would override the client + # default. read_timeout is still passed as-is. + client = _RecordingQueryPlanClient() + ctx = _make_sync_ctx(client, {}) + with pytest.raises(_StopBeforePipeline): + ctx._create_execution_context_with_query_plan() + assert client.captured_kwargs == {"read_timeout": None} + assert "connection_timeout" not in client.captured_kwargs + assert "timeout" not in client.captured_kwargs + + def test_forwards_only_connection_timeout_when_only_it_is_set(self): + client = _RecordingQueryPlanClient() + ctx = _make_sync_ctx(client, {"connection_timeout": 0.5}) + with pytest.raises(_StopBeforePipeline): + ctx._create_execution_context_with_query_plan() + assert client.captured_kwargs == {"read_timeout": None, "connection_timeout": 0.5} + + +class TestAsyncQueryPlanDispatcherForwarding(unittest.IsolatedAsyncioTestCase): + """The async query-plan dispatcher has the same contract as the sync one.""" + + async def test_forwards_all_three_when_set(self): + client = _AsyncRecordingQueryPlanClient() + ctx = _make_async_ctx(client, {"read_timeout": 30, "connection_timeout": 0.5, "timeout": 2}) + with pytest.raises(_StopBeforePipeline): + await ctx._create_execution_context_with_query_plan() + assert client.captured_kwargs == { + "read_timeout": 30, + "connection_timeout": 0.5, + "timeout": 2, + } + + async def test_omits_connection_timeout_and_timeout_when_unset(self): + client = _AsyncRecordingQueryPlanClient() + ctx = _make_async_ctx(client, {}) + with pytest.raises(_StopBeforePipeline): + await ctx._create_execution_context_with_query_plan() + assert client.captured_kwargs == {"read_timeout": None} + + +class TestDeadlineAnchorCarry(unittest.TestCase): + """The carry must move OperationStartTime as well as the three timeouts, so the + /pkranges and query-plan setup calls measure the deadline from the operation's + start instead of their own.""" + + def test_deadline_keys_extend_timeout_keys_with_anchor(self): + assert _base._PER_CALL_DEADLINE_OPTION_KEYS == ( + Constants.Kwargs.READ_TIMEOUT, + Constants.Kwargs.CONNECTION_TIMEOUT, + Constants.Kwargs.TIMEOUT, + Constants.OperationStartTime, + ) + + def test_helper_carries_operation_start_time(self): + destination = {} + _base._carry_per_call_timeout_options( + {Constants.OperationStartTime: 123.0, "read_timeout": 30}, destination + ) + assert destination[Constants.OperationStartTime] == 123.0 + assert destination["read_timeout"] == 30 + + def test_helper_omits_operation_start_time_when_absent(self): + destination = {} + _base._carry_per_call_timeout_options({"timeout": 2}, destination) + assert Constants.OperationStartTime not in destination + + def test_format_pk_range_options_carries_anchor_and_timers(self): + options = { + Constants.ContainerRID: "rid", + "read_timeout": 30, + "connection_timeout": 0.5, + "timeout": 2, + Constants.OperationStartTime: 123.0, + } + pk = _base.format_pk_range_options(options) + assert pk[Constants.ContainerRID] == "rid" + assert pk["read_timeout"] == 30 + assert pk["connection_timeout"] == 0.5 + assert pk["timeout"] == 2 + assert pk[Constants.OperationStartTime] == 123.0 + + def test_format_pk_range_options_omits_unset(self): + pk = _base.format_pk_range_options({Constants.ContainerRID: "rid"}) + assert Constants.OperationStartTime not in pk + assert "read_timeout" not in pk + assert "timeout" not in pk + + +class TestQueryPlanDeadlineAnchorSync(unittest.TestCase): + """The sync query-plan dispatcher forwards OperationStartTime when set (so the + deadline is measured from the shared start) and omits it when unset (so the + request layer default is not overwritten).""" + + def test_forwards_operation_start_time_when_set(self): + client = _RecordingQueryPlanClient() + ctx = _make_sync_ctx(client, {"timeout": 2, Constants.OperationStartTime: 123.0}) + with pytest.raises(_StopBeforePipeline): + ctx._create_execution_context_with_query_plan() + assert client.captured_kwargs["timeout"] == 2 + assert client.captured_kwargs[Constants.OperationStartTime] == 123.0 + + def test_omits_operation_start_time_when_unset(self): + client = _RecordingQueryPlanClient() + ctx = _make_sync_ctx(client, {"timeout": 2}) + with pytest.raises(_StopBeforePipeline): + ctx._create_execution_context_with_query_plan() + assert Constants.OperationStartTime not in client.captured_kwargs + + +class TestQueryPlanDeadlineAnchorAsync(unittest.IsolatedAsyncioTestCase): + """The async query-plan dispatcher has the same contract as the sync one.""" + + async def test_forwards_operation_start_time_when_set(self): + client = _AsyncRecordingQueryPlanClient() + ctx = _make_async_ctx(client, {"timeout": 2, Constants.OperationStartTime: 123.0}) + with pytest.raises(_StopBeforePipeline): + await ctx._create_execution_context_with_query_plan() + assert client.captured_kwargs["timeout"] == 2 + assert client.captured_kwargs[Constants.OperationStartTime] == 123.0 + + async def test_omits_operation_start_time_when_unset(self): + client = _AsyncRecordingQueryPlanClient() + ctx = _make_async_ctx(client, {"timeout": 2}) + with pytest.raises(_StopBeforePipeline): + await ctx._create_execution_context_with_query_plan() + assert Constants.OperationStartTime not in client.captured_kwargs + + +class TestContainerReadForwarding(unittest.TestCase): + """``_get_properties_with_options`` copies ``connection_timeout`` and + ``OperationStartTime`` (plus ``read_timeout`` / ``timeout`` / + ``excludedLocations``) from the options dict into the kwargs it hands to the + container read. Exercised without a live client by stubbing ``_get_properties`` + to capture the kwargs.""" + + @staticmethod + def _capture(options): + proxy = ContainerProxy.__new__(ContainerProxy) + captured = {} + proxy._get_properties = lambda **kwargs: captured.update(kwargs) or {} + proxy._get_properties_with_options(options) + return captured + + def test_forwards_all_timers_and_anchor(self): + captured = self._capture({ + Constants.Kwargs.CONNECTION_TIMEOUT: 0.5, + Constants.Kwargs.READ_TIMEOUT: 30, + Constants.Kwargs.TIMEOUT: 2, + Constants.OperationStartTime: 123.0, + "excludedLocations": ["West US"], + }) + assert captured[Constants.Kwargs.CONNECTION_TIMEOUT] == 0.5 + assert captured[Constants.Kwargs.READ_TIMEOUT] == 30 + assert captured[Constants.Kwargs.TIMEOUT] == 2 + assert captured[Constants.OperationStartTime] == 123.0 + assert captured["excluded_locations"] == ["West US"] + + def test_omits_connection_timeout_when_unset(self): + captured = self._capture({Constants.Kwargs.READ_TIMEOUT: 30}) + assert Constants.Kwargs.CONNECTION_TIMEOUT not in captured + assert captured[Constants.Kwargs.READ_TIMEOUT] == 30 + + +class _RecordingReadPKRangesClient: + """Captures the ``feed_options`` the hybrid all-ranges path hands to + ``_ReadPartitionKeyRanges``.""" + + def __init__(self): + self.captured_feed_options = None + + def _ReadPartitionKeyRanges(self, collection_link, feed_options=None, **kwargs): # noqa: N802 + self.captured_feed_options = feed_options + return [] + + +class _AsyncRecordingReadPKRangesClient: + """Async counterpart for hybrid all-ranges feed-options capture.""" + + def __init__(self): + self.captured_feed_options = None + + def _ReadPartitionKeyRanges(self, collection_link, feed_options=None, **kwargs): # noqa: N802 + self.captured_feed_options = feed_options + + async def _empty_async_iter(): + if False: + yield None + + return _empty_async_iter() + + +class TestHybridAllRangesCarry(unittest.TestCase): + """The hybrid-search all-ranges ``/pkranges`` fetch builds ``feed_options`` by + hand instead of using ``format_pk_range_options``, so it must still carry the + timeouts and ``OperationStartTime`` through the shared helper.""" + + def _capture_feed_options(self, options): + agg = _HybridSearchContextAggregator.__new__(_HybridSearchContextAggregator) + client = _RecordingReadPKRangesClient() + agg._client = client + agg._resource_link = "dbs/db/colls/coll" + agg._options = options + agg._get_target_partition_key_range(target_all_ranges=True) + return client.captured_feed_options + + def test_all_ranges_feed_options_carries_timers_and_anchor(self): + fo = self._capture_feed_options({ + Constants.ContainerRID: "rid", + "read_timeout": 30, + "connection_timeout": 0.5, + "timeout": 2, + Constants.OperationStartTime: 123.0, + }) + assert fo[Constants.ContainerRID] == "rid" + assert fo["read_timeout"] == 30 + assert fo["connection_timeout"] == 0.5 + assert fo["timeout"] == 2 + assert fo[Constants.OperationStartTime] == 123.0 + + def test_all_ranges_feed_options_omits_unset_timers(self): + fo = self._capture_feed_options({Constants.ContainerRID: "rid"}) + assert fo[Constants.ContainerRID] == "rid" + assert "read_timeout" not in fo + assert "connection_timeout" not in fo + assert "timeout" not in fo + assert Constants.OperationStartTime not in fo + + +class TestAsyncHybridAllRangesCarry(unittest.IsolatedAsyncioTestCase): + """Async hybrid-search all-ranges `/pkranges` carry has the same contract as sync.""" + + async def _capture_feed_options(self, options): + agg = _AsyncHybridSearchContextAggregator.__new__(_AsyncHybridSearchContextAggregator) + client = _AsyncRecordingReadPKRangesClient() + agg._client = client + agg._resource_link = "dbs/db/colls/coll" + agg._options = options + await agg._get_target_partition_key_range(target_all_ranges=True) + return client.captured_feed_options + + async def test_all_ranges_feed_options_carries_timers_and_anchor(self): + fo = await self._capture_feed_options({ + Constants.ContainerRID: "rid", + "read_timeout": 30, + "connection_timeout": 0.5, + "timeout": 2, + Constants.OperationStartTime: 123.0, + }) + assert fo[Constants.ContainerRID] == "rid" + assert fo["read_timeout"] == 30 + assert fo["connection_timeout"] == 0.5 + assert fo["timeout"] == 2 + assert fo[Constants.OperationStartTime] == 123.0 + + async def test_all_ranges_feed_options_omits_unset_timers(self): + fo = await self._capture_feed_options({Constants.ContainerRID: "rid"}) + assert fo[Constants.ContainerRID] == "rid" + assert "read_timeout" not in fo + assert "connection_timeout" not in fo + assert "timeout" not in fo + assert Constants.OperationStartTime not in fo + + +class TestCopyPerCallTimeoutsToKwargs(unittest.TestCase): + """The shared helper that copies the per-call timeouts and the operation start time + into kwargs, used by the container read, the page fetch, and the query-plan + dispatcher.""" + + def test_copies_present_values(self): + kwargs = {} + _base._copy_per_call_timeouts_to_kwargs( + {"read_timeout": 30, "connection_timeout": 0.5, "timeout": 2, + Constants.OperationStartTime: 123.0}, kwargs) + assert kwargs == { + "read_timeout": 30, + "connection_timeout": 0.5, + "timeout": 2, + Constants.OperationStartTime: 123.0, + } + + def test_does_not_copy_none_values(self): + # A present-but-None timer must not be copied: forwarding None would make + # _Request's kwargs.pop(name, default) return None and override the default. + kwargs = {} + _base._copy_per_call_timeouts_to_kwargs({"read_timeout": None, "timeout": 2}, kwargs) + assert "read_timeout" not in kwargs + assert kwargs["timeout"] == 2 + + def test_setdefault_existing_kwarg_wins(self): + kwargs = {"read_timeout": 99} + _base._copy_per_call_timeouts_to_kwargs({"read_timeout": 30}, kwargs) + assert kwargs["read_timeout"] == 99 + + def test_none_or_empty_options_is_noop(self): + kwargs = {} + _base._copy_per_call_timeouts_to_kwargs(None, kwargs) + _base._copy_per_call_timeouts_to_kwargs({}, kwargs) + assert kwargs == {} + + +if __name__ == "__main__": + unittest.main() From d63f52b4a314ba2273ab25fc13aa40065cee9763 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Tue, 16 Jun 2026 17:04:07 -0500 Subject: [PATCH 2/2] cleaning up comments --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 25 +-- .../aio/execution_dispatcher.py | 9 +- .../execution_dispatcher.py | 9 +- .../azure/cosmos/aio/_retry_utility_async.py | 12 +- ...test_metadata_timeout_propagation_async.py | 164 ++++++++++++++++++ .../tests/test_retry_utility_deadline_unit.py | 151 ++++++++++++++++ .../tests/test_timeout_propagation_unit.py | 129 ++++++++++++-- 7 files changed, 459 insertions(+), 40 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_metadata_timeout_propagation_async.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_retry_utility_deadline_unit.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 9a1517e3dfc8..f843bebdcdce 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -1087,39 +1087,40 @@ def _build_properties_cache(properties: dict[str, Any], container_link: str) -> "partitionKey": properties.get("partitionKey", None), "container_link": container_link } -# The per-call timeout keys a caller can set on a single request. Listed once -# here so format_pk_range_options and the hybrid-search fetch forward the same -# set. +# The three per-call timeout keys a caller can set on one request. The deadline +# tuple below adds OperationStartTime to these; the carry helpers iterate that +# 4-key tuple, not this one. _PER_CALL_TIMEOUT_OPTION_KEYS: Tuple[str, ...] = ( Constants.Kwargs.READ_TIMEOUT, Constants.Kwargs.CONNECTION_TIMEOUT, Constants.Kwargs.TIMEOUT, ) -# The operation deadline is checked as elapsed = now - OperationStartTime, and -# OperationStartTime defaults to the current time when it is missing. So timeout -# and OperationStartTime must be carried together onto the metadata setup calls; -# otherwise a setup call measures the deadline from its own start instead of the -# operation's start. This adds OperationStartTime to the three timeout keys above. +# timeout and OperationStartTime must travel together: the deadline is checked as +# now - OperationStartTime, which defaults to now when missing, so a metadata call +# without it would measure from its own start, not the operation's. _PER_CALL_DEADLINE_OPTION_KEYS: Tuple[str, ...] = _PER_CALL_TIMEOUT_OPTION_KEYS + ( Constants.OperationStartTime, ) -def _carry_per_call_timeout_options(source: Mapping[str, Any], destination: dict[str, Any]) -> None: +def _carry_per_call_timeout_options(source: Optional[Mapping[str, Any]], destination: dict[str, Any]) -> None: """Copy the per-call timeouts and the operation start time from source into destination. Copies read_timeout, connection_timeout, timeout, and OperationStartTime. Only keys present in source are copied, so a timeout the caller did not set stays - absent and the request uses the client default instead of None. + absent and the request uses the client default instead of None. A None or empty + source is a no-op. - :param source: The request options to read the timeouts from. - :type source: ~collections.abc.Mapping[str, typing.Any] + :param source: The request options to read the timeouts from (may be None or empty). + :type source: ~collections.abc.Mapping[str, typing.Any] or None :param destination: The options dict to copy the timeouts into. :type destination: dict[str, typing.Any] :return: None :rtype: None """ + if not source: + return for key in _PER_CALL_DEADLINE_OPTION_KEYS: if key in source: destination[key] = source[key] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index ce5ca26d2f93..e485a986535c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -68,11 +68,10 @@ def __init__(self, client, resource_link, query, options, fetch_function, async def _create_execution_context_with_query_plan(self): self._fetched_query_plan = True query_to_use = self._query if self._query is not None else "Select * from root r" - # read_timeout is forwarded as-is (None when the caller did not set it) to - # keep its existing behavior. It is set before the helper, so the helper's - # setdefault leaves it unchanged and only adds connection_timeout, timeout, - # and OperationStartTime when the caller set them. - query_plan_kwargs = {"read_timeout": self._options.get('read_timeout')} + # Forward the per-call timeouts and OperationStartTime only when the caller + # set them, so an unset value falls back to the client/policy default + # instead of overriding it with None. + query_plan_kwargs = {} _base._copy_per_call_timeouts_to_kwargs(self._options, query_plan_kwargs) query_plan = await self._client._GetQueryPlanThroughGateway( query_to_use, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index aacd538d7a61..67a6836e4757 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -98,11 +98,10 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon def _create_execution_context_with_query_plan(self): self._fetched_query_plan = True query_to_use = self._query if self._query is not None else "Select * from root r" - # read_timeout is forwarded as-is (None when the caller did not set it) to - # keep its existing behavior. It is set before the helper, so the helper's - # setdefault leaves it unchanged and only adds connection_timeout, timeout, - # and OperationStartTime when the caller set them. - query_plan_kwargs = {"read_timeout": self._options.get('read_timeout')} + # Forward the per-call timeouts and OperationStartTime only when the caller + # set them, so an unset value falls back to the client/policy default + # instead of overriding it with None. + query_plan_kwargs = {} _base._copy_per_call_timeouts_to_kwargs(self._options, query_plan_kwargs) query_plan = self._client._GetQueryPlanThroughGateway( query_to_use, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index a6d44a699cb8..1555d846b92e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -140,11 +140,13 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg await _record_success_if_request_not_cancelled(args[0], global_endpoint_manager, pk_range_wrapper) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) - # Check timeout after successful execution - if timeout: - elapsed = time.time() - operation_start_time - if elapsed >= timeout: - raise exceptions.CosmosClientTimeoutError(error=last_error) + # Check the deadline after a successful call. Outside the if/else so it + # also covers the normal request path (if args), matching the sync loop: + # a call that succeeds after the deadline passed must still raise. + if timeout: + elapsed = time.time() - operation_start_time + if elapsed >= timeout: + raise exceptions.CosmosClientTimeoutError(error=last_error) if not client.last_response_headers: client.last_response_headers = {} diff --git a/sdk/cosmos/azure-cosmos/tests/test_metadata_timeout_propagation_async.py b/sdk/cosmos/azure-cosmos/tests/test_metadata_timeout_propagation_async.py new file mode 100644 index 000000000000..32acf5a522a8 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_metadata_timeout_propagation_async.py @@ -0,0 +1,164 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Async fault-injection tests for per-call timeout propagation to the metadata +calls a query makes before its first page (container read, ``/pkranges``, query +plan). Async counterpart of ``test_metadata_timeout_propagation.py``, reusing its +request classifier so both suites bucket requests the same way. + +* Cold start -- a per-call ``read_timeout`` / ``connection_timeout`` reaches the + three metadata calls; the forced-short account probe does not inherit them. +* Post-split -- the re-issued ``/pkranges`` fetch still carries them. +* Operation deadline -- a delayed query plan makes the query raise + ``CosmosClientTimeoutError`` before the ``/pkranges`` fan-out goes out. +""" + +import asyncio +import unittest + +import pytest + +import test_config +from azure.cosmos.aio import CosmosClient +from azure.cosmos import exceptions + +from _fault_injection_transport_async import FaultInjectionTransportAsync +# Reuse the sync suite's request classifier so both buckets requests identically. +from test_metadata_timeout_propagation import _classify_request, _is_query_plan + + +class _RecordingFaultTransportAsync(FaultInjectionTransportAsync): + """Records the per-request ``connection_timeout`` / ``read_timeout`` handed + to the transport, and optionally awaits a delay before a matching request so + the operation deadline can be exercised without depending on real latency.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.records = [] + self._delays = [] + + def add_delay(self, predicate, seconds): + self._delays.append((predicate, seconds)) + + async def send(self, request, *, stream=False, proxies=None, **config): + self.records.append({ + "kind": _classify_request(request), + "url": request.url, + "connection_timeout": config.get("connection_timeout"), + "read_timeout": config.get("read_timeout"), + }) + for predicate, seconds in self._delays: + if predicate(request): + await asyncio.sleep(seconds) + break + return await super().send(request, stream=stream, proxies=proxies, **config) + + def records_for(self, kind): + return [r for r in self.records if r["kind"] == kind] + + +@pytest.mark.cosmosEmulator +class TestMetadataTimeoutPropagationAsync(unittest.IsolatedAsyncioTestCase): + """End-to-end propagation of per-call timeouts to the metadata setup calls + on the asynchronous client.""" + + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID + + # A cross-partition aggregate forces the full pipeline: a query-plan fetch + # then a /pkranges fan-out. The async query_items enables cross-partition on + # its own (no partition key) and does not accept enable_cross_partition_query, + # which would leak to the transport -- so, unlike the sync sibling, we omit it. + QUERY = "SELECT VALUE COUNT(1) FROM c" + + # Per-call values that are deliberately not the client/policy defaults, so a + # match proves the caller's value -- not the default -- reached the wire. + READ_TIMEOUT = 33 + CONNECTION_TIMEOUT = 7 + + def _cold_client(self, transport): + # A fresh client => cold container-properties and routing caches, so the + # setup calls actually go on the wire. + return CosmosClient(self.host, self.master_key, transport=transport) + + async def test_cold_start_per_call_timeouts_reach_setup_calls(self): + transport = _RecordingFaultTransportAsync() + async with self._cold_client(transport) as client: + container = client.get_database_client(self.TEST_DATABASE_ID).get_container_client( + self.TEST_CONTAINER_ID) + _ = [item async for item in container.query_items( + query=self.QUERY, + read_timeout=self.READ_TIMEOUT, + connection_timeout=self.CONNECTION_TIMEOUT, + )] + + for kind in ("container_read", "query_plan", "pkranges"): + recs = transport.records_for(kind) + self.assertTrue(recs, "expected at least one {} request on a cold client".format(kind)) + for r in recs: + self.assertEqual( + r["read_timeout"], self.READ_TIMEOUT, + "{} dropped the per-call read_timeout (got {})".format(kind, r["read_timeout"])) + self.assertEqual( + r["connection_timeout"], self.CONNECTION_TIMEOUT, + "{} dropped the per-call connection_timeout (got {})".format(kind, r["connection_timeout"])) + + # The forced-short failover probe must never inherit a caller's per-call + # values -- that is what keeps a generous read_timeout from slowing + # regional failover. + for r in transport.records_for("account_probe"): + self.assertNotEqual(r["read_timeout"], self.READ_TIMEOUT) + self.assertNotEqual(r["connection_timeout"], self.CONNECTION_TIMEOUT) + + async def test_post_split_pkranges_refresh_carries_per_call_timeouts(self): + transport = _RecordingFaultTransportAsync() + async with self._cold_client(transport) as client: + container = client.get_database_client(self.TEST_DATABASE_ID).get_container_client( + self.TEST_CONTAINER_ID) + # Warm the caches with default timeouts. + _ = [item async for item in container.query_items(query=self.QUERY)] + + # Simulate the post-split refresh: clearing the routing map forces the + # next query to re-issue /pkranges, exactly as the 410-Gone path does. + client.client_connection._routing_map_provider.clear_cache() # pylint: disable=protected-access + transport.records.clear() + + _ = [item async for item in container.query_items( + query=self.QUERY, + read_timeout=44, + connection_timeout=9, + )] + + pkranges = transport.records_for("pkranges") + self.assertTrue(pkranges, "a /pkranges refresh should occur after the routing cache is cleared") + for r in pkranges: + self.assertEqual(r["read_timeout"], 44, + "the post-split /pkranges refresh dropped the per-call read_timeout") + self.assertEqual(r["connection_timeout"], 9, + "the post-split /pkranges refresh dropped the per-call connection_timeout") + + async def test_operation_deadline_halts_setup_phase(self): + transport = _RecordingFaultTransportAsync() + # Delay the query-plan fetch past the operation deadline. With timeout=1 + # and a 2s query plan, the setup phase exceeds the budget, so the query + # raises CosmosClientTimeoutError before it issues the /pkranges fan-out. + transport.add_delay(_is_query_plan, 2.0) + async with self._cold_client(transport) as client: + container = client.get_database_client(self.TEST_DATABASE_ID).get_container_client( + self.TEST_CONTAINER_ID) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + _ = [item async for item in container.query_items( + query=self.QUERY, timeout=1)] + + # The deadline tripped during the setup phase: the /pkranges fan-out + # (which follows the query plan) never went out. + self.assertEqual(transport.records_for("pkranges"), [], + "the operation deadline should halt the query during the setup phase, " + "before the /pkranges fan-out is issued") + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_retry_utility_deadline_unit.py b/sdk/cosmos/azure-cosmos/tests/test_retry_utility_deadline_unit.py new file mode 100644 index 000000000000..08767799e01e --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_retry_utility_deadline_unit.py @@ -0,0 +1,151 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Unit tests for the operation-deadline check in the retry loop +(``_retry_utility.Execute`` / ``aio._retry_utility_async.ExecuteAsync``). + +The deadline check after a successful call must run for the normal request flow +(the ``if args`` path) in both the sync and async clients, so a call that +succeeds after ``timeout`` has already passed raises ``CosmosClientTimeoutError`` +instead of returning a late result. The async loop used to run that check only +on the no-args (change-feed callback) path; these tests check both now match. + +The tests are network-free and deterministic: a fake clock is advanced inside +the mocked ``ExecuteFunction`` so the before-call check passes (elapsed == 0) +while the after-call check sees the overrun, with no wall-clock sleeps. +""" + +import types +import unittest +from unittest import mock + +from azure.cosmos import documents, exceptions +from azure.cosmos import _retry_utility +from azure.cosmos.aio import _retry_utility_async +from azure.cosmos._constants import _Constants as Constants +from azure.cosmos._request_object import RequestObject +from azure.cosmos.documents import _OperationType +from azure.cosmos.http_constants import ResourceType + + +class _FakeClient: + """Minimal stand-in for CosmosClientConnection sufficient for the retry + loop's policy construction and success-path bookkeeping (uses a real + ConnectionPolicy so every ``connection_policy.*`` access is a real value).""" + + def __init__(self): + self.connection_policy = documents.ConnectionPolicy() + self._container_properties_cache = {} + self.last_response_headers = {} + self._enable_diagnostics_logging = False + self.session = None + + def _UpdateSessionIfRequired(self, *_args, **_kwargs): + pass + + +def _make_gem(*, is_async): + """A global-endpoint-manager mock that keeps the retry loop on its simple + path: no circuit breaker, no per-partition failover, single write location. + Other methods the retry policies call are auto-stubbed by MagicMock; + record_success is awaited on the async path, so it is an AsyncMock there.""" + gem = mock.MagicMock() + gem.is_per_partition_automatic_failover_applicable.return_value = False + gem.is_circuit_breaker_applicable.return_value = False + gem.can_use_multiple_write_locations.return_value = False + if is_async: + gem.record_success = mock.AsyncMock() + return gem + + +def _make_args(client): + """Build the args tuple in the exact shape the real callers pass: + ``(request_params, connection_policy, pipeline_client, request)`` -- so all + the retry-policy constructors receive what they expect.""" + request_params = RequestObject(ResourceType.Document, _OperationType.Read, {}, None) + fake_request = types.SimpleNamespace(method="GET", headers={}, body=None) + return (request_params, client.connection_policy, None, fake_request) + + +class _Clock: + """A tiny mutable clock the mocked ExecuteFunction advances mid-call.""" + + def __init__(self, start): + self.now = start + + def __call__(self): + return self.now + + +_START = 1000.0 +_TIMEOUT = 5.0 + + +def _unused_request_fn(*_args, **_kwargs): + """Sentinel passed as the retry loop's ``function``; never invoked because + ``ExecuteFunction``/``ExecuteFunctionAsync`` is mocked in these tests.""" + return ([], {}) + + +class TestRetryUtilityPostSuccessDeadlineSync(unittest.TestCase): + """Sync ``Execute``: a successful call that overran the deadline raises; + a successful call within the deadline returns normally.""" + + def _run(self, advance_during_call): + clock = _Clock(_START) + client = _FakeClient() + gem = _make_gem(is_async=False) + args = _make_args(client) + + def _mock_execute(_function, *_a, **_k): + clock.now += advance_during_call + return ([], {}) + + kwargs = {Constants.OperationStartTime: _START, "timeout": _TIMEOUT} + with mock.patch.object(_retry_utility.time, "time", clock), \ + mock.patch.object(_retry_utility, "ExecuteFunction", _mock_execute): + return _retry_utility.Execute(client, gem, _unused_request_fn, *args, **kwargs) + + def test_successful_call_past_deadline_raises(self): + # Pre-check sees elapsed == 0 (passes); the call advances the clock past + # the deadline, so the post-success check raises. + with self.assertRaises(exceptions.CosmosClientTimeoutError): + self._run(advance_during_call=_TIMEOUT + 100.0) + + def test_successful_call_within_deadline_returns(self): + # No overrun: the post-success check must not raise on the happy path. + result = self._run(advance_during_call=0.0) + assert result == ([], {}) + + +class TestRetryUtilityPostSuccessDeadlineAsync(unittest.IsolatedAsyncioTestCase): + """Async ``ExecuteAsync`` must match the sync behavior on the ``if args`` + path -- the case the fix restores.""" + + async def _run(self, advance_during_call): + clock = _Clock(_START) + client = _FakeClient() + gem = _make_gem(is_async=True) + args = _make_args(client) + + async def _mock_execute_async(_function, *_a, **_k): + clock.now += advance_during_call + return ([], {}) + + kwargs = {Constants.OperationStartTime: _START, "timeout": _TIMEOUT} + with mock.patch.object(_retry_utility_async.time, "time", clock), \ + mock.patch.object(_retry_utility_async, "ExecuteFunctionAsync", _mock_execute_async): + return await _retry_utility_async.ExecuteAsync(client, gem, _unused_request_fn, *args, **kwargs) + + async def test_successful_call_past_deadline_raises(self): + with self.assertRaises(exceptions.CosmosClientTimeoutError): + await self._run(advance_during_call=_TIMEOUT + 100.0) + + async def test_successful_call_within_deadline_returns(self): + result = await self._run(advance_during_call=0.0) + assert result == ([], {}) + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_timeout_propagation_unit.py b/sdk/cosmos/azure-cosmos/tests/test_timeout_propagation_unit.py index fb8203491d6b..e1cf43ead4c2 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_timeout_propagation_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_timeout_propagation_unit.py @@ -11,20 +11,24 @@ * ``_carry_per_call_timeout_options`` forwards only the timeouts actually set, plus the operation start time (``OperationStartTime``) when present. * ``format_pk_range_options`` carries the timers and the operation start time. -* The query-plan dispatcher (sync + async) forwards ``connection_timeout``, - ``timeout`` and ``OperationStartTime`` only when set -- never as ``None`` -- - so an unset value cannot override the client default in the request layer. +* The query-plan dispatcher (sync + async) forwards the per-call timeouts and + ``OperationStartTime`` only when set -- never as ``None`` -- so an unset value + cannot override the client/policy default in the request layer. * The container read (``_get_properties_with_options``) copies the timeouts and the operation start time from options into the kwargs it hands down. """ import unittest +from unittest import mock import pytest from azure.cosmos import _base +from azure.cosmos import http_constants from azure.cosmos._constants import _Constants as Constants from azure.cosmos.container import ContainerProxy +from azure.cosmos._cosmos_client_connection import CosmosClientConnection as _SyncCosmosClientConnection +from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection as _AsyncCosmosClientConnection from azure.cosmos._execution_context.execution_dispatcher import _ProxyQueryExecutionContext from azure.cosmos._execution_context.aio.execution_dispatcher import ( _ProxyQueryExecutionContext as _AsyncProxyQueryExecutionContext, @@ -106,6 +110,14 @@ def test_empty_source_leaves_destination_untouched(self): _base._carry_per_call_timeout_options({}, destination) assert destination == {"containerRID": "rid"} + def test_none_source_is_noop(self): + # A None source must be a no-op rather than raising, mirroring the + # options->kwargs helper. Callers today never pass None, so this guards + # against a latent TypeError if a future caller does. + destination = {"containerRID": "rid"} + _base._carry_per_call_timeout_options(None, destination) + assert destination == {"containerRID": "rid"} + def test_keys_are_the_three_per_call_timeouts(self): assert _base._PER_CALL_TIMEOUT_OPTION_KEYS == ( Constants.Kwargs.READ_TIMEOUT, @@ -156,24 +168,22 @@ def test_forwards_all_three_when_set(self): "timeout": 2, } - def test_omits_connection_timeout_and_timeout_when_unset(self): - # connection_timeout and timeout must not be forwarded as None: they go - # straight to the request as kwargs, where None would override the client - # default. read_timeout is still passed as-is. + def test_omits_all_timeouts_when_unset(self): + # No timer is forwarded as None: each goes to the request as a kwarg, where + # None would override the client default. An unset timer is omitted so + # _Request falls back to the policy default. client = _RecordingQueryPlanClient() ctx = _make_sync_ctx(client, {}) with pytest.raises(_StopBeforePipeline): ctx._create_execution_context_with_query_plan() - assert client.captured_kwargs == {"read_timeout": None} - assert "connection_timeout" not in client.captured_kwargs - assert "timeout" not in client.captured_kwargs + assert client.captured_kwargs == {} def test_forwards_only_connection_timeout_when_only_it_is_set(self): client = _RecordingQueryPlanClient() ctx = _make_sync_ctx(client, {"connection_timeout": 0.5}) with pytest.raises(_StopBeforePipeline): ctx._create_execution_context_with_query_plan() - assert client.captured_kwargs == {"read_timeout": None, "connection_timeout": 0.5} + assert client.captured_kwargs == {"connection_timeout": 0.5} class TestAsyncQueryPlanDispatcherForwarding(unittest.IsolatedAsyncioTestCase): @@ -190,12 +200,12 @@ async def test_forwards_all_three_when_set(self): "timeout": 2, } - async def test_omits_connection_timeout_and_timeout_when_unset(self): + async def test_omits_all_timeouts_when_unset(self): client = _AsyncRecordingQueryPlanClient() ctx = _make_async_ctx(client, {}) with pytest.raises(_StopBeforePipeline): await ctx._create_execution_context_with_query_plan() - assert client.captured_kwargs == {"read_timeout": None} + assert client.captured_kwargs == {} class TestDeadlineAnchorCarry(unittest.TestCase): @@ -458,5 +468,98 @@ def test_none_or_empty_options_is_noop(self): assert kwargs == {} +def _make_query_feed_conn(connection_cls, get_fn): + """Build a connection stub so __QueryFeed can reach __Get without network or + header setup. Returns the bound (name-mangled) __QueryFeed to call.""" + conn = connection_cls.__new__(connection_cls) + conn.default_headers = {} + conn.availability_strategy = None + # Sync reads availability_strategy_executor; async reads + # availability_strategy_max_concurrency. Set both so one helper serves both. + conn.availability_strategy_executor = None + conn.availability_strategy_max_concurrency = None + # _UpdateSessionIfRequired runs after __Get on the async ReadFeed branch. + conn._UpdateSessionIfRequired = lambda *_args, **_kwargs: None + # __Get is name-mangled; both sync and async classes are named + # CosmosClientConnection, so the mangled attribute name is identical. + setattr(conn, "_CosmosClientConnection__Get", get_fn) + return getattr(conn, "_CosmosClientConnection__QueryFeed") + + +_QUERY_FEED_OPTIONS = { + "read_timeout": 30, + "connection_timeout": 0.5, + "timeout": 2, + Constants.OperationStartTime: 123.0, +} + + +def _assert_query_feed_timers_lifted(captured_kwargs): + assert captured_kwargs["read_timeout"] == 30 + assert captured_kwargs["connection_timeout"] == 0.5 + assert captured_kwargs["timeout"] == 2 + assert captured_kwargs[Constants.OperationStartTime] == 123.0 + + +class TestSyncQueryFeedLift(unittest.TestCase): + """__QueryFeed is where both the /pkranges fetch and the query plan copy the + per-call timeouts from options into the kwargs _Request reads. This drives + that copy through the ReadFeed branch into __Get, which the helper tests + above do not cover.""" + + def test_queryfeed_lifts_timers_from_options_into_get_kwargs(self): + captured_kwargs = {} + + def _fake_get(_path, _request_params, _headers, **kwargs): + captured_kwargs.update(kwargs) + return {}, {} + + query_feed = _make_query_feed_conn(_SyncCosmosClientConnection, _fake_get) + with mock.patch.object(_base, "GetHeaders", return_value={}), \ + mock.patch.object(_base, "set_session_token_header", return_value=None): + # query=None drives the ReadFeed branch, which ends in __Get. + query_feed( + "dbs/db/colls/coll/pkranges", + http_constants.ResourceType.PartitionKeyRange, + "rid", + lambda _r: [], + lambda _client, body: body, + None, + _QUERY_FEED_OPTIONS, + ) + + _assert_query_feed_timers_lifted(captured_kwargs) + + +class TestAsyncQueryFeedLift(unittest.IsolatedAsyncioTestCase): + """Async __QueryFeed makes the same options-to-kwargs copy as the sync one.""" + + async def test_queryfeed_lifts_timers_from_options_into_get_kwargs(self): + + captured_kwargs = {} + + async def _fake_get(_path, _request_params, _headers, **kwargs): + captured_kwargs.update(kwargs) + return {}, {} + + async def _noop_session_async(*_args, **_kwargs): + return None + + query_feed = _make_query_feed_conn(_AsyncCosmosClientConnection, _fake_get) + with mock.patch.object(_base, "GetHeaders", return_value={}), \ + mock.patch.object(_base, "set_session_token_header_async", new=_noop_session_async): + await query_feed( + "dbs/db/colls/coll/pkranges", + http_constants.ResourceType.PartitionKeyRange, + "rid", + lambda _r: [], + lambda _client, body: body, + None, + _QUERY_FEED_OPTIONS, + ) + + _assert_query_feed_timers_lifted(captured_kwargs) + + if __name__ == "__main__": unittest.main()