Skip to content

Commit 8d73d6e

Browse files
authored
[mypyc] Fix segfault when top level raises exception (#10586)
If an error occurs while running the initialization code, set the CPyModule_<name>_internal module pointer to NULL so future attempts to import that same module don't mistakenly think that the module is already initialized due to the fact that that module pointer is not NULL. Clearing that module pointer on error allows us to keep initializing it at the beginning of the function before running any top level code, which is necessary to prevent RecursionErrors when dealing with circular imports.
1 parent b562cc2 commit 8d73d6e

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

mypyc/codegen/emitmodule.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -888,15 +888,15 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
888888

889889
emitter.emit_lines('{} = PyModule_Create(&{}module);'.format(module_static, module_prefix),
890890
'if (unlikely({} == NULL))'.format(module_static),
891-
' return NULL;')
891+
' goto fail;')
892892
emitter.emit_line(
893893
'PyObject *modname = PyObject_GetAttrString((PyObject *){}, "__name__");'.format(
894894
module_static))
895895

896896
module_globals = emitter.static_name('globals', module_name)
897897
emitter.emit_lines('{} = PyModule_GetDict({});'.format(module_globals, module_static),
898898
'if (unlikely({} == NULL))'.format(module_globals),
899-
' return NULL;')
899+
' goto fail;')
900900

901901
# HACK: Manually instantiate generated classes here
902902
for cl in module.classes:
@@ -907,16 +907,19 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
907907
'(PyObject *){t}_template, NULL, modname);'
908908
.format(t=type_struct))
909909
emitter.emit_lines('if (unlikely(!{}))'.format(type_struct),
910-
' return NULL;')
910+
' goto fail;')
911911

912912
emitter.emit_lines('if (CPyGlobalsInit() < 0)',
913-
' return NULL;')
913+
' goto fail;')
914914

915915
self.generate_top_level_call(module, emitter)
916916

917917
emitter.emit_lines('Py_DECREF(modname);')
918918

919919
emitter.emit_line('return {};'.format(module_static))
920+
emitter.emit_lines('fail:',
921+
'{} = NULL;'.format(module_static),
922+
'return NULL;')
920923
emitter.emit_line('}')
921924

922925
def generate_top_level_call(self, module: ModuleIR, emitter: Emitter) -> None:
@@ -927,7 +930,7 @@ def generate_top_level_call(self, module: ModuleIR, emitter: Emitter) -> None:
927930
emitter.emit_lines(
928931
'char result = {}();'.format(emitter.native_function_name(fn.decl)),
929932
'if (result == 2)',
930-
' return NULL;',
933+
' goto fail;',
931934
)
932935
break
933936

mypyc/test-data/run-misc.test

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,3 +1037,27 @@ C = sys.platform == 'x' and (lambda x: y + x)
10371037
assert not A
10381038
assert not B
10391039
assert not C
1040+
1041+
[case testDoesntSegfaultWhenTopLevelFails]
1042+
# make the initial import fail
1043+
assert False
1044+
1045+
class C:
1046+
def __init__(self):
1047+
self.x = 1
1048+
self.y = 2
1049+
def test() -> None:
1050+
a = C()
1051+
[file driver.py]
1052+
# load native, cause PyInit to be run, create the module but don't finish initializing the globals
1053+
try:
1054+
import native
1055+
except:
1056+
pass
1057+
try:
1058+
# try accessing those globals that were never properly initialized
1059+
import native
1060+
native.test()
1061+
# should fail with AssertionError due to `assert False` in other function
1062+
except AssertionError:
1063+
pass

0 commit comments

Comments
 (0)