Skip to content

Commit

Permalink
[Cosmos] Service Request and Response retries (#39396)
Browse files Browse the repository at this point in the history
* implementation

* Update _retry_utility_async.py

* changelog, versions, fixes

* fixes

* remove fake logic, count fix

* Update _service_request_retry_policy.py

* Update _retry_utility_async.py

* retry utilities fixing

* Update _retry_utility.py

* additional enhancements

* Update setup.py

* Update _retry_utility_async.py

* add tests, remove previous retry logic for ServiceRequestExceptions

* clean up with finally

* tests

* retry utilities

* disable tests

* add logging to policies

* GetDatabaseAccount Fix

* Update _base.py

* retry utilities fixes

* Update _retry_utility.py

* retry utulities part 34

* Update _service_request_retry_policy.py

* remove extra logs

* policy updates

* Update _service_response_retry_policy.py

* Update _service_response_retry_policy.py

* policies updates and update operation types

* trying out fixes

* Update sdk/cosmos/azure-cosmos/CHANGELOG.md

Co-authored-by: Abhijeet Mohanty <[email protected]>

* Update sdk/cosmos/azure-cosmos/CHANGELOG.md

Co-authored-by: Abhijeet Mohanty <[email protected]>

* Skipped proxy test for debugging

* annotation fix

* Fixed some tests cases

* test fixes

* Update test_service_retry_policies_async.py

* Fixed some mocking behavior

* fixed pylint issues

* Added aiohttp minimum dependency

* Updated changelog and setup.py

* Updated changelog

---------

Co-authored-by: tvaron3 <[email protected]>
Co-authored-by: Abhijeet Mohanty <[email protected]>
Co-authored-by: Kushagra Thapar <[email protected]>
  • Loading branch information
4 people authored Jan 26, 2025
1 parent 8f3af53 commit db08404
Show file tree
Hide file tree
Showing 19 changed files with 861 additions and 70 deletions.
12 changes: 12 additions & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
## Release History

### 4.9.1b2 (2025-01-24)

#### Features Added
* Added new cross-regional retry logic for `ServiceRequestError` and `ServiceResponseError` exceptions. See [PR 39396](https://github.com/Azure/azure-sdk-for-python/pull/39396)

#### Bugs Fixed
* Fixed `KeyError` being returned by location cache when most preferred location is not present in cached regions. See [PR 39396](https://github.com/Azure/azure-sdk-for-python/pull/39396).
* Fixed cross-region retries on `CosmosClient` initialization. See [PR 39396](https://github.com/Azure/azure-sdk-for-python/pull/39396)

#### Other Changes
* This release requires aiohttp version 3.10.11 and above. See [PR 39396](https://github.com/Azure/azure-sdk-for-python/pull/39396)

### 4.9.1b1 (2024-12-13)

#### Features Added
Expand Down
7 changes: 7 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
path: str,
resource_id: Optional[str],
resource_type: str,
operation_type: str,
options: Mapping[str, Any],
partition_key_range_id: Optional[str] = None,
) -> Dict[str, Any]:
Expand All @@ -127,6 +128,7 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
:param str path:
:param str resource_id:
:param str resource_type:
:param str operation_type:
:param dict options:
:param str partition_key_range_id:
:return: The HTTP request headers.
Expand Down Expand Up @@ -323,6 +325,11 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
if resource_type != 'dbs' and options.get("containerRID"):
headers[http_constants.HttpHeaders.IntendedCollectionRID] = options["containerRID"]

if resource_type == "":
resource_type = http_constants.ResourceType.DatabaseAccount
headers[http_constants.HttpHeaders.ThinClientProxyResourceType] = resource_type
headers[http_constants.HttpHeaders.ThinClientProxyOperationType] = operation_type

return headers


Expand Down
31 changes: 21 additions & 10 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,7 +2038,8 @@ def PatchItem(
if options is None:
options = {}

headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, options)
headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type,
documents._OperationType.Patch, options)
# Patch will use WriteEndpoint since it uses PUT operation
request_params = RequestObject(resource_type, documents._OperationType.Patch)
request_data = {}
Expand Down Expand Up @@ -2126,7 +2127,8 @@ def _Batch(
) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]:
initial_headers = self.default_headers.copy()
base._populate_batch_headers(initial_headers)
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", options)
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs",
documents._OperationType.Batch, options)
request_params = RequestObject("docs", documents._OperationType.Batch)
return cast(
Tuple[List[Dict[str, Any]], CaseInsensitiveDict],
Expand Down Expand Up @@ -2185,7 +2187,8 @@ def DeleteAllItemsByPartitionKey(
# Specified url to perform background operation to delete all items by partition key
path = '{}{}/{}'.format(path, "operations", "partitionkeydelete")
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", options)
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id,
"partitionkey", documents._OperationType.Delete, options)
request_params = RequestObject("partitionkey", documents._OperationType.Delete)
_, last_response_headers = self.__Post(
path=path,
Expand Down Expand Up @@ -2353,7 +2356,8 @@ def ExecuteStoredProcedure(

path = base.GetPathFromLink(sproc_link)
sproc_id = base.GetResourceIdOrFullNameFromLink(sproc_link)
headers = base.GetHeaders(self, initial_headers, "post", path, sproc_id, "sprocs", options)
headers = base.GetHeaders(self, initial_headers, "post", path, sproc_id, "sprocs",
documents._OperationType.ExecuteJavaScript, options)

# ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation
request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript)
Expand Down Expand Up @@ -2550,7 +2554,8 @@ def GetDatabaseAccount(
if url_connection is None:
url_connection = self.url_connection

headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", {})
headers = base.GetHeaders(self, self.default_headers, "get", "", "", "",
documents._OperationType.Read,{})
request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection)
result, last_response_headers = self.__Get("", request_params, headers, **kwargs)
self.last_response_headers = last_response_headers
Expand Down Expand Up @@ -2615,7 +2620,8 @@ def Create(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Create,
options)
# Create will use WriteEndpoint since it uses POST operation

request_params = RequestObject(typ, documents._OperationType.Create)
Expand Down Expand Up @@ -2659,7 +2665,8 @@ def Upsert(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert,
options)
headers[http_constants.HttpHeaders.IsUpsert] = True

# Upsert will use WriteEndpoint since it uses POST operation
Expand Down Expand Up @@ -2703,7 +2710,8 @@ def Replace(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace,
options)
# Replace will use WriteEndpoint since it uses PUT operation
request_params = RequestObject(typ, documents._OperationType.Replace)
result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs)
Expand Down Expand Up @@ -2744,7 +2752,7 @@ def Read(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options)
# Read will use ReadEndpoint since it uses GET operation
request_params = RequestObject(typ, documents._OperationType.Read)
result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
Expand Down Expand Up @@ -2782,7 +2790,8 @@ def DeleteResource(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete,
options)
# Delete will use WriteEndpoint since it uses DELETE operation
request_params = RequestObject(typ, documents._OperationType.Delete)
result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs)
Expand Down Expand Up @@ -3027,6 +3036,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
path,
resource_id,
resource_type,
request_params.operation_type,
options,
partition_key_range_id
)
Expand Down Expand Up @@ -3064,6 +3074,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
path,
resource_id,
resource_type,
documents._OperationType.SqlQuery,
options,
partition_key_range_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
"""

import threading

from urllib.parse import urlparse

from azure.core.exceptions import AzureError

from . import _constants as constants
from . import exceptions
from ._location_cache import LocationCache
Expand Down Expand Up @@ -134,14 +135,14 @@ def _GetDatabaseAccount(self, **kwargs):
# specified (by creating a locational endpoint) and keeping eating the exception
# until we get the database account and return None at the end, if we are not able
# to get that info from any endpoints
except exceptions.CosmosHttpResponseError:
except (exceptions.CosmosHttpResponseError, AzureError):
for location_name in self.PreferredLocations:
locational_endpoint = _GlobalEndpointManager.GetLocationalEndpoint(self.DefaultEndpoint, location_name)
try:
database_account = self._GetDatabaseAccountStub(locational_endpoint, **kwargs)
self._database_account_cache = database_account
return database_account
except exceptions.CosmosHttpResponseError:
except (exceptions.CosmosHttpResponseError, AzureError):
pass
raise

Expand Down
17 changes: 8 additions & 9 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,22 @@ def should_refresh_endpoints(self): # pylint: disable=too-many-return-statement

should_refresh = self.use_multiple_write_locations and not self.enable_multiple_writable_locations

if most_preferred_location:
if self.available_read_endpoint_by_locations:
most_preferred_read_endpoint = self.available_read_endpoint_by_locations[most_preferred_location]
if most_preferred_read_endpoint and most_preferred_read_endpoint != self.read_endpoints[0]:
# For reads, we can always refresh in background as we can alternate to
# other available read endpoints
return True
else:
if most_preferred_location and most_preferred_location in self.available_read_endpoint_by_locations:
most_preferred_read_endpoint = self.available_read_endpoint_by_locations[most_preferred_location]
if most_preferred_read_endpoint and most_preferred_read_endpoint != self.read_endpoints[0]:
# For reads, we can always refresh in background as we can alternate to
# other available read endpoints
return True
else:
return True

if not self.can_use_multiple_write_locations():
if self.is_endpoint_unavailable(self.write_endpoints[0], EndpointOperationType.WriteType):
# Since most preferred write endpoint is unavailable, we can only refresh in background if
# we have an alternate write endpoint
return True
return should_refresh
if most_preferred_location:
if most_preferred_location and most_preferred_location in self.available_write_endpoint_by_locations:
most_preferred_write_endpoint = self.available_write_endpoint_by_locations[most_preferred_location]
if most_preferred_write_endpoint:
should_refresh |= most_preferred_write_endpoint != self.write_endpoints[0]
Expand Down
64 changes: 58 additions & 6 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import time
from typing import Optional

from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError
from requests.exceptions import ( # pylint: disable=networking-import-outside-azure-core-transport
ReadTimeout, ConnectTimeout) # pylint: disable=networking-import-outside-azure-core-transport
from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import RetryPolicy
from azure.core.pipeline.transport._base import HttpRequest

from . import exceptions
from . import _endpoint_discovery_retry_policy
Expand All @@ -38,6 +39,8 @@
from . import _gone_retry_policy
from . import _timeout_failover_retry_policy
from . import _container_recreate_retry_policy
from . import _service_request_retry_policy, _service_response_retry_policy
from .documents import _OperationType
from .http_constants import HttpHeaders, StatusCodes, SubStatusCodes


Expand Down Expand Up @@ -78,8 +81,14 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy(
client.connection_policy, global_endpoint_manager, *args
)
service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy(
client.connection_policy, global_endpoint_manager, *args,
)
service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy(
client.connection_policy, global_endpoint_manager, *args,
)
# HttpRequest we would need to modify for Container Recreate Retry Policy
request: Optional[HttpRequest] = None
request = None
if args and len(args) > 3:
# Reference HttpRequest instance in args
request = args[3]
Expand Down Expand Up @@ -188,6 +197,16 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
if kwargs['timeout'] <= 0:
raise exceptions.CosmosClientTimeoutError()

except ServiceRequestError as e:
_handle_service_request_retries(client, service_request_retry_policy, e, *args)

except ServiceResponseError as e:
if e.exc_type == ReadTimeout:
_handle_service_response_retries(request, client, service_response_retry_policy, e, *args)
elif e.exc_type == ConnectTimeout:
_handle_service_request_retries(client, service_request_retry_policy, e, *args)
else:
raise

def ExecuteFunction(function, *args, **kwargs):
"""Stub method so that it can be used for mocking purposes as well.
Expand All @@ -198,6 +217,31 @@ def ExecuteFunction(function, *args, **kwargs):
"""
return function(*args, **kwargs)

def _has_read_retryable_headers(request_headers):
if _OperationType.IsReadOnlyOperation(request_headers.get(HttpHeaders.ThinClientProxyOperationType)):
return True
return False

def _handle_service_request_retries(client, request_retry_policy, exception, *args):
# we resolve the request endpoint to the next preferred region
# once we are out of preferred regions we stop retrying
retry_policy = request_retry_policy
if not retry_policy.ShouldRetry():
if args and args[0].should_clear_session_token_on_session_read_failure and client.session:
client.session.clear_session_token(client.last_response_headers)
raise exception

def _handle_service_response_retries(request, client, response_retry_policy, exception, *args):
if _has_read_retryable_headers(request.headers):
# we resolve the request endpoint to the next preferred region
# once we are out of preferred regions we stop retrying
retry_policy = response_retry_policy
if not retry_policy.ShouldRetry():
if args and args[0].should_clear_session_token_on_session_read_failure and client.session:
client.session.clear_session_token(client.last_response_headers)
raise exception
else:
raise exception

def _configure_timeout(request: PipelineRequest, absolute: Optional[int], per_request: int) -> None:
if absolute is not None:
Expand Down Expand Up @@ -243,7 +287,6 @@ def send(self, request):
start_time = time.time()
try:
_configure_timeout(request, absolute_timeout, per_request_timeout)

response = self.next.send(request)
if self.is_retry(retry_settings, response):
retry_active = self.increment(retry_settings, response=response)
Expand All @@ -262,8 +305,17 @@ def send(self, request):
raise
except ServiceRequestError as err:
# the request ran into a socket timeout or failed to establish a new connection
# since request wasn't sent, we retry up to however many connection retries are configured (default 3)
if retry_settings['connect'] > 0:
# since request wasn't sent, raise exception immediately to be dealt with in client retry policies
raise err
except ServiceResponseError as err:
retry_error = err
if err.exc_type == ReadTimeout:
if _has_read_retryable_headers(request.http_request.headers):
# raise exception immediately to be dealt with in client retry policies
raise err
elif err.exc_type == ConnectTimeout:
raise err
if self._is_method_retryable(retry_settings, request.http_request):
retry_active = self.increment(retry_settings, response=request, error=err)
if retry_active:
self.sleep(retry_settings, request.context.transport)
Expand Down
Loading

0 comments on commit db08404

Please sign in to comment.