Skip to content

bpo-36842: Fix reference leak in tests by running out-of-proc #13556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions Lib/test/audit-tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
"""This script contains the actual auditing tests.

It should not be imported directly, but should be run by the test_audit
module with arguments identifying each test.

"""

import contextlib
import sys


class TestHook:
"""Used in standard hook tests to collect any logged events.

Should be used in a with block to ensure that it has no impact
after the test completes.
"""

def __init__(self, raise_on_events=None, exc_type=RuntimeError):
self.raise_on_events = raise_on_events or ()
self.exc_type = exc_type
self.seen = []
self.closed = False

def __enter__(self, *a):
sys.addaudithook(self)
return self

def __exit__(self, *a):
self.close()

def close(self):
self.closed = True

@property
def seen_events(self):
return [i[0] for i in self.seen]

def __call__(self, event, args):
if self.closed:
return
self.seen.append((event, args))
if event in self.raise_on_events:
raise self.exc_type("saw event " + event)


class TestFinalizeHook:
"""Used in the test_finalize_hooks function to ensure that hooks
are correctly cleaned up, that they are notified about the cleanup,
and are unable to prevent it.
"""

def __init__(self):
print("Created", id(self), file=sys.stdout, flush=True)

def __call__(self, event, args):
# Avoid recursion when we call id() below
if event == "builtins.id":
return

print(event, id(self), file=sys.stdout, flush=True)

if event == "cpython._PySys_ClearAuditHooks":
raise RuntimeError("Should be ignored")
elif event == "cpython.PyInterpreterState_Clear":
raise RuntimeError("Should be ignored")


# Simple helpers, since we are not in unittest here
def assertEqual(x, y):
if x != y:
raise AssertionError(f"{x!r} should equal {y!r}")


def assertIn(el, series):
if el not in series:
raise AssertionError(f"{el!r} should be in {series!r}")


def assertNotIn(el, series):
if el in series:
raise AssertionError(f"{el!r} should not be in {series!r}")


def assertSequenceEqual(x, y):
if len(x) != len(y):
raise AssertionError(f"{x!r} should equal {y!r}")
if any(ix != iy for ix, iy in zip(x, y)):
raise AssertionError(f"{x!r} should equal {y!r}")


@contextlib.contextmanager
def assertRaises(ex_type):
try:
yield
assert False, f"expected {ex_type}"
except BaseException as ex:
if isinstance(ex, AssertionError):
raise
assert type(ex) is ex_type, f"{ex} should be {ex_type}"


def test_basic():
with TestHook() as hook:
sys.audit("test_event", 1, 2, 3)
assertEqual(hook.seen[0][0], "test_event")
assertEqual(hook.seen[0][1], (1, 2, 3))


def test_block_add_hook():
# Raising an exception should prevent a new hook from being added,
# but will not propagate out.
with TestHook(raise_on_events="sys.addaudithook") as hook1:
with TestHook() as hook2:
sys.audit("test_event")
assertIn("test_event", hook1.seen_events)
assertNotIn("test_event", hook2.seen_events)


def test_block_add_hook_baseexception():
# Raising BaseException will propagate out when adding a hook
with assertRaises(BaseException):
with TestHook(
raise_on_events="sys.addaudithook", exc_type=BaseException
) as hook1:
# Adding this next hook should raise BaseException
with TestHook() as hook2:
pass


def test_finalize_hooks():
sys.addaudithook(TestFinalizeHook())


def test_pickle():
import pickle

class PicklePrint:
def __reduce_ex__(self, p):
return str, ("Pwned!",)

payload_1 = pickle.dumps(PicklePrint())
payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))

# Before we add the hook, ensure our malicious pickle loads
assertEqual("Pwned!", pickle.loads(payload_1))

with TestHook(raise_on_events="pickle.find_class") as hook:
with assertRaises(RuntimeError):
# With the hook enabled, loading globals is not allowed
pickle.loads(payload_1)
# pickles with no globals are okay
pickle.loads(payload_2)


def test_monkeypatch():
class A:
pass

class B:
pass

class C(A):
pass

a = A()

with TestHook() as hook:
# Catch name changes
C.__name__ = "X"
# Catch type changes
C.__bases__ = (B,)
# Ensure bypassing __setattr__ is still caught
type.__dict__["__bases__"].__set__(C, (B,))
# Catch attribute replacement
C.__init__ = B.__init__
# Catch attribute addition
C.new_attr = 123
# Catch class changes
a.__class__ = B

actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
assertSequenceEqual(
[(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
)


def test_open():
# SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
try:
import ssl

load_dh_params = ssl.create_default_context().load_dh_params
except ImportError:
load_dh_params = None

# Try a range of "open" functions.
# All of them should fail
with TestHook(raise_on_events={"open"}) as hook:
for fn, *args in [
(open, sys.argv[2], "r"),
(open, sys.executable, "rb"),
(open, 3, "wb"),
(open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
(load_dh_params, sys.argv[2]),
]:
if not fn:
continue
with assertRaises(RuntimeError):
fn(*args)

actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
assertSequenceEqual(
[
i
for i in [
(sys.argv[2], "r"),
(sys.executable, "r"),
(3, "w"),
(sys.argv[2], "w"),
(sys.argv[2], "rb") if load_dh_params else None,
]
if i is not None
],
actual_mode,
)
assertSequenceEqual([], actual_flag)


def test_cantrace():
traced = []

def trace(frame, event, *args):
if frame.f_code == TestHook.__call__.__code__:
traced.append(event)

old = sys.settrace(trace)
try:
with TestHook() as hook:
# No traced call
eval("1")

# No traced call
hook.__cantrace__ = False
eval("2")

# One traced call
hook.__cantrace__ = True
eval("3")

# Two traced calls (writing to private member, eval)
hook.__cantrace__ = 1
eval("4")

# One traced call (writing to private member)
hook.__cantrace__ = 0
finally:
sys.settrace(old)

assertSequenceEqual(["call"] * 4, traced)


if __name__ == "__main__":
from test.libregrtest.setup import suppress_msvcrt_asserts
suppress_msvcrt_asserts(False)

test = sys.argv[1]
globals()[test]()
47 changes: 26 additions & 21 deletions Lib/test/libregrtest/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,7 @@ def setup_tests(ns):
if ns.threshold is not None:
gc.set_threshold(ns.threshold)

try:
import msvcrt
except ImportError:
pass
else:
msvcrt.SetErrorMode(msvcrt.SEM_FAILCRITICALERRORS|
msvcrt.SEM_NOALIGNMENTFAULTEXCEPT|
msvcrt.SEM_NOGPFAULTERRORBOX|
msvcrt.SEM_NOOPENFILEERRORBOX)
try:
msvcrt.CrtSetReportMode
except AttributeError:
# release build
pass
else:
for m in [msvcrt.CRT_WARN, msvcrt.CRT_ERROR, msvcrt.CRT_ASSERT]:
if ns.verbose and ns.verbose >= 2:
msvcrt.CrtSetReportMode(m, msvcrt.CRTDBG_MODE_FILE)
msvcrt.CrtSetReportFile(m, msvcrt.CRTDBG_FILE_STDERR)
else:
msvcrt.CrtSetReportMode(m, 0)
suppress_msvcrt_asserts(ns.verbose and ns.verbose >= 2)

support.use_resources = ns.use_resources

Expand All @@ -114,6 +94,31 @@ def _test_audit_hook(name, args):
sys.addaudithook(_test_audit_hook)


def suppress_msvcrt_asserts(verbose):
try:
import msvcrt
except ImportError:
return

msvcrt.SetErrorMode(msvcrt.SEM_FAILCRITICALERRORS|
msvcrt.SEM_NOALIGNMENTFAULTEXCEPT|
msvcrt.SEM_NOGPFAULTERRORBOX|
msvcrt.SEM_NOOPENFILEERRORBOX)
try:
msvcrt.CrtSetReportMode
except AttributeError:
# release build
return

for m in [msvcrt.CRT_WARN, msvcrt.CRT_ERROR, msvcrt.CRT_ASSERT]:
if verbose:
msvcrt.CrtSetReportMode(m, msvcrt.CRTDBG_MODE_FILE)
msvcrt.CrtSetReportFile(m, msvcrt.CRTDBG_FILE_STDERR)
else:
msvcrt.CrtSetReportMode(m, 0)



def replace_stdout():
"""Set stdout encoder error handler to backslashreplace (as stderr error
handler) to avoid UnicodeEncodeError when printing a traceback"""
Expand Down
Loading