10
10
import pytest
11
11
12
12
import azure .cosmos .cosmos_client as cosmos_client
13
+ from azure .cosmos .container import ContainerProxy
13
14
import test_config
15
+ from _fault_injection_transport import FaultInjectionTransport
16
+ from test_fault_injection_transport import TestFaultInjectionTransport
17
+ from typing import List , Callable
18
+ from azure .core .rest import HttpRequest
14
19
15
20
try :
16
21
from unittest .mock import Mock
@@ -30,8 +35,32 @@ def reset(self):
30
35
def emit (self , record ):
31
36
self .messages .append (record )
32
37
33
-
34
-
38
+ CONFIG = test_config .TestConfig
39
+ L1 = "Location1"
40
+ L2 = "Location2"
41
+ L1_URL = test_config .TestConfig .local_host
42
+ L2_URL = L1_URL .replace ("localhost" , "127.0.0.1" )
43
+ URL_TO_LOCATIONS = {
44
+ L1_URL : L1 ,
45
+ L2_URL : L2 }
46
+
47
+
48
+ def create_logger (name : str , mock_handler : MockHandler , level : int = logging .INFO ) -> logging .Logger :
49
+ logger = logging .getLogger (name )
50
+ logger .addHandler (mock_handler )
51
+ logger .setLevel (level )
52
+
53
+ return logger
54
+
55
+ def get_locations_list (msg : str ) -> List [str ]:
56
+ msg = msg .replace (' ' , '' )
57
+ msg = msg .replace ('\' ' , '' )
58
+ # Find the substring between the first '[' and the last ']'
59
+ start = msg .find ('[' ) + 1
60
+ end = msg .rfind (']' )
61
+ # Extract the substring and convert it to a list using ast.literal_eval
62
+ msg = msg [start :end ]
63
+ return msg .split (',' )
35
64
36
65
@pytest .mark .cosmosEmulator
37
66
class TestCosmosHttpLogger (unittest .TestCase ):
@@ -54,12 +83,8 @@ def setUpClass(cls):
54
83
"tests." )
55
84
cls .mock_handler_default = MockHandler ()
56
85
cls .mock_handler_diagnostic = MockHandler ()
57
- cls .logger_default = logging .getLogger ("testloggerdefault" )
58
- cls .logger_default .addHandler (cls .mock_handler_default )
59
- cls .logger_default .setLevel (logging .INFO )
60
- cls .logger_diagnostic = logging .getLogger ("testloggerdiagnostic" )
61
- cls .logger_diagnostic .addHandler (cls .mock_handler_diagnostic )
62
- cls .logger_diagnostic .setLevel (logging .INFO )
86
+ cls .logger_default = create_logger ("testloggerdefault" , cls .mock_handler_default )
87
+ cls .logger_diagnostic = create_logger ("testloggerdiagnostic" , cls .mock_handler_diagnostic )
63
88
cls .client_default = cosmos_client .CosmosClient (cls .host , cls .masterKey ,
64
89
consistency_level = "Session" ,
65
90
connection_policy = cls .connectionPolicy ,
@@ -136,6 +161,65 @@ def test_cosmos_http_logging_policy(self):
136
161
137
162
self .mock_handler_diagnostic .reset ()
138
163
164
+ def test_client_settings (self ):
165
+ # Test data
166
+ all_locations = [L1 , L2 ]
167
+ client_excluded_locations = [L1 ]
168
+ multiple_write_locations = True
169
+
170
+ # Client setup
171
+ mock_handler = MockHandler ()
172
+ logger = create_logger ("test_logger_client_settings" , mock_handler )
173
+
174
+ custom_transport = FaultInjectionTransport ()
175
+ is_get_account_predicate : Callable [[HttpRequest ], bool ] = lambda \
176
+ r : FaultInjectionTransport .predicate_is_database_account_call (r )
177
+ emulator_as_multi_write_region_account_transformation = \
178
+ lambda r , inner : FaultInjectionTransport .transform_topology_mwr (
179
+ first_region_name = L1 ,
180
+ second_region_name = L2 ,
181
+ inner = inner ,
182
+ first_region_url = L1_URL ,
183
+ second_region_url = L2_URL ,
184
+ )
185
+ custom_transport .add_response_transformation (
186
+ is_get_account_predicate ,
187
+ emulator_as_multi_write_region_account_transformation )
188
+
189
+ initialized_objects = TestFaultInjectionTransport .setup_method_with_custom_transport (
190
+ custom_transport ,
191
+ default_endpoint = CONFIG .host ,
192
+ key = CONFIG .masterKey ,
193
+ database_id = CONFIG .TEST_DATABASE_ID ,
194
+ container_id = CONFIG .TEST_SINGLE_PARTITION_CONTAINER_ID ,
195
+ preferred_locations = all_locations ,
196
+ excluded_locations = client_excluded_locations ,
197
+ multiple_write_locations = multiple_write_locations ,
198
+ custom_logger = logger
199
+ )
200
+ mock_handler .reset ()
201
+
202
+ # create an item
203
+ id_value : str = str (uuid .uuid4 ())
204
+ document_definition = {'id' : id_value , 'pk' : id_value }
205
+ container : ContainerProxy = initialized_objects ["col" ]
206
+ container .create_item (body = document_definition )
207
+
208
+ # Verify endpoint locations
209
+ messages_split = mock_handler .messages [1 ].message .split ("\n " )
210
+ for message in messages_split :
211
+ if "Client Preferred Regions:" in message :
212
+ locations = get_locations_list (message )
213
+ assert all_locations == locations
214
+ elif "Client Excluded Regions:" in message :
215
+ locations = get_locations_list (message )
216
+ assert client_excluded_locations == locations
217
+ elif "Client Account Read Regions:" in message :
218
+ locations = get_locations_list (message )
219
+ assert all_locations == locations
220
+ elif "Client Account Write Regions:" in message :
221
+ locations = get_locations_list (message )
222
+ assert all_locations == locations
139
223
140
224
if __name__ == "__main__" :
141
225
unittest .main ()
0 commit comments