|
15 | 15 | """Test the topology module's Server Selection Spec implementation."""
|
16 | 16 |
|
17 | 17 | import os
|
| 18 | +import threading |
| 19 | + |
18 | 20 | from pymongo.common import clean_node, HEARTBEAT_FREQUENCY
|
19 |
| -from pymongo.pool import PoolOptions |
20 | 21 | from pymongo.read_preferences import ReadPreference
|
21 | 22 | from pymongo.settings import TopologySettings
|
22 | 23 | from pymongo.topology import Topology
|
23 |
| -from test import unittest |
| 24 | +from test import client_context, IntegrationTest, unittest |
24 | 25 | from test.utils_selection_tests import (
|
25 | 26 | get_addresses,
|
26 | 27 | get_topology_settings_dict,
|
27 | 28 | make_server_description)
|
28 |
| -from test.utils import TestCreator |
| 29 | +from test.utils import TestCreator, rs_client, OvertCommandListener |
29 | 30 |
|
30 | 31 |
|
31 | 32 | # Location of JSON test specifications.
|
@@ -106,5 +107,74 @@ def tests(self, scenario_def):
|
106 | 107 |
|
107 | 108 | CustomTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()
|
108 | 109 |
|
| 110 | + |
| 111 | +class FinderThread(threading.Thread): |
| 112 | + def __init__(self, collection, iterations): |
| 113 | + super(FinderThread, self).__init__() |
| 114 | + self.daemon = True |
| 115 | + self.collection = collection |
| 116 | + self.iterations = iterations |
| 117 | + self.passed = False |
| 118 | + |
| 119 | + def run(self): |
| 120 | + for _ in range(self.iterations): |
| 121 | + self.collection.find_one({}) |
| 122 | + self.passed = True |
| 123 | + |
| 124 | + |
| 125 | +class TestProse(IntegrationTest): |
| 126 | + def frequencies(self, client, listener): |
| 127 | + coll = client.test.test |
| 128 | + N_FINDS = 10 |
| 129 | + N_THREADS = 10 |
| 130 | + threads = [FinderThread(coll, N_FINDS) for _ in range(N_THREADS)] |
| 131 | + for thread in threads: |
| 132 | + thread.start() |
| 133 | + for thread in threads: |
| 134 | + thread.join() |
| 135 | + for thread in threads: |
| 136 | + self.assertTrue(thread.passed) |
| 137 | + |
| 138 | + events = listener.results['started'] |
| 139 | + self.assertEqual(len(events), N_FINDS * N_THREADS) |
| 140 | + nodes = client.nodes |
| 141 | + self.assertEqual(len(nodes), 2) |
| 142 | + freqs = {address: 0 for address in nodes} |
| 143 | + for event in events: |
| 144 | + freqs[event.connection_id] += 1 |
| 145 | + for address in freqs: |
| 146 | + freqs[address] = freqs[address]/float(len(events)) |
| 147 | + return freqs |
| 148 | + |
| 149 | + @client_context.require_failCommand_appName |
| 150 | + @client_context.require_multiple_mongoses |
| 151 | + def test_load_balancing(self): |
| 152 | + listener = OvertCommandListener() |
| 153 | + client = rs_client(client_context.mongos_seeds(), |
| 154 | + appName='loadBalancingTest', |
| 155 | + event_listeners=[listener]) |
| 156 | + self.addCleanup(client.close) |
| 157 | + # Delay find commands on |
| 158 | + delay_finds = { |
| 159 | + 'configureFailPoint': 'failCommand', |
| 160 | + 'mode': {'times': 10000}, |
| 161 | + 'data': { |
| 162 | + 'failCommands': ['find'], |
| 163 | + 'blockConnection': True, |
| 164 | + 'blockTimeMS': 500, |
| 165 | + 'appName': 'loadBalancingTest', |
| 166 | + }, |
| 167 | + } |
| 168 | + with self.fail_point(delay_finds): |
| 169 | + nodes = client_context.client.nodes |
| 170 | + self.assertEqual(len(nodes), 1) |
| 171 | + delayed_server = next(iter(nodes)) |
| 172 | + freqs = self.frequencies(client, listener) |
| 173 | + self.assertLessEqual(freqs[delayed_server], 0.20) |
| 174 | + listener.reset() |
| 175 | + freqs = self.frequencies(client, listener) |
| 176 | + self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.05) |
| 177 | + |
| 178 | + |
109 | 179 | if __name__ == "__main__":
|
110 | 180 | unittest.main()
|
0 commit comments