Skip to content

[mypyc] Stop abusing module names for static namespacing #7628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypyc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
REG_PREFIX = 'cpy_r_' # type: Final # Registers
STATIC_PREFIX = 'CPyStatic_' # type: Final # Static variables (for literals etc.)
TYPE_PREFIX = 'CPyType_' # type: Final # Type object struct
MODULE_PREFIX = 'CPyModule_' # type: Final # Cached modules
ATTR_PREFIX = '_' # type: Final # Attributes

ENV_ATTR_NAME = '__mypyc_env__' # type: Final
Expand Down
7 changes: 5 additions & 2 deletions mypyc/emitfunc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Code generation for native function bodies."""


from mypyc.common import REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX
from mypyc.common import (
REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX,
)
from mypyc.emit import Emitter
from mypyc.ops import (
FuncIR, OpVisitor, Goto, Branch, Return, Assign, LoadInt, LoadErrorValue, GetAttr, SetAttr,
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
BasicBlock, Value, RType, RTuple, MethodCall, PrimitiveOp,
EmitterInterface, Unreachable, NAMESPACE_STATIC, NAMESPACE_TYPE,
EmitterInterface, Unreachable, NAMESPACE_STATIC, NAMESPACE_TYPE, NAMESPACE_MODULE,
RaiseStandardError, FuncDecl, ClassIR,
FUNC_STATICMETHOD, FUNC_CLASSMETHOD,
)
Expand Down Expand Up @@ -255,6 +257,7 @@ def visit_set_attr(self, op: SetAttr) -> None:
PREFIX_MAP = {
NAMESPACE_STATIC: STATIC_PREFIX,
NAMESPACE_TYPE: TYPE_PREFIX,
NAMESPACE_MODULE: MODULE_PREFIX,
} # type: Final

def visit_load_static(self, op: LoadStatic) -> None:
Expand Down
20 changes: 11 additions & 9 deletions mypyc/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mypy.options import Options

from mypyc import genops
from mypyc.common import PREFIX, TOP_LEVEL_NAME, INT_PREFIX
from mypyc.common import PREFIX, TOP_LEVEL_NAME, INT_PREFIX, MODULE_PREFIX
from mypyc.emit import EmitterContext, Emitter, HeaderDeclaration
from mypyc.emitfunc import generate_native_function, native_function_header
from mypyc.emitclass import generate_class_type_decl, generate_class
Expand Down Expand Up @@ -145,7 +145,7 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
self.declare_internal_globals(module_name, emitter)
self.declare_imports(module.imports, emitter)
# Finals must be last (types can depend on declared above)
self.define_finals(module.final_names, emitter)
self.define_finals(module_name, module.final_names, emitter)

for cl in module.classes:
if cl.is_ext_class:
Expand Down Expand Up @@ -192,7 +192,7 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]:
declarations.emit_lines(*declaration.decl)

for module_name, module in self.modules:
self.declare_finals(module.final_names, declarations)
self.declare_finals(module_name, module.final_names, declarations)
for cl in module.classes:
generate_class_type_decl(cl, emitter, declarations)
for fn in module.functions:
Expand Down Expand Up @@ -456,7 +456,7 @@ def declare_internal_globals(self, module_name: str, emitter: Emitter) -> None:
self.declare_global('PyObject *', static_name)

def module_internal_static_name(self, module_name: str, emitter: Emitter) -> str:
return emitter.static_name('module_internal', module_name)
return emitter.static_name(module_name + '_internal', None, prefix=MODULE_PREFIX)

def declare_module(self, module_name: str, emitter: Emitter) -> None:
# We declare two globals for each module:
Expand All @@ -465,22 +465,24 @@ def declare_module(self, module_name: str, emitter: Emitter) -> None:
# by other modules to refer to it.
internal_static_name = self.module_internal_static_name(module_name, emitter)
self.declare_global('CPyModule *', internal_static_name, initializer='NULL')
static_name = emitter.static_name('module', module_name)
static_name = emitter.static_name(module_name, None, prefix=MODULE_PREFIX)
self.declare_global('CPyModule *', static_name)
self.simple_inits.append((static_name, 'Py_None'))

def declare_imports(self, imps: Iterable[str], emitter: Emitter) -> None:
for imp in imps:
self.declare_module(imp, emitter)

def declare_finals(self, final_names: Iterable[Tuple[str, RType]], emitter: Emitter) -> None:
def declare_finals(
self, module: str, final_names: Iterable[Tuple[str, RType]], emitter: Emitter) -> None:
for name, typ in final_names:
static_name = emitter.static_name(name, 'final')
static_name = emitter.static_name(name, module)
emitter.emit_line('extern {}{};'.format(emitter.ctype_spaced(typ), static_name))

def define_finals(self, final_names: Iterable[Tuple[str, RType]], emitter: Emitter) -> None:
def define_finals(
self, module: str, final_names: Iterable[Tuple[str, RType]], emitter: Emitter) -> None:
for name, typ in final_names:
static_name = emitter.static_name(name, 'final')
static_name = emitter.static_name(name, module)
# Here we rely on the fact that undefined value and error value are always the same
if isinstance(typ, RTuple):
# We need to inline because initializer must be static
Expand Down
28 changes: 16 additions & 12 deletions mypyc/genops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def f(x: int) -> int:
from mypy.visitor import ExpressionVisitor, StatementVisitor
from mypy.checkexpr import map_actuals_to_formals
from mypy.state import strict_optional_set
from mypy.util import split_target

from mypyc.common import (
ENV_ATTR_NAME, NEXT_LABEL_ATTR_NAME, TEMP_ATTR_NAME, LAMBDA_NAME,
Expand All @@ -66,7 +67,8 @@ def f(x: int) -> int:
exc_rtuple,
PrimitiveOp, ControlOp, OpDescription, RegisterOp,
is_object_rprimitive, LiteralsMap, FuncSignature, VTableAttr, VTableMethod, VTableEntries,
NAMESPACE_TYPE, RaiseStandardError, LoadErrorValue, NO_TRACEBACK_LINE_NO, FuncDecl,
NAMESPACE_TYPE, NAMESPACE_MODULE,
RaiseStandardError, LoadErrorValue, NO_TRACEBACK_LINE_NO, FuncDecl,
FUNC_NORMAL, FUNC_STATICMETHOD, FUNC_CLASSMETHOD,
RUnion, is_optional_type, optional_value_type, all_concrete_classes
)
Expand Down Expand Up @@ -1387,7 +1389,7 @@ def cache_class_attrs(self, attrs_to_cache: List[Lvalue], cdef: ClassDef) -> Non
for lval in attrs_to_cache:
assert isinstance(lval, NameExpr)
rval = self.py_get_attr(typ, lval.name, cdef.line)
self.init_final_static(lval, rval, cdef.fullname)
self.init_final_static(lval, rval, cdef.name)

def visit_class_def(self, cdef: ClassDef) -> None:
ir = self.mapper.type_to_ir[cdef.info]
Expand Down Expand Up @@ -1451,7 +1453,7 @@ def visit_class_def(self, cdef: ClassDef) -> None:
self.primitive_op(
py_setattr_op, [typ, self.load_static_unicode(lvalue.name), value], stmt.line)
if self.non_function_scope() and stmt.is_final_def:
self.init_final_static(lvalue, value, cdef.fullname)
self.init_final_static(lvalue, value, cdef.name)
elif isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, StrExpr):
# Docstring. Ignore
pass
Expand Down Expand Up @@ -1525,13 +1527,13 @@ def gen_import(self, id: str, line: int) -> None:
self.imports[id] = None

needs_import, out = BasicBlock(), BasicBlock()
first_load = self.add(LoadStatic(object_rprimitive, 'module', id))
first_load = self.load_module(id)
comparison = self.binary_op(first_load, self.none_object(), 'is not', line)
self.add_bool_branch(comparison, out, needs_import)

self.activate_block(needs_import)
value = self.primitive_op(import_op, [self.load_static_unicode(id)], line)
self.add(InitStatic(value, 'module', id))
self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE))
self.goto_and_activate(out)

def visit_import(self, node: Import) -> None:
Expand Down Expand Up @@ -1577,7 +1579,7 @@ def visit_import_from(self, node: ImportFrom) -> None:
id = importlib.util.resolve_name('.' * node.relative + node.id, module_package)

self.gen_import(id, node.line)
module = self.add(LoadStatic(object_rprimitive, 'module', id))
module = self.load_module(id)

# Copy everything into our module's dict.
# Note that we miscompile import from inside of functions here,
Expand Down Expand Up @@ -1730,7 +1732,7 @@ def calculate_arg_defaults(self,
env.lookup(arg.variable).type, arg.line)
if not fn_info.is_nested:
name = fitem.fullname() + '.' + arg.variable.name()
self.add(InitStatic(value, name, 'final'))
self.add(InitStatic(value, name, self.module_name))
else:
assert func_reg is not None
self.add(SetAttr(func_reg, arg.variable.name(), value, arg.line))
Expand Down Expand Up @@ -1758,7 +1760,7 @@ def get_default() -> Value:
elif not self.fn_info.is_nested:
name = fitem.fullname() + '.' + arg.variable.name()
self.final_names.append((name, target.type))
return self.add(LoadStatic(target.type, name, 'final'))
return self.add(LoadStatic(target.type, name, self.module_name))
else:
name = arg.variable.name()
self.fn_info.callable_class.ir.attributes[name] = target.type
Expand Down Expand Up @@ -2042,19 +2044,21 @@ def init_final_static(self, lvalue: Lvalue, rvalue_reg: Value,
assert isinstance(lvalue.node, Var)
if lvalue.node.final_value is None:
if class_name is None:
name = lvalue.fullname
name = lvalue.name
else:
name = '{}.{}'.format(class_name, lvalue.name)
assert name is not None, "Full name not set for variable"
self.final_names.append((name, rvalue_reg.type))
self.add(InitStatic(rvalue_reg, name, 'final'))
self.add(InitStatic(rvalue_reg, name, self.module_name))

def load_final_static(self, fullname: str, typ: RType, line: int,
error_name: Optional[str] = None) -> Value:
if error_name is None:
error_name = fullname
ok_block, error_block = BasicBlock(), BasicBlock()
value = self.add(LoadStatic(typ, fullname, 'final', line=line))
split_name = split_target(self.graph, fullname)
assert split_name is not None
value = self.add(LoadStatic(typ, split_name[1], split_name[0], line=line))
self.add(Branch(value, error_block, ok_block, Branch.IS_ERROR, rare=True))
self.activate_block(error_block)
self.add(RaiseStandardError(RaiseStandardError.VALUE_ERROR,
Expand Down Expand Up @@ -5237,7 +5241,7 @@ def load_static_unicode(self, value: str) -> Value:
return self.add(LoadStatic(str_rprimitive, static_symbol, ann=value))

def load_module(self, name: str) -> Value:
return self.add(LoadStatic(object_rprimitive, 'module', name))
return self.add(LoadStatic(object_rprimitive, name, namespace=NAMESPACE_MODULE))

def load_module_attr_by_fullname(self, fullname: str, line: int) -> Value:
module, _, name = fullname.rpartition('.')
Expand Down
1 change: 1 addition & 0 deletions mypyc/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,7 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:

NAMESPACE_STATIC = 'static' # type: Final # Default name space for statics, variables
NAMESPACE_TYPE = 'type' # type: Final # Static namespace for pointers to native type objects
NAMESPACE_MODULE = 'module' # type: Final # Namespace for modules


class LoadStatic(RegisterOp):
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/analysis.test
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ L1:
goto L10
L2:
r1 = error_catch
r2 = builtins.module :: static
r2 = builtins :: module
r3 = unicode_1 :: static ('Exception')
r4 = getattr r2, r3
if is_error(r4) goto L8 (error at lol:4) else goto L3
Expand Down
10 changes: 5 additions & 5 deletions mypyc/test-data/exceptions.test
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def g():
r12, r13 :: None
L0:
L1:
r0 = builtins.module :: static
r0 = builtins :: module
r1 = unicode_1 :: static ('object')
r2 = getattr r0, r1
if is_error(r2) goto L3 (error at g:3) else goto L2
Expand All @@ -205,7 +205,7 @@ L2:
L3:
r4 = error_catch
r5 = unicode_2 :: static ('weeee')
r6 = builtins.module :: static
r6 = builtins :: module
r7 = unicode_3 :: static ('print')
r8 = getattr r6, r7
if is_error(r8) goto L7 (error at g:5) else goto L4
Expand Down Expand Up @@ -266,7 +266,7 @@ def a():
r20 :: str
L0:
L1:
r0 = builtins.module :: static
r0 = builtins :: module
r1 = unicode_1 :: static ('print')
r2 = getattr r0, r1
if is_error(r2) goto L6 (error at a:3) else goto L2
Expand All @@ -293,7 +293,7 @@ L6:
r7 = r11
L7:
r12 = unicode_3 :: static ('goodbye!')
r13 = builtins.module :: static
r13 = builtins :: module
r14 = unicode_1 :: static ('print')
r15 = getattr r13, r14
if is_error(r15) goto L15 (error at a:6) else goto L8
Expand Down Expand Up @@ -500,7 +500,7 @@ L2:
r3 = !r2
if r3 goto L12 else goto L1 :: bool
L3:
r4 = builtins.module :: static
r4 = builtins :: module
r5 = unicode_3 :: static ('print')
r6 = getattr r4, r5
if is_error(r6) goto L13 (error at f:7) else goto L4
Expand Down
Loading