Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
#### Breaking Changes

#### Bugs Fixed
* Fixed bug where customer provided excluded region was not always being honored during certain transient failures. See [PR 43602](https://github.com/Azure/azure-sdk-for-python/pull/43602)

#### Other Changes
* Enhanced logging to ensure when a region is marked unavailable we have the proper context. See [PR 43602](https://github.com/Azure/azure-sdk-for-python/pull/43602)

### 4.14.0 (2025-10-13)
This version and all future versions will require Python 3.9+.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,16 @@ def ShouldRetry(self, exception): # pylint: disable=unused-argument
self.failover_retry_count += 1

if self.request.location_endpoint_to_route:
context = self.__class__.__name__
if _OperationType.IsReadOnlyOperation(self.request.operation_type):
# Mark current read endpoint as unavailable
self.global_endpoint_manager.mark_endpoint_unavailable_for_read(
self.request.location_endpoint_to_route,
True)
True, context)
else:
self.global_endpoint_manager.mark_endpoint_unavailable_for_write(
self.request.location_endpoint_to_route,
True)
True, context)

# set the refresh_needed flag to ensure that endpoint list is
# refreshed with new writable and readable locations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def _resolve_service_endpoint(
) -> str:
return self.location_cache.resolve_service_endpoint(request)

def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache):
self.location_cache.mark_endpoint_unavailable_for_read(endpoint, refresh_cache)
def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache, context: str):
self.location_cache.mark_endpoint_unavailable_for_read(endpoint, refresh_cache, context)

def mark_endpoint_unavailable_for_write(self, endpoint, refresh_cache):
self.location_cache.mark_endpoint_unavailable_for_write(endpoint, refresh_cache)
def mark_endpoint_unavailable_for_write(self, endpoint, refresh_cache, context: str):
self.location_cache.mark_endpoint_unavailable_for_write(endpoint, refresh_cache, context)

def get_ordered_write_locations(self):
return self.location_cache.get_ordered_write_locations()
Expand All @@ -96,14 +96,15 @@ def force_refresh_on_startup(self, database_account):
def update_location_cache(self):
self.location_cache.update_location_cache()

def _mark_endpoint_unavailable(self, endpoint: str):
def _mark_endpoint_unavailable(self, endpoint: str, context: str):
"""Marks an endpoint as unavailable for the appropriate operations.
:param str endpoint: The endpoint to mark as unavailable.
:param str context: The context for marking the endpoint as unavailable.
"""
write_endpoints = self.location_cache.get_all_write_endpoints()
self.mark_endpoint_unavailable_for_read(endpoint, False)
self.mark_endpoint_unavailable_for_read(endpoint, False, context)
if endpoint in write_endpoints:
self.mark_endpoint_unavailable_for_write(endpoint, False)
self.mark_endpoint_unavailable_for_write(endpoint, False, context)

def refresh_endpoint_list(self, database_account, **kwargs):
if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms:
Expand Down Expand Up @@ -159,7 +160,7 @@ def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
self.location_cache.mark_endpoint_available(locational_endpoint)
return database_account, locational_endpoint
except (exceptions.CosmosHttpResponseError, AzureError):
self._mark_endpoint_unavailable(locational_endpoint)
self._mark_endpoint_unavailable(locational_endpoint, "_GetDatabaseAccount")
raise

def _endpoints_health_check(self, **kwargs):
Expand Down Expand Up @@ -194,7 +195,7 @@ def _endpoints_health_check(self, **kwargs):
success_count += 1
self.location_cache.mark_endpoint_available(endpoint)
except (exceptions.CosmosHttpResponseError, AzureError):
self._mark_endpoint_unavailable(endpoint)
self._mark_endpoint_unavailable(endpoint, "_endpoints_health_check")

finally:
# after the health check for that endpoint setting the timeouts back to their original values
Expand Down
109 changes: 66 additions & 43 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def get_write_regional_routing_context(self):
def get_read_regional_routing_context(self):
return self.get_read_regional_routing_contexts()[0].get_primary()

def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache):
self.mark_endpoint_unavailable(endpoint, EndpointOperationType.ReadType, refresh_cache)
def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache, context="Unknown"):
self.mark_endpoint_unavailable(endpoint, EndpointOperationType.ReadType, refresh_cache, context)

def mark_endpoint_unavailable_for_write(self, endpoint, refresh_cache):
self.mark_endpoint_unavailable(endpoint, EndpointOperationType.WriteType, refresh_cache)
def mark_endpoint_unavailable_for_write(self, endpoint, refresh_cache, context="Unknown"):
self.mark_endpoint_unavailable(endpoint, EndpointOperationType.WriteType, refresh_cache, context)

def perform_on_database_account_read(self, database_account):
self.update_location_cache(
Expand Down Expand Up @@ -185,42 +185,60 @@ def _get_configured_excluded_locations(self, request: RequestObject) -> list[str
excluded_locations = list(self.connection_policy.ExcludedLocations)
else:
excluded_locations = []
for excluded_location in request.excluded_locations_circuit_breaker:
if excluded_location not in excluded_locations:
excluded_locations.append(excluded_location)

return excluded_locations

def _get_applicable_read_regional_routing_contexts(self, request: RequestObject) -> list[RegionalRoutingContext]:
# Get configured excluded locations
excluded_locations = self._get_configured_excluded_locations(request)
return self._get_applicable_regional_routing_contexts(
request,
self.get_read_regional_routing_contexts(),
self.account_locations_by_read_endpoints,
self.get_write_regional_routing_contexts()[0] # Fallback to primary write region
)

# If excluded locations were configured, return filtered regional endpoints by excluded locations.
if excluded_locations:
return _get_applicable_regional_routing_contexts(
self.get_read_regional_routing_contexts(),
self.account_locations_by_read_endpoints,
self.get_write_regional_routing_contexts()[0],
excluded_locations,
request.resource_type)
def _get_applicable_write_regional_routing_contexts(self, request: RequestObject) -> list[RegionalRoutingContext]:
return self._get_applicable_regional_routing_contexts(
request,
self.get_write_regional_routing_contexts(),
self.account_locations_by_write_endpoints,
self.default_regional_routing_context # Fallback to default global endpoint
)

# Else, return all regional endpoints
return self.get_read_regional_routing_contexts()
def _get_applicable_regional_routing_contexts(
self,
request: RequestObject,
regional_routing_contexts: list[RegionalRoutingContext],
location_name_by_endpoint: Mapping[str, str],
fallback_regional_routing_context: RegionalRoutingContext
) -> list[RegionalRoutingContext]:
user_excluded_locations = self._get_configured_excluded_locations(request)
circuit_breaker_excluded_locations = request.excluded_locations_circuit_breaker or []

if not user_excluded_locations and not circuit_breaker_excluded_locations:
return regional_routing_contexts

applicable_contexts = []
last_resort_contexts = []

for context in regional_routing_contexts:
location = location_name_by_endpoint.get(context.get_primary())
if location in user_excluded_locations:
# For metadata calls, user-excluded locations are added at the end as a last resort.
if base.IsMasterResource(request.resource_type):
last_resort_contexts.append(context)
elif location in circuit_breaker_excluded_locations:
last_resort_contexts.append(context)
else:
applicable_contexts.append(context)

def _get_applicable_write_regional_routing_contexts(self, request: RequestObject) -> list[RegionalRoutingContext]:
# Get configured excluded locations
excluded_locations = self._get_configured_excluded_locations(request)
# Append circuit-breaker-excluded locations (and user-excluded for metadata) to be used as a last resort.
applicable_contexts.extend(last_resort_contexts)

# If excluded locations were configured, return filtered regional endpoints by excluded locations.
if excluded_locations:
return _get_applicable_regional_routing_contexts(
self.get_write_regional_routing_contexts(),
self.account_locations_by_write_endpoints,
self.default_regional_routing_context,
excluded_locations,
request.resource_type)
# If all preferred locations are excluded, use the fallback endpoint.
if not applicable_contexts:
applicable_contexts.append(fallback_regional_routing_context)

# Else, return all regional endpoints
return self.get_write_regional_routing_contexts()
return applicable_contexts

def resolve_service_endpoint(self, request):
if request.location_endpoint_to_route:
Expand All @@ -238,14 +256,17 @@ def resolve_service_endpoint(self, request):
# For non-document resource types in case of client can use multiple write locations
# or when client cannot use multiple write locations, flip-flop between the
# first and the second writable region in DatabaseAccount (for manual failover)
if self.connection_policy.EnableEndpointDiscovery and self.account_write_locations:
location_index = min(location_index % 2, len(self.account_write_locations) - 1)
write_location = self.account_write_locations[location_index]
if (self.account_write_regional_routing_contexts_by_location
and write_location in self.account_write_regional_routing_contexts_by_location):
write_regional_routing_context = (
self.account_write_regional_routing_contexts_by_location)[write_location]
return write_regional_routing_context.get_primary()
if self.connection_policy.EnableEndpointDiscovery:
# Get the list of applicable write locations, which respects excluded locations.
applicable_contexts = self._get_applicable_write_regional_routing_contexts(request)
if not applicable_contexts:
# if all write locations are excluded, fall back to the default endpoint
return self.default_regional_routing_context.get_primary()

# For single-master writes, flip-flop between the first and second *applicable*
# regions for manual failover.
index = min(location_index % 2, len(applicable_contexts) - 1)
return applicable_contexts[index].get_primary()
# if endpoint discovery is off for reads it should use passed in endpoint
return self.default_regional_routing_context.get_primary()

Expand Down Expand Up @@ -317,10 +338,12 @@ def is_endpoint_unavailable(self, endpoint: str, expected_available_operation: s
return True

def mark_endpoint_unavailable(
self, unavailable_endpoint: str, unavailable_operation_type: EndpointOperationType, refresh_cache: bool):
logger.warning("Marking %s unavailable for %s ",
self, unavailable_endpoint: str, unavailable_operation_type: EndpointOperationType, refresh_cache: bool,
context: str):
logger.warning("Marking %s unavailable for %s. Source: %s",
unavailable_endpoint,
unavailable_operation_type)
unavailable_operation_type,
context)
unavailability_info = (
self.location_unavailability_info_by_endpoint[unavailable_endpoint]
if unavailable_endpoint in self.location_unavailability_info_by_endpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@ def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper,
self.failover_retry_count = 0
self.connection_policy = connection_policy
self.request = args[0] if args else None

if self.request:
if _OperationType.IsReadOnlyOperation(self.request.operation_type):
self.total_retries = len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts)
self.total_retries = len(
self.global_endpoint_manager.location_cache._get_applicable_read_regional_routing_contexts(
self.request))
else:
self.total_retries = len(self.global_endpoint_manager.location_cache.write_regional_routing_contexts)
self.total_retries = len(
self.global_endpoint_manager.location_cache._get_applicable_write_regional_routing_contexts(
self.request))


def ShouldRetry(self): # pylint: disable=too-many-return-statements
"""Returns true if the request should retry based on preferred regions and retries already done.
Expand Down Expand Up @@ -72,10 +78,11 @@ def resolve_next_region_service_endpoint(self):
return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)

def mark_endpoint_unavailable(self, unavailable_endpoint):
context = self.__class__.__name__
if _OperationType.IsReadOnlyOperation(self.request.operation_type):
self.global_endpoint_manager.mark_endpoint_unavailable_for_read(unavailable_endpoint, True)
self.global_endpoint_manager.mark_endpoint_unavailable_for_read(unavailable_endpoint, True, context)
else:
self.global_endpoint_manager.mark_endpoint_unavailable_for_write(unavailable_endpoint, True)
self.global_endpoint_manager.mark_endpoint_unavailable_for_write(unavailable_endpoint, True, context)

def update_location_cache(self):
self.global_endpoint_manager.update_location_cache()
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def _resolve_service_endpoint(
) -> str:
return self.location_cache.resolve_service_endpoint(request)

def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache):
self.location_cache.mark_endpoint_unavailable_for_read(endpoint, refresh_cache)
def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache, context: str):
self.location_cache.mark_endpoint_unavailable_for_read(endpoint, refresh_cache, context)

def mark_endpoint_unavailable_for_write(self, endpoint, refresh_cache):
self.location_cache.mark_endpoint_unavailable_for_write(endpoint, refresh_cache)
def mark_endpoint_unavailable_for_write(self, endpoint, refresh_cache, context: str):
self.location_cache.mark_endpoint_unavailable_for_write(endpoint, refresh_cache, context)

def get_ordered_write_locations(self):
return self.location_cache.get_ordered_write_locations()
Expand All @@ -100,14 +100,14 @@ async def force_refresh_on_startup(self, database_account):
def update_location_cache(self):
self.location_cache.update_location_cache()

def _mark_endpoint_unavailable(self, endpoint: str):
def _mark_endpoint_unavailable(self, endpoint: str, context: str):
"""Marks an endpoint as unavailable for the appropriate operations.
:param str endpoint: The endpoint to mark as unavailable.
"""
write_endpoints = self.location_cache.get_all_write_endpoints()
self.mark_endpoint_unavailable_for_read(endpoint, False)
self.mark_endpoint_unavailable_for_read(endpoint, False, context)
if endpoint in write_endpoints:
self.mark_endpoint_unavailable_for_write(endpoint, False)
self.mark_endpoint_unavailable_for_write(endpoint, False, context)

async def refresh_endpoint_list(self, database_account, **kwargs):
if self.refresh_task and self.refresh_task.done():
Expand Down Expand Up @@ -151,7 +151,7 @@ async def _database_account_check(self, endpoint: str, **kwargs: dict[str, Any])
await self.client._GetDatabaseAccountCheck(endpoint, **kwargs)
self.location_cache.mark_endpoint_available(endpoint)
except (exceptions.CosmosHttpResponseError, AzureError):
self._mark_endpoint_unavailable(endpoint)
self._mark_endpoint_unavailable(endpoint,"_database_account_check")

async def _endpoints_health_check(self, **kwargs):
"""Gets the database account for each endpoint.
Expand Down Expand Up @@ -200,7 +200,7 @@ async def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
self.location_cache.mark_endpoint_available(locational_endpoint)
return database_account, locational_endpoint
except (exceptions.CosmosHttpResponseError, AzureError):
self._mark_endpoint_unavailable(locational_endpoint)
self._mark_endpoint_unavailable(locational_endpoint,"_GetDatabaseAccount")
raise

async def _GetDatabaseAccountStub(self, endpoint, **kwargs):
Expand Down
Loading
Loading