Skip to content

Commit 9ddc416

Browse files
authored
bpo-36842: Fix reference leak in tests by running out-of-proc (GH-13556)
1 parent d8b7551 commit 9ddc416

File tree

3 files changed

+323
-230
lines changed

3 files changed

+323
-230
lines changed

Lib/test/audit-tests.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
"""This script contains the actual auditing tests.
2+
3+
It should not be imported directly, but should be run by the test_audit
4+
module with arguments identifying each test.
5+
6+
"""
7+
8+
import contextlib
9+
import sys
10+
11+
12+
class TestHook:
13+
"""Used in standard hook tests to collect any logged events.
14+
15+
Should be used in a with block to ensure that it has no impact
16+
after the test completes.
17+
"""
18+
19+
def __init__(self, raise_on_events=None, exc_type=RuntimeError):
20+
self.raise_on_events = raise_on_events or ()
21+
self.exc_type = exc_type
22+
self.seen = []
23+
self.closed = False
24+
25+
def __enter__(self, *a):
26+
sys.addaudithook(self)
27+
return self
28+
29+
def __exit__(self, *a):
30+
self.close()
31+
32+
def close(self):
33+
self.closed = True
34+
35+
@property
36+
def seen_events(self):
37+
return [i[0] for i in self.seen]
38+
39+
def __call__(self, event, args):
40+
if self.closed:
41+
return
42+
self.seen.append((event, args))
43+
if event in self.raise_on_events:
44+
raise self.exc_type("saw event " + event)
45+
46+
47+
class TestFinalizeHook:
48+
"""Used in the test_finalize_hooks function to ensure that hooks
49+
are correctly cleaned up, that they are notified about the cleanup,
50+
and are unable to prevent it.
51+
"""
52+
53+
def __init__(self):
54+
print("Created", id(self), file=sys.stdout, flush=True)
55+
56+
def __call__(self, event, args):
57+
# Avoid recursion when we call id() below
58+
if event == "builtins.id":
59+
return
60+
61+
print(event, id(self), file=sys.stdout, flush=True)
62+
63+
if event == "cpython._PySys_ClearAuditHooks":
64+
raise RuntimeError("Should be ignored")
65+
elif event == "cpython.PyInterpreterState_Clear":
66+
raise RuntimeError("Should be ignored")
67+
68+
69+
# Simple helpers, since we are not in unittest here
70+
def assertEqual(x, y):
71+
if x != y:
72+
raise AssertionError(f"{x!r} should equal {y!r}")
73+
74+
75+
def assertIn(el, series):
76+
if el not in series:
77+
raise AssertionError(f"{el!r} should be in {series!r}")
78+
79+
80+
def assertNotIn(el, series):
81+
if el in series:
82+
raise AssertionError(f"{el!r} should not be in {series!r}")
83+
84+
85+
def assertSequenceEqual(x, y):
86+
if len(x) != len(y):
87+
raise AssertionError(f"{x!r} should equal {y!r}")
88+
if any(ix != iy for ix, iy in zip(x, y)):
89+
raise AssertionError(f"{x!r} should equal {y!r}")
90+
91+
92+
@contextlib.contextmanager
93+
def assertRaises(ex_type):
94+
try:
95+
yield
96+
assert False, f"expected {ex_type}"
97+
except BaseException as ex:
98+
if isinstance(ex, AssertionError):
99+
raise
100+
assert type(ex) is ex_type, f"{ex} should be {ex_type}"
101+
102+
103+
def test_basic():
104+
with TestHook() as hook:
105+
sys.audit("test_event", 1, 2, 3)
106+
assertEqual(hook.seen[0][0], "test_event")
107+
assertEqual(hook.seen[0][1], (1, 2, 3))
108+
109+
110+
def test_block_add_hook():
111+
# Raising an exception should prevent a new hook from being added,
112+
# but will not propagate out.
113+
with TestHook(raise_on_events="sys.addaudithook") as hook1:
114+
with TestHook() as hook2:
115+
sys.audit("test_event")
116+
assertIn("test_event", hook1.seen_events)
117+
assertNotIn("test_event", hook2.seen_events)
118+
119+
120+
def test_block_add_hook_baseexception():
121+
# Raising BaseException will propagate out when adding a hook
122+
with assertRaises(BaseException):
123+
with TestHook(
124+
raise_on_events="sys.addaudithook", exc_type=BaseException
125+
) as hook1:
126+
# Adding this next hook should raise BaseException
127+
with TestHook() as hook2:
128+
pass
129+
130+
131+
def test_finalize_hooks():
132+
sys.addaudithook(TestFinalizeHook())
133+
134+
135+
def test_pickle():
136+
import pickle
137+
138+
class PicklePrint:
139+
def __reduce_ex__(self, p):
140+
return str, ("Pwned!",)
141+
142+
payload_1 = pickle.dumps(PicklePrint())
143+
payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
144+
145+
# Before we add the hook, ensure our malicious pickle loads
146+
assertEqual("Pwned!", pickle.loads(payload_1))
147+
148+
with TestHook(raise_on_events="pickle.find_class") as hook:
149+
with assertRaises(RuntimeError):
150+
# With the hook enabled, loading globals is not allowed
151+
pickle.loads(payload_1)
152+
# pickles with no globals are okay
153+
pickle.loads(payload_2)
154+
155+
156+
def test_monkeypatch():
157+
class A:
158+
pass
159+
160+
class B:
161+
pass
162+
163+
class C(A):
164+
pass
165+
166+
a = A()
167+
168+
with TestHook() as hook:
169+
# Catch name changes
170+
C.__name__ = "X"
171+
# Catch type changes
172+
C.__bases__ = (B,)
173+
# Ensure bypassing __setattr__ is still caught
174+
type.__dict__["__bases__"].__set__(C, (B,))
175+
# Catch attribute replacement
176+
C.__init__ = B.__init__
177+
# Catch attribute addition
178+
C.new_attr = 123
179+
# Catch class changes
180+
a.__class__ = B
181+
182+
actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
183+
assertSequenceEqual(
184+
[(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
185+
)
186+
187+
188+
def test_open():
189+
# SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
190+
try:
191+
import ssl
192+
193+
load_dh_params = ssl.create_default_context().load_dh_params
194+
except ImportError:
195+
load_dh_params = None
196+
197+
# Try a range of "open" functions.
198+
# All of them should fail
199+
with TestHook(raise_on_events={"open"}) as hook:
200+
for fn, *args in [
201+
(open, sys.argv[2], "r"),
202+
(open, sys.executable, "rb"),
203+
(open, 3, "wb"),
204+
(open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
205+
(load_dh_params, sys.argv[2]),
206+
]:
207+
if not fn:
208+
continue
209+
with assertRaises(RuntimeError):
210+
fn(*args)
211+
212+
actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
213+
actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
214+
assertSequenceEqual(
215+
[
216+
i
217+
for i in [
218+
(sys.argv[2], "r"),
219+
(sys.executable, "r"),
220+
(3, "w"),
221+
(sys.argv[2], "w"),
222+
(sys.argv[2], "rb") if load_dh_params else None,
223+
]
224+
if i is not None
225+
],
226+
actual_mode,
227+
)
228+
assertSequenceEqual([], actual_flag)
229+
230+
231+
def test_cantrace():
232+
traced = []
233+
234+
def trace(frame, event, *args):
235+
if frame.f_code == TestHook.__call__.__code__:
236+
traced.append(event)
237+
238+
old = sys.settrace(trace)
239+
try:
240+
with TestHook() as hook:
241+
# No traced call
242+
eval("1")
243+
244+
# No traced call
245+
hook.__cantrace__ = False
246+
eval("2")
247+
248+
# One traced call
249+
hook.__cantrace__ = True
250+
eval("3")
251+
252+
# Two traced calls (writing to private member, eval)
253+
hook.__cantrace__ = 1
254+
eval("4")
255+
256+
# One traced call (writing to private member)
257+
hook.__cantrace__ = 0
258+
finally:
259+
sys.settrace(old)
260+
261+
assertSequenceEqual(["call"] * 4, traced)
262+
263+
264+
if __name__ == "__main__":
265+
from test.libregrtest.setup import suppress_msvcrt_asserts
266+
suppress_msvcrt_asserts(False)
267+
268+
test = sys.argv[1]
269+
globals()[test]()

Lib/test/libregrtest/setup.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -83,27 +83,7 @@ def setup_tests(ns):
8383
if ns.threshold is not None:
8484
gc.set_threshold(ns.threshold)
8585

86-
try:
87-
import msvcrt
88-
except ImportError:
89-
pass
90-
else:
91-
msvcrt.SetErrorMode(msvcrt.SEM_FAILCRITICALERRORS|
92-
msvcrt.SEM_NOALIGNMENTFAULTEXCEPT|
93-
msvcrt.SEM_NOGPFAULTERRORBOX|
94-
msvcrt.SEM_NOOPENFILEERRORBOX)
95-
try:
96-
msvcrt.CrtSetReportMode
97-
except AttributeError:
98-
# release build
99-
pass
100-
else:
101-
for m in [msvcrt.CRT_WARN, msvcrt.CRT_ERROR, msvcrt.CRT_ASSERT]:
102-
if ns.verbose and ns.verbose >= 2:
103-
msvcrt.CrtSetReportMode(m, msvcrt.CRTDBG_MODE_FILE)
104-
msvcrt.CrtSetReportFile(m, msvcrt.CRTDBG_FILE_STDERR)
105-
else:
106-
msvcrt.CrtSetReportMode(m, 0)
86+
suppress_msvcrt_asserts(ns.verbose and ns.verbose >= 2)
10787

10888
support.use_resources = ns.use_resources
10989

@@ -114,6 +94,31 @@ def _test_audit_hook(name, args):
11494
sys.addaudithook(_test_audit_hook)
11595

11696

97+
def suppress_msvcrt_asserts(verbose):
98+
try:
99+
import msvcrt
100+
except ImportError:
101+
return
102+
103+
msvcrt.SetErrorMode(msvcrt.SEM_FAILCRITICALERRORS|
104+
msvcrt.SEM_NOALIGNMENTFAULTEXCEPT|
105+
msvcrt.SEM_NOGPFAULTERRORBOX|
106+
msvcrt.SEM_NOOPENFILEERRORBOX)
107+
try:
108+
msvcrt.CrtSetReportMode
109+
except AttributeError:
110+
# release build
111+
return
112+
113+
for m in [msvcrt.CRT_WARN, msvcrt.CRT_ERROR, msvcrt.CRT_ASSERT]:
114+
if verbose:
115+
msvcrt.CrtSetReportMode(m, msvcrt.CRTDBG_MODE_FILE)
116+
msvcrt.CrtSetReportFile(m, msvcrt.CRTDBG_FILE_STDERR)
117+
else:
118+
msvcrt.CrtSetReportMode(m, 0)
119+
120+
121+
117122
def replace_stdout():
118123
"""Set stdout encoder error handler to backslashreplace (as stderr error
119124
handler) to avoid UnicodeEncodeError when printing a traceback"""

0 commit comments

Comments
 (0)