Skip to content

[mypyc] Allow dots in group names #7941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def get_mypy_config(mypy_options: List[str],
else:
mypyc_sources = all_sources

if compiler_options.separate:
mypyc_sources = [src for src in mypyc_sources
if src.path and not src.path.endswith('__init__.py')]

if not mypyc_sources:
return mypyc_sources, all_sources, options

Expand Down Expand Up @@ -126,7 +130,7 @@ def generate_c_extension_shim(
dir_name: the directory to place source code
group_name: the name of the group
"""
cname = '%s.c' % exported_name(full_module_name)
cname = '%s.c' % full_module_name.replace('.', os.sep)
cpath = os.path.join(dir_name, cname)

# We load the C extension shim template from a file.
Expand All @@ -146,7 +150,7 @@ def generate_c_extension_shim(
def group_name(modules: List[str]) -> str:
"""Produce a probably unique name for a group from a list of module names."""
if len(modules) == 1:
return exported_name(modules[0])
return modules[0]

h = hashlib.sha1()
h.update(','.join(modules).encode())
Expand Down Expand Up @@ -226,7 +230,7 @@ def build_using_shared_lib(sources: List[BuildSource],
extensions = [get_extension()(
shared_lib_name(group_name),
sources=cfiles,
include_dirs=[include_dir()],
include_dirs=[include_dir(), build_dir],
depends=deps,
extra_compile_args=extra_compile_args,
)]
Expand Down
4 changes: 2 additions & 2 deletions mypyc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ def decorator_helper_name(func_name: str) -> str:
def shared_lib_name(group_name: str) -> str:
"""Given a group name, return the actual name of its extension module.

(This just adds a prefix.)
(This just adds a suffix to the final component.)
"""
return 'mypyc_{}'.format(group_name)
return '{}__mypyc'.format(group_name)
4 changes: 2 additions & 2 deletions mypyc/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive, ClassIR,
FuncDecl, int_rprimitive, is_optional_type, optional_value_type, all_concrete_classes
)
from mypyc.namegen import NameGenerator
from mypyc.namegen import NameGenerator, exported_name
from mypyc.sametype import is_same_type


Expand Down Expand Up @@ -167,7 +167,7 @@ def get_module_group_prefix(self, module_name: str) -> str:
target_group_name = groups.get(module_name)
if target_group_name and target_group_name != self.context.group_name:
self.context.group_deps.add(target_group_name)
return 'exports_{}.'.format(target_group_name)
return 'exports_{}.'.format(exported_name(target_group_name))
else:
return ''

Expand Down
61 changes: 41 additions & 20 deletions mypyc/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ def pointerize(decl: str, name: str) -> str:
return decl.replace(name, '*{}'.format(name))


def group_dir(group_name: str) -> str:
"""Given a group name, return the relative directory path for it. """
return os.sep.join(group_name.split('.')[:-1])


class GroupGenerator:
def __init__(self,
literals: LiteralsMap,
Expand Down Expand Up @@ -477,7 +482,11 @@ def __init__(self,

@property
def group_suffix(self) -> str:
return '_' + self.group_name if self.group_name else ''
return '_' + exported_name(self.group_name) if self.group_name else ''

@property
def short_group_suffix(self) -> str:
return '_' + exported_name(self.group_name.split('.')[-1]) if self.group_name else ''

def generate_c_for_modules(self) -> List[Tuple[str, str]]:
file_contents = []
Expand All @@ -489,8 +498,8 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
if self.compiler_options.include_runtime_files:
base_emitter.emit_line('#include "CPy.c"')
base_emitter.emit_line('#include "getargs.c"')
base_emitter.emit_line('#include "__native{}.h"'.format(self.group_suffix))
base_emitter.emit_line('#include "__native_internal{}.h"'.format(self.group_suffix))
base_emitter.emit_line('#include "__native{}.h"'.format(self.short_group_suffix))
base_emitter.emit_line('#include "__native_internal{}.h"'.format(self.short_group_suffix))
emitter = base_emitter

for (_, literal), identifier in self.literals.items():
Expand All @@ -503,8 +512,9 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
for module_name, module in self.modules:
if multi_file:
emitter = Emitter(self.context)
emitter.emit_line('#include "__native{}.h"'.format(self.group_suffix))
emitter.emit_line('#include "__native_internal{}.h"'.format(self.group_suffix))
emitter.emit_line('#include "__native{}.h"'.format(self.short_group_suffix))
emitter.emit_line(
'#include "__native_internal{}.h"'.format(self.short_group_suffix))

self.declare_module(module_name, emitter)
self.declare_internal_globals(module_name, emitter)
Expand Down Expand Up @@ -544,7 +554,7 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
declarations.emit_line('#define MYPYC_NATIVE_INTERNAL{}_H'.format(self.group_suffix))
declarations.emit_line('#include <Python.h>')
declarations.emit_line('#include <CPy.h>')
declarations.emit_line('#include "__native{}.h"'.format(self.group_suffix))
declarations.emit_line('#include "__native{}.h"'.format(self.short_group_suffix))
declarations.emit_line()
declarations.emit_line('int CPyGlobalsInit(void);')
declarations.emit_line()
Expand All @@ -557,9 +567,13 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
generate_function_declaration(fn, declarations)

for lib in sorted(self.context.group_deps):
elib = exported_name(lib)
short_lib = exported_name(lib.split('.')[-1])
declarations.emit_lines(
'#include "__native_{}.h"'.format(lib),
'struct export_table_{} exports_{};'.format(lib, lib)
'#include <{}>'.format(
os.path.join(group_dir(lib), "__native_{}.h".format(short_lib))
),
'struct export_table_{} exports_{};'.format(elib, elib)
)

sorted_decls = self.toposort_declarations()
Expand Down Expand Up @@ -591,13 +605,15 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
ext_declarations.emit_line('#endif')
declarations.emit_line('#endif')

return file_contents + [('__native{}.c'.format(self.group_suffix),
''.join(emitter.fragments)),
('__native_internal{}.h'.format(self.group_suffix),
''.join(declarations.fragments)),
('__native{}.h'.format(self.group_suffix),
''.join(ext_declarations.fragments)),
]
output_dir = group_dir(self.group_name) if self.group_name else ''
return file_contents + [
(os.path.join(output_dir, '__native{}.c'.format(self.short_group_suffix)),
''.join(emitter.fragments)),
(os.path.join(output_dir, '__native_internal{}.h'.format(self.short_group_suffix)),
''.join(declarations.fragments)),
(os.path.join(output_dir, '__native{}.h'.format(self.short_group_suffix)),
''.join(ext_declarations.fragments)),
]

def generate_export_table(self, decl_emitter: Emitter, code_emitter: Emitter) -> None:
"""Generate the declaration and definition of the group's export struct.
Expand Down Expand Up @@ -683,12 +699,14 @@ def generate_shared_lib_init(self, emitter: Emitter) -> None:

emitter.emit_line()
emitter.emit_lines(
'PyMODINIT_FUNC PyInit_{}(void)'.format(shared_lib_name(self.group_name)),
'PyMODINIT_FUNC PyInit_{}(void)'.format(
shared_lib_name(self.group_name).split('.')[-1]),
'{',
('static PyModuleDef def = {{ PyModuleDef_HEAD_INIT, "{}", NULL, -1, NULL, NULL }};'
.format(self.group_name)),
.format(shared_lib_name(self.group_name))),
'int res;',
'PyObject *capsule;',
'PyObject *tmp;',
'static PyObject *module;',
'if (module) {',
'Py_INCREF(module);',
Expand Down Expand Up @@ -733,14 +751,17 @@ def generate_shared_lib_init(self, emitter: Emitter) -> None:
)

for group in sorted(self.context.group_deps):
egroup = exported_name(group)
emitter.emit_lines(
'tmp = PyImport_ImportModule("{}"); if (!tmp) goto fail; Py_DECREF(tmp);'.format(
shared_lib_name(group)),
'struct export_table_{} *pexports_{} = PyCapsule_Import("{}.exports", 0);'.format(
group, group, shared_lib_name(group)),
'if (!pexports_{}) {{'.format(group),
egroup, egroup, shared_lib_name(group)),
'if (!pexports_{}) {{'.format(egroup),
'goto fail;',
'}',
'memcpy(&exports_{group}, pexports_{group}, sizeof(exports_{group}));'.format(
group=group),
group=egroup),
'',
)

Expand Down
11 changes: 6 additions & 5 deletions mypyc/test-data/run-multimodule.test
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
-- These test cases compile two modules at a time (native and other.py)

[case testMultiModuleBasic]
from other import g
[case testMultiModulePackage]
from p.other import g
def f(x: int) -> int:
from other import h
from p.other import h
return h(g(x + 1))
[file other.py]
[file p/__init__.py]
[file p/other.py]
def g(x: int) -> int:
return x + 2
def h(x: int) -> int:
return x + 1
[file driver.py]
import native
from native import f
from other import g
from p.other import g
assert f(3) == 7
assert g(2) == 4
try:
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) ->
fn = os.path.relpath(fn, test_temp_dir)

if os.path.basename(fn).startswith('other') and fn.endswith('.py'):
name = os.path.basename(fn).split('.')[0]
name = fn.split('.')[0].replace(os.sep, '.')
module_names.append(name)
sources.append(build.BuildSource(fn, name, None))
to_delete.append(fn)
Expand Down