Skip to content

Commit fe1a174

Browse files
committed
ChatGPT suggestions
1 parent b3dee33 commit fe1a174

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

cuda_bindings/tests/run_python_code_safely.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,24 @@
88
from dataclasses import dataclass
99
from io import StringIO
1010

11+
PROCESS_KILLED = -9
12+
PROCESS_NO_RESULT = -999
13+
14+
15+
# Similar to https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess
16+
# (args, check_returncode() are intentionally not supported here.)
17+
@dataclass
18+
class CompletedProcess:
19+
returncode: int
20+
stdout: str
21+
stderr: str
22+
1123

1224
class Worker:
1325
def __init__(self, result_queue, func, args, kwargs):
1426
self.func = func
15-
self.args = args or ()
16-
self.kwargs = kwargs or {}
27+
self.args = () if args is None else args
28+
self.kwargs = {} if kwargs is None else kwargs
1729
self.result_queue = result_queue
1830

1931
def __call__(self):
@@ -44,17 +56,14 @@ def __call__(self):
4456
pass
4557

4658

47-
# Similar to https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess
48-
# (args, check_returncode() are intentionally not supported here.)
49-
@dataclass
50-
class CompletedProcess:
51-
returncode: int
52-
stdout: str
53-
stderr: str
59+
def run_in_spawned_child_process(func, *, args=None, kwargs=None, timeout=None):
60+
"""Run `func` in a spawned child process, capturing stdout/stderr.
5461
62+
The provided `func` must be defined at the top level of a module, and must
63+
be importable in the spawned child process. Lambdas, closures, or interactively
64+
defined functions (e.g., in Jupyter notebooks) will not work.
65+
"""
5566

56-
def run_in_spawned_child_process(func, *, args=None, kwargs=None, timeout=None):
57-
"""Run Python code in a spawned child process, capturing stdout/stderr/output."""
5867
ctx = multiprocessing.get_context("spawn")
5968
result_queue = ctx.Queue()
6069
process = ctx.Process(target=Worker(result_queue, func, args, kwargs))
@@ -66,7 +75,7 @@ def run_in_spawned_child_process(func, *, args=None, kwargs=None, timeout=None):
6675
process.terminate()
6776
process.join()
6877
return CompletedProcess(
69-
returncode=-9,
78+
returncode=PROCESS_KILLED,
7079
stdout="",
7180
stderr=f"Process timed out after {timeout} seconds and was terminated.",
7281
)
@@ -75,7 +84,7 @@ def run_in_spawned_child_process(func, *, args=None, kwargs=None, timeout=None):
7584
returncode, stdout, stderr = result_queue.get(timeout=1.0)
7685
except (queue.Empty, EOFError):
7786
return CompletedProcess(
78-
returncode=-999,
87+
returncode=PROCESS_NO_RESULT,
7988
stdout="",
8089
stderr="Process exited or crashed before returning results.",
8190
)

0 commit comments

Comments
 (0)