Skip to content

Commit 6a39435

Browse files
FabianMeiswinkeltvaron3kushagraThapar
authored
Gateway mode fault injection in python (#39623)
* change default read timeout * fix tests * Add read timeout tests for database account calls * fix timeout retry policy * Fixed the timeout logic * Fixed the timeout retry policy * Mock tests for timeout and failover retry policy * Create test_dummy.py * Update test_dummy.py * Update test_dummy.py * Update test_dummy.py * Iterating on fault injection tooling * Refactoring to have FaultInjectionTransport in its own file * Update test_dummy.py * Reafctoring FaultInjectionTransport * Iterating on tests * Prettifying tests * small refactoring * Adding MM topology on Emulator * Adding cross region retry tests * fix mypy errors * remove async await * refactor and fix tests * Fix refactoring * Fix tests * fix tests * add more tests * add more tests * Add tests * fix tests * fix tests * fix test * fix test * fix tests * fix async in test * initial sync version of fault injection * add all sync tests * add new error and fix logs * fix test --------- Co-authored-by: tvaron3 <[email protected]> Co-authored-by: Kushagra Thapar <[email protected]> Co-authored-by: Kushagra Thapar <[email protected]> Co-authored-by: Tomas Varon Saldarriaga <[email protected]>
1 parent 7102446 commit 6a39435

7 files changed

+1504
-4
lines changed

sdk/cosmos/azure-cosmos/pytest.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[pytest]
2+
markers =
3+
cosmosEmulator: marks tests as depending in Cosmos DB Emulator.
4+
cosmosLong: marks tests to be run on a Cosmos DB live account.
5+
cosmosQuery: marks tests running queries on Cosmos DB live account.
6+
cosmosSplit: marks test where there are partition splits on CosmosDB live account.
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# The MIT License (MIT)
2+
# Copyright (c) 2014 Microsoft Corporation
3+
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy
5+
# of this software and associated documentation files (the "Software"), to deal
6+
# in the Software without restriction, including without limitation the rights
7+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
# copies of the Software, and to permit persons to whom the Software is
9+
# furnished to do so, subject to the following conditions:
10+
11+
# The above copyright notice and this permission notice shall be included in all
12+
# copies or substantial portions of the Software.
13+
14+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
# SOFTWARE.
21+
22+
"""RequestTransport allowing injection of faults between SDK and Cosmos Gateway
23+
"""
24+
25+
import json
26+
import logging
27+
import sys
28+
from time import sleep
29+
from typing import Callable, Optional, Any, Dict, List, MutableMapping
30+
31+
from azure.core.pipeline.transport import HttpRequest, HttpResponse
32+
from azure.core.pipeline.transport._requests_basic import RequestsTransport, RequestsTransportResponse
33+
from requests import Session
34+
35+
from azure.cosmos import documents
36+
37+
import test_config
38+
from azure.cosmos.exceptions import CosmosHttpResponseError
39+
from azure.core.exceptions import ServiceRequestError, ServiceResponseError
40+
41+
class FaultInjectionTransport(RequestsTransport):
42+
logger = logging.getLogger('azure.cosmos.fault_injection_transport')
43+
logger.setLevel(logging.DEBUG)
44+
45+
def __init__(self, *, session: Optional[Session] = None, loop=None, session_owner: bool = True, **config):
46+
self.faults: List[Dict[str, Any]] = []
47+
self.requestTransformations: List[Dict[str, Any]] = []
48+
self.responseTransformations: List[Dict[str, Any]] = []
49+
super().__init__(session=session, loop=loop, session_owner=session_owner, **config)
50+
51+
def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], Exception]):
52+
self.faults.append({"predicate": predicate, "apply": fault_factory})
53+
54+
def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], RequestsTransportResponse]], RequestsTransportResponse]):
55+
self.responseTransformations.append({
56+
"predicate": predicate,
57+
"apply": response_transformation})
58+
59+
@staticmethod
60+
def __first_item(iterable, condition=lambda x: True):
61+
"""
62+
Returns the first item in the `iterable` that satisfies the `condition`.
63+
64+
If no item satisfies the condition, it returns None.
65+
"""
66+
return next((x for x in iterable if condition(x)), None)
67+
68+
def send(self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs) -> HttpResponse:
69+
FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url))
70+
# find the first fault Factory with matching predicate if any
71+
first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request))
72+
if first_fault_factory:
73+
FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection")
74+
injected_error = first_fault_factory["apply"](request)
75+
FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injected_error))
76+
raise injected_error
77+
78+
# apply the chain of request transformations with matching predicates if any
79+
matching_request_transformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations))
80+
for currentTransformation in matching_request_transformations:
81+
FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation")
82+
request = currentTransformation["apply"](request)
83+
84+
first_response_transformation = FaultInjectionTransport.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request))
85+
86+
FaultInjectionTransport.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask")
87+
get_response_task = super().send(request, proxies=proxies, **kwargs)
88+
FaultInjectionTransport.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask")
89+
90+
if first_response_transformation:
91+
FaultInjectionTransport.logger.info(f"Invoking response transformation")
92+
response = first_response_transformation["apply"](request, lambda: get_response_task)
93+
response.headers["_request"] = request
94+
FaultInjectionTransport.logger.info(f"Received response transformation result with status code {response.status_code}")
95+
return response
96+
else:
97+
FaultInjectionTransport.logger.info(f"Sending request to {request.url}")
98+
response = get_response_task
99+
response.headers["_request"] = request
100+
FaultInjectionTransport.logger.info(f"Received response with status code {response.status_code}")
101+
return response
102+
103+
@staticmethod
104+
def predicate_url_contains_id(r: HttpRequest, id_value: str) -> bool:
105+
return id_value in r.url
106+
107+
@staticmethod
108+
def predicate_targets_region(r: HttpRequest, region_endpoint: str) -> bool:
109+
return r.url.startswith(region_endpoint)
110+
111+
@staticmethod
112+
def print_call_stack():
113+
print("Call stack:")
114+
frame = sys._getframe()
115+
while frame:
116+
print(f"File: {frame.f_code.co_filename}, Line: {frame.f_lineno}, Function: {frame.f_code.co_name}")
117+
frame = frame.f_back
118+
119+
@staticmethod
120+
def predicate_req_payload_contains_id(r: HttpRequest, id_value: str):
121+
if r.body is None:
122+
return False
123+
124+
return '"id":"{}"'.format(id_value) in r.body
125+
126+
@staticmethod
127+
def predicate_req_for_document_with_id(r: HttpRequest, id_value: str) -> bool:
128+
return (FaultInjectionTransport.predicate_url_contains_id(r, id_value)
129+
or FaultInjectionTransport.predicate_req_payload_contains_id(r, id_value))
130+
131+
@staticmethod
132+
def predicate_is_database_account_call(r: HttpRequest) -> bool:
133+
is_db_account_read = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount'
134+
and r.headers.get('x-ms-thinclient-proxy-operation-type') == 'Read')
135+
136+
return is_db_account_read
137+
138+
@staticmethod
139+
def predicate_is_document_operation(r: HttpRequest) -> bool:
140+
is_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs')
141+
142+
return is_document_operation
143+
144+
@staticmethod
145+
def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool:
146+
is_write_document_operation = documents._OperationType.IsWriteOperation(
147+
str(r.headers.get('x-ms-thinclient-proxy-operation-type')))
148+
149+
return is_write_document_operation and uri_prefix in r.url
150+
151+
@staticmethod
152+
def error_after_delay(delay_in_ms: int, error: Exception) -> Exception:
153+
sleep(delay_in_ms / 1000.0)
154+
return error
155+
156+
@staticmethod
157+
def error_write_forbidden() -> Exception:
158+
return CosmosHttpResponseError(
159+
status_code=403,
160+
message="Injected error disallowing writes in this region.",
161+
response=None,
162+
sub_status_code=3,
163+
)
164+
165+
@staticmethod
166+
def error_region_down() -> Exception:
167+
return ServiceRequestError(
168+
message="Injected region down.",
169+
)
170+
171+
@staticmethod
172+
def error_service_response() -> Exception:
173+
return ServiceResponseError(
174+
message="Injected Service Response Error.",
175+
)
176+
177+
@staticmethod
178+
def transform_topology_swr_mrr(
179+
write_region_name: str,
180+
read_region_name: str,
181+
inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse:
182+
183+
response = inner()
184+
if not FaultInjectionTransport.predicate_is_database_account_call(response.request):
185+
return response
186+
187+
data = response.body()
188+
if response.status_code == 200 and data:
189+
data = data.decode("utf-8")
190+
result = json.loads(data)
191+
readable_locations = result["readableLocations"]
192+
writable_locations = result["writableLocations"]
193+
readable_locations[0]["name"] = write_region_name
194+
writable_locations[0]["name"] = write_region_name
195+
readable_locations.append({"name": read_region_name, "databaseAccountEndpoint" : test_config.TestConfig.local_host})
196+
FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result))
197+
request: HttpRequest = response.request
198+
return FaultInjectionTransport.MockHttpResponse(request, 200, result)
199+
200+
return response
201+
202+
@staticmethod
203+
def transform_topology_mwr(
204+
first_region_name: str,
205+
second_region_name: str,
206+
inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse:
207+
208+
response = inner()
209+
if not FaultInjectionTransport.predicate_is_database_account_call(response.request):
210+
return response
211+
212+
data = response.body()
213+
if response.status_code == 200 and data:
214+
data = data.decode("utf-8")
215+
result = json.loads(data)
216+
readable_locations = result["readableLocations"]
217+
writable_locations = result["writableLocations"]
218+
readable_locations[0]["name"] = first_region_name
219+
writable_locations[0]["name"] = first_region_name
220+
readable_locations.append(
221+
{"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host})
222+
writable_locations.append(
223+
{"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host})
224+
result["enableMultipleWriteLocations"] = True
225+
FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result))
226+
request: HttpRequest = response.request
227+
return FaultInjectionTransport.MockHttpResponse(request, 200, result)
228+
229+
return response
230+
231+
class MockHttpResponse(RequestsTransportResponse):
232+
def __init__(self, request: HttpRequest, status_code: int, content:Optional[Dict[str, Any]]):
233+
self.request: HttpRequest = request
234+
# This is actually never None, and set by all implementations after the call to
235+
# __init__ of this class. This class is also a legacy impl, so it's risky to change it
236+
# for low benefits The new "rest" implementation does define correctly status_code
237+
# as non-optional.
238+
self.status_code: int = status_code
239+
self.headers: MutableMapping[str, str] = {}
240+
self.reason: Optional[str] = None
241+
self.content_type: Optional[str] = None
242+
self.block_size: int = 4096 # Default to same as R
243+
self.content: Optional[Dict[str, Any]] = None
244+
self.json_text: str = ""
245+
self.bytes: bytes = b""
246+
if content:
247+
self.content = content
248+
self.json_text = json.dumps(content)
249+
self.bytes = self.json_text.encode("utf-8")
250+
251+
252+
def body(self) -> bytes:
253+
return self.bytes
254+
255+
def text(self, encoding: Optional[str] = None) -> str:
256+
return self.json_text
257+
258+
def load_body(self) -> None:
259+
return

0 commit comments

Comments
 (0)