|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | | -"""Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" |
| 15 | +"""Shared helper methods for pymongo, bson, and gridfs test suites.""" |
16 | 16 | from __future__ import annotations
|
17 | 17 |
|
18 | 18 | import asyncio
|
19 | | -import base64 |
20 | | -import gc |
21 | | -import multiprocessing |
22 | | -import os |
23 | | -import signal |
24 | | -import socket |
25 | | -import subprocess |
26 | | -import sys |
27 | 19 | import threading
|
28 | | -import time |
29 | 20 | import traceback
|
30 | | -import unittest |
31 | | -import warnings |
32 | | -from inspect import iscoroutinefunction |
33 | | - |
34 | | -from pymongo._asyncio_task import create_task |
35 | | - |
36 | | -try: |
37 | | - import ipaddress |
38 | | - |
39 | | - HAVE_IPADDRESS = True |
40 | | -except ImportError: |
41 | | - HAVE_IPADDRESS = False |
42 | 21 | from functools import wraps
|
43 | | -from typing import Any, Callable, Dict, Generator, Optional, no_type_check |
44 | | -from unittest import SkipTest |
| 22 | +from typing import Optional, no_type_check |
45 | 23 |
|
46 | | -from bson.son import SON |
47 | | -from pymongo import common, message |
| 24 | +from bson import SON |
| 25 | +from pymongo import common |
| 26 | +from pymongo._asyncio_task import create_task |
48 | 27 | from pymongo.read_preferences import ReadPreference
|
49 | | -from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] |
50 | | -from pymongo.synchronous.uri_parser import parse_uri |
51 | | - |
52 | | -if HAVE_SSL: |
53 | | - import ssl |
54 | 28 |
|
55 | 29 | _IS_SYNC = False
|
56 | 30 |
|
57 | | -# Enable debug output for uncollectable objects. PyPy does not have set_debug. |
58 | | -if hasattr(gc, "set_debug"): |
59 | | - gc.set_debug( |
60 | | - gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0) |
61 | | - ) |
62 | | - |
63 | | -# The host and port of a single mongod or mongos, or the seed host |
64 | | -# for a replica set. |
65 | | -host = os.environ.get("DB_IP", "localhost") |
66 | | -port = int(os.environ.get("DB_PORT", 27017)) |
67 | | -IS_SRV = "mongodb+srv" in host |
68 | | - |
69 | | -db_user = os.environ.get("DB_USER", "user") |
70 | | -db_pwd = os.environ.get("DB_PASSWORD", "password") |
71 | | - |
72 | | -CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates") |
73 | | -CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem")) |
74 | | -CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem")) |
75 | | - |
76 | | -TLS_OPTIONS: Dict = {"tls": True} |
77 | | -if CLIENT_PEM: |
78 | | - TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM |
79 | | -if CA_PEM: |
80 | | - TLS_OPTIONS["tlsCAFile"] = CA_PEM |
81 | | - |
82 | | -COMPRESSORS = os.environ.get("COMPRESSORS") |
83 | | -MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") |
84 | | -TEST_LOADBALANCER = bool(os.environ.get("TEST_LOAD_BALANCER")) |
85 | | -SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") |
86 | | -MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") |
87 | | - |
88 | | -if TEST_LOADBALANCER: |
89 | | - res = parse_uri(SINGLE_MONGOS_LB_URI or "") |
90 | | - host, port = res["nodelist"][0] |
91 | | - db_user = res["username"] or db_user |
92 | | - db_pwd = res["password"] or db_pwd |
93 | | - |
94 | | - |
95 | | -# Shared KMS data. |
96 | | -LOCAL_MASTER_KEY = base64.b64decode( |
97 | | - b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ" |
98 | | - b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk" |
99 | | -) |
100 | | -AWS_CREDS = { |
101 | | - "accessKeyId": os.environ.get("FLE_AWS_KEY", ""), |
102 | | - "secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""), |
103 | | -} |
104 | | -AWS_CREDS_2 = { |
105 | | - "accessKeyId": os.environ.get("FLE_AWS_KEY2", ""), |
106 | | - "secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""), |
107 | | -} |
108 | | -AZURE_CREDS = { |
109 | | - "tenantId": os.environ.get("FLE_AZURE_TENANTID", ""), |
110 | | - "clientId": os.environ.get("FLE_AZURE_CLIENTID", ""), |
111 | | - "clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""), |
112 | | -} |
113 | | -GCP_CREDS = { |
114 | | - "email": os.environ.get("FLE_GCP_EMAIL", ""), |
115 | | - "privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""), |
116 | | -} |
117 | | -KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")} |
118 | | - |
119 | | -# Ensure Evergreen metadata doesn't result in truncation |
120 | | -os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000") |
121 | | - |
122 | | - |
123 | | -def is_server_resolvable(): |
124 | | - """Returns True if 'server' is resolvable.""" |
125 | | - socket_timeout = socket.getdefaulttimeout() |
126 | | - socket.setdefaulttimeout(1) |
127 | | - try: |
128 | | - try: |
129 | | - socket.gethostbyname("server") |
130 | | - return True |
131 | | - except OSError: |
132 | | - return False |
133 | | - finally: |
134 | | - socket.setdefaulttimeout(socket_timeout) |
135 | | - |
136 | | - |
137 | | -def _create_user(authdb, user, pwd=None, roles=None, **kwargs): |
138 | | - cmd = SON([("createUser", user)]) |
139 | | - # X509 doesn't use a password |
140 | | - if pwd: |
141 | | - cmd["pwd"] = pwd |
142 | | - cmd["roles"] = roles or ["root"] |
143 | | - cmd.update(**kwargs) |
144 | | - return authdb.command(cmd) |
145 | | - |
146 | 31 |
|
147 | 32 | async def async_repl_set_step_down(client, **kwargs):
|
148 | 33 | """Run replSetStepDown, first unfreezing a secondary with replSetFreeze."""
|
@@ -237,133 +122,10 @@ def __del__(self):
|
237 | 122 | raise Exception(msg)
|
238 | 123 |
|
239 | 124 |
|
240 | | -def _all_users(db): |
241 | | - return {u["user"] for u in db.command("usersInfo").get("users", [])} |
242 | | - |
243 | | - |
244 | | -def sanitize_cmd(cmd): |
245 | | - cp = cmd.copy() |
246 | | - cp.pop("$clusterTime", None) |
247 | | - cp.pop("$db", None) |
248 | | - cp.pop("$readPreference", None) |
249 | | - cp.pop("lsid", None) |
250 | | - if MONGODB_API_VERSION: |
251 | | - # Stable API parameters |
252 | | - cp.pop("apiVersion", None) |
253 | | - # OP_MSG encoding may move the payload type one field to the |
254 | | - # end of the command. Do the same here. |
255 | | - name = next(iter(cp)) |
256 | | - try: |
257 | | - identifier = message._FIELD_MAP[name] |
258 | | - docs = cp.pop(identifier) |
259 | | - cp[identifier] = docs |
260 | | - except KeyError: |
261 | | - pass |
262 | | - return cp |
263 | | - |
264 | | - |
265 | | -def sanitize_reply(reply): |
266 | | - cp = reply.copy() |
267 | | - cp.pop("$clusterTime", None) |
268 | | - cp.pop("operationTime", None) |
269 | | - return cp |
270 | | - |
271 | | - |
272 | | -def print_thread_tracebacks() -> None: |
273 | | - """Print all Python thread tracebacks.""" |
274 | | - for thread_id, frame in sys._current_frames().items(): |
275 | | - sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n") |
276 | | - traceback.print_stack(frame, file=sys.stderr) |
277 | | - |
278 | | - |
279 | | -def print_thread_stacks(pid: int) -> None: |
280 | | - """Print all C-level thread stacks for a given process id.""" |
281 | | - if sys.platform == "darwin": |
282 | | - cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"'] |
283 | | - else: |
284 | | - cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"'] |
285 | | - |
286 | | - try: |
287 | | - res = subprocess.run( |
288 | | - cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8" |
289 | | - ) |
290 | | - except Exception as exc: |
291 | | - sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}") |
292 | | - else: |
293 | | - sys.stderr.write(res.stdout) |
294 | | - |
295 | | - |
296 | 125 | # Global knobs to speed up the test suite.
|
297 | 126 | global_knobs = client_knobs(events_queue_frequency=0.05)
|
298 | 127 |
|
299 | 128 |
|
300 | | -def _get_executors(topology): |
301 | | - executors = [] |
302 | | - for server in topology._servers.values(): |
303 | | - # Some MockMonitor do not have an _executor. |
304 | | - if hasattr(server._monitor, "_executor"): |
305 | | - executors.append(server._monitor._executor) |
306 | | - if hasattr(server._monitor, "_rtt_monitor"): |
307 | | - executors.append(server._monitor._rtt_monitor._executor) |
308 | | - executors.append(topology._Topology__events_executor) |
309 | | - if topology._srv_monitor: |
310 | | - executors.append(topology._srv_monitor._executor) |
311 | | - |
312 | | - return [e for e in executors if e is not None] |
313 | | - |
314 | | - |
315 | | -def print_running_topology(topology): |
316 | | - running = [e for e in _get_executors(topology) if not e._stopped] |
317 | | - if running: |
318 | | - print( |
319 | | - "WARNING: found Topology with running threads:\n" |
320 | | - f" Threads: {running}\n" |
321 | | - f" Topology: {topology}\n" |
322 | | - f" Creation traceback:\n{topology._settings._stack}" |
323 | | - ) |
324 | | - |
325 | | - |
326 | | -def test_cases(suite): |
327 | | - """Iterator over all TestCases within a TestSuite.""" |
328 | | - for suite_or_case in suite._tests: |
329 | | - if isinstance(suite_or_case, unittest.TestCase): |
330 | | - # unittest.TestCase |
331 | | - yield suite_or_case |
332 | | - else: |
333 | | - # unittest.TestSuite |
334 | | - yield from test_cases(suite_or_case) |
335 | | - |
336 | | - |
337 | | -# Helper method to workaround https://bugs.python.org/issue21724 |
338 | | -def clear_warning_registry(): |
339 | | - """Clear the __warningregistry__ for all modules.""" |
340 | | - for _, module in list(sys.modules.items()): |
341 | | - if hasattr(module, "__warningregistry__"): |
342 | | - module.__warningregistry__ = {} # type:ignore[attr-defined] |
343 | | - |
344 | | - |
345 | | -class SystemCertsPatcher: |
346 | | - def __init__(self, ca_certs): |
347 | | - if ( |
348 | | - ssl.OPENSSL_VERSION.lower().startswith("libressl") |
349 | | - and sys.platform == "darwin" |
350 | | - and not _ssl.IS_PYOPENSSL |
351 | | - ): |
352 | | - raise SkipTest( |
353 | | - "LibreSSL on OSX doesn't support setting CA certificates " |
354 | | - "using SSL_CERT_FILE environment variable." |
355 | | - ) |
356 | | - self.original_certs = os.environ.get("SSL_CERT_FILE") |
357 | | - # Tell OpenSSL where CA certificates live. |
358 | | - os.environ["SSL_CERT_FILE"] = ca_certs |
359 | | - |
360 | | - def disable(self): |
361 | | - if self.original_certs is None: |
362 | | - os.environ.pop("SSL_CERT_FILE") |
363 | | - else: |
364 | | - os.environ["SSL_CERT_FILE"] = self.original_certs |
365 | | - |
366 | | - |
367 | 129 | if _IS_SYNC:
|
368 | 130 | PARENT = threading.Thread
|
369 | 131 | else:
|
|
0 commit comments