Skip to content

Commit 9453f5f

Browse files
authored
[mypyc] Allow dots in group names (#7941)
Single-module groups will be given a dotted group name equal to the module name by default. A dot in a group name will put the group C extension at that path. Generated C could will also be placed in a directory corresponding to the full dotted name (still under the build directory). The direct motivation here is that bazel wants artifacts to be under the appropriate directory, but it is cleaner in general I think.
1 parent 1b798e4 commit 9453f5f

File tree

6 files changed

+59
-33
lines changed

6 files changed

+59
-33
lines changed

mypyc/build.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def get_mypy_config(mypy_options: List[str],
9292
else:
9393
mypyc_sources = all_sources
9494

95+
if compiler_options.separate:
96+
mypyc_sources = [src for src in mypyc_sources
97+
if src.path and not src.path.endswith('__init__.py')]
98+
9599
if not mypyc_sources:
96100
return mypyc_sources, all_sources, options
97101

@@ -126,7 +130,7 @@ def generate_c_extension_shim(
126130
dir_name: the directory to place source code
127131
group_name: the name of the group
128132
"""
129-
cname = '%s.c' % exported_name(full_module_name)
133+
cname = '%s.c' % full_module_name.replace('.', os.sep)
130134
cpath = os.path.join(dir_name, cname)
131135

132136
# We load the C extension shim template from a file.
@@ -146,7 +150,7 @@ def generate_c_extension_shim(
146150
def group_name(modules: List[str]) -> str:
147151
"""Produce a probably unique name for a group from a list of module names."""
148152
if len(modules) == 1:
149-
return exported_name(modules[0])
153+
return modules[0]
150154

151155
h = hashlib.sha1()
152156
h.update(','.join(modules).encode())
@@ -226,7 +230,7 @@ def build_using_shared_lib(sources: List[BuildSource],
226230
extensions = [get_extension()(
227231
shared_lib_name(group_name),
228232
sources=cfiles,
229-
include_dirs=[include_dir()],
233+
include_dirs=[include_dir(), build_dir],
230234
depends=deps,
231235
extra_compile_args=extra_compile_args,
232236
)]

mypyc/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ def decorator_helper_name(func_name: str) -> str:
3636
def shared_lib_name(group_name: str) -> str:
3737
"""Given a group name, return the actual name of its extension module.
3838
39-
(This just adds a prefix.)
39+
(This just adds a suffix to the final component.)
4040
"""
41-
return 'mypyc_{}'.format(group_name)
41+
return '{}__mypyc'.format(group_name)

mypyc/emit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive, ClassIR,
1616
FuncDecl, int_rprimitive, is_optional_type, optional_value_type, all_concrete_classes
1717
)
18-
from mypyc.namegen import NameGenerator
18+
from mypyc.namegen import NameGenerator, exported_name
1919
from mypyc.sametype import is_same_type
2020

2121

@@ -167,7 +167,7 @@ def get_module_group_prefix(self, module_name: str) -> str:
167167
target_group_name = groups.get(module_name)
168168
if target_group_name and target_group_name != self.context.group_name:
169169
self.context.group_deps.add(target_group_name)
170-
return 'exports_{}.'.format(target_group_name)
170+
return 'exports_{}.'.format(exported_name(target_group_name))
171171
else:
172172
return ''
173173

mypyc/emitmodule.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,11 @@ def pointerize(decl: str, name: str) -> str:
437437
return decl.replace(name, '*{}'.format(name))
438438

439439

440+
def group_dir(group_name: str) -> str:
441+
"""Given a group name, return the relative directory path for it. """
442+
return os.sep.join(group_name.split('.')[:-1])
443+
444+
440445
class GroupGenerator:
441446
def __init__(self,
442447
literals: LiteralsMap,
@@ -477,7 +482,11 @@ def __init__(self,
477482

478483
@property
479484
def group_suffix(self) -> str:
480-
return '_' + self.group_name if self.group_name else ''
485+
return '_' + exported_name(self.group_name) if self.group_name else ''
486+
487+
@property
488+
def short_group_suffix(self) -> str:
489+
return '_' + exported_name(self.group_name.split('.')[-1]) if self.group_name else ''
481490

482491
def generate_c_for_modules(self) -> List[Tuple[str, str]]:
483492
file_contents = []
@@ -489,8 +498,8 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
489498
if self.compiler_options.include_runtime_files:
490499
base_emitter.emit_line('#include "CPy.c"')
491500
base_emitter.emit_line('#include "getargs.c"')
492-
base_emitter.emit_line('#include "__native{}.h"'.format(self.group_suffix))
493-
base_emitter.emit_line('#include "__native_internal{}.h"'.format(self.group_suffix))
501+
base_emitter.emit_line('#include "__native{}.h"'.format(self.short_group_suffix))
502+
base_emitter.emit_line('#include "__native_internal{}.h"'.format(self.short_group_suffix))
494503
emitter = base_emitter
495504

496505
for (_, literal), identifier in self.literals.items():
@@ -503,8 +512,9 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
503512
for module_name, module in self.modules:
504513
if multi_file:
505514
emitter = Emitter(self.context)
506-
emitter.emit_line('#include "__native{}.h"'.format(self.group_suffix))
507-
emitter.emit_line('#include "__native_internal{}.h"'.format(self.group_suffix))
515+
emitter.emit_line('#include "__native{}.h"'.format(self.short_group_suffix))
516+
emitter.emit_line(
517+
'#include "__native_internal{}.h"'.format(self.short_group_suffix))
508518

509519
self.declare_module(module_name, emitter)
510520
self.declare_internal_globals(module_name, emitter)
@@ -544,7 +554,7 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
544554
declarations.emit_line('#define MYPYC_NATIVE_INTERNAL{}_H'.format(self.group_suffix))
545555
declarations.emit_line('#include <Python.h>')
546556
declarations.emit_line('#include <CPy.h>')
547-
declarations.emit_line('#include "__native{}.h"'.format(self.group_suffix))
557+
declarations.emit_line('#include "__native{}.h"'.format(self.short_group_suffix))
548558
declarations.emit_line()
549559
declarations.emit_line('int CPyGlobalsInit(void);')
550560
declarations.emit_line()
@@ -557,9 +567,13 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
557567
generate_function_declaration(fn, declarations)
558568

559569
for lib in sorted(self.context.group_deps):
570+
elib = exported_name(lib)
571+
short_lib = exported_name(lib.split('.')[-1])
560572
declarations.emit_lines(
561-
'#include "__native_{}.h"'.format(lib),
562-
'struct export_table_{} exports_{};'.format(lib, lib)
573+
'#include <{}>'.format(
574+
os.path.join(group_dir(lib), "__native_{}.h".format(short_lib))
575+
),
576+
'struct export_table_{} exports_{};'.format(elib, elib)
563577
)
564578

565579
sorted_decls = self.toposort_declarations()
@@ -591,13 +605,15 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
591605
ext_declarations.emit_line('#endif')
592606
declarations.emit_line('#endif')
593607

594-
return file_contents + [('__native{}.c'.format(self.group_suffix),
595-
''.join(emitter.fragments)),
596-
('__native_internal{}.h'.format(self.group_suffix),
597-
''.join(declarations.fragments)),
598-
('__native{}.h'.format(self.group_suffix),
599-
''.join(ext_declarations.fragments)),
600-
]
608+
output_dir = group_dir(self.group_name) if self.group_name else ''
609+
return file_contents + [
610+
(os.path.join(output_dir, '__native{}.c'.format(self.short_group_suffix)),
611+
''.join(emitter.fragments)),
612+
(os.path.join(output_dir, '__native_internal{}.h'.format(self.short_group_suffix)),
613+
''.join(declarations.fragments)),
614+
(os.path.join(output_dir, '__native{}.h'.format(self.short_group_suffix)),
615+
''.join(ext_declarations.fragments)),
616+
]
601617

602618
def generate_export_table(self, decl_emitter: Emitter, code_emitter: Emitter) -> None:
603619
"""Generate the declaration and definition of the group's export struct.
@@ -683,12 +699,14 @@ def generate_shared_lib_init(self, emitter: Emitter) -> None:
683699

684700
emitter.emit_line()
685701
emitter.emit_lines(
686-
'PyMODINIT_FUNC PyInit_{}(void)'.format(shared_lib_name(self.group_name)),
702+
'PyMODINIT_FUNC PyInit_{}(void)'.format(
703+
shared_lib_name(self.group_name).split('.')[-1]),
687704
'{',
688705
('static PyModuleDef def = {{ PyModuleDef_HEAD_INIT, "{}", NULL, -1, NULL, NULL }};'
689-
.format(self.group_name)),
706+
.format(shared_lib_name(self.group_name))),
690707
'int res;',
691708
'PyObject *capsule;',
709+
'PyObject *tmp;',
692710
'static PyObject *module;',
693711
'if (module) {',
694712
'Py_INCREF(module);',
@@ -733,14 +751,17 @@ def generate_shared_lib_init(self, emitter: Emitter) -> None:
733751
)
734752

735753
for group in sorted(self.context.group_deps):
754+
egroup = exported_name(group)
736755
emitter.emit_lines(
756+
'tmp = PyImport_ImportModule("{}"); if (!tmp) goto fail; Py_DECREF(tmp);'.format(
757+
shared_lib_name(group)),
737758
'struct export_table_{} *pexports_{} = PyCapsule_Import("{}.exports", 0);'.format(
738-
group, group, shared_lib_name(group)),
739-
'if (!pexports_{}) {{'.format(group),
759+
egroup, egroup, shared_lib_name(group)),
760+
'if (!pexports_{}) {{'.format(egroup),
740761
'goto fail;',
741762
'}',
742763
'memcpy(&exports_{group}, pexports_{group}, sizeof(exports_{group}));'.format(
743-
group=group),
764+
group=egroup),
744765
'',
745766
)
746767

mypyc/test-data/run-multimodule.test

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
-- These test cases compile two modules at a time (native and other.py)
22

3-
[case testMultiModuleBasic]
4-
from other import g
3+
[case testMultiModulePackage]
4+
from p.other import g
55
def f(x: int) -> int:
6-
from other import h
6+
from p.other import h
77
return h(g(x + 1))
8-
[file other.py]
8+
[file p/__init__.py]
9+
[file p/other.py]
910
def g(x: int) -> int:
1011
return x + 2
1112
def h(x: int) -> int:
1213
return x + 1
1314
[file driver.py]
1415
import native
1516
from native import f
16-
from other import g
17+
from p.other import g
1718
assert f(3) == 7
1819
assert g(2) == 4
1920
try:

mypyc/test/test_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) ->
184184
fn = os.path.relpath(fn, test_temp_dir)
185185

186186
if os.path.basename(fn).startswith('other') and fn.endswith('.py'):
187-
name = os.path.basename(fn).split('.')[0]
187+
name = fn.split('.')[0].replace(os.sep, '.')
188188
module_names.append(name)
189189
sources.append(build.BuildSource(fn, name, None))
190190
to_delete.append(fn)

0 commit comments

Comments
 (0)