-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][python] update type stubs #75099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesSo I definitely didn't mean to get sucked into this (I just wanted to add Patch is 108.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75099.diff 1 Files Affected:
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 2609117dd220be..577222ce79a9ea 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -1,15 +1,63 @@
# Originally imported via:
-# stubgen {...} -m mlir._mlir_libs._mlir.ir
+# pybind11-stubgen --print-invalid-expressions-as-is mlir._mlir_libs._mlir.ir
+# but with the following diff (in order to remove pipes from types,
+# which we won't support until bumping minimum python to 3.10)
+#
+# --------------------- diff begins ------------------------------------
+#
+# diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py
+# index 1f755aa..4924927 100644
+# --- a/pybind11_stubgen/printer.py
+# +++ b/pybind11_stubgen/printer.py
+# @@ -283,14 +283,6 @@ class Printer:
+# return split[0] + "..."
+#
+# def print_type(self, type_: ResolvedType) -> str:
+# - if (
+# - str(type_.name) == "typing.Optional"
+# - and type_.parameters is not None
+# - and len(type_.parameters) == 1
+# - ):
+# - return f"{self.print_annotation(type_.parameters[0])} | None"
+# - if str(type_.name) == "typing.Union" and type_.parameters is not None:
+# - return " | ".join(self.print_annotation(p) for p in type_.parameters)
+# if type_.parameters:
+# param_str = (
+# "["
+#
+# --------------------- diff ends ------------------------------------
+#
# Local modifications:
-# * Rewrite references to 'mlir.ir.' to local types
-# * Add __all__ with the following incantation:
-# egrep '^class ' ir.pyi | awk -F ' |:|\\(' '{print " \"" $2 "\","}'
+# * Rewrite references to '' to local types.
+# * Drop `typing.` everywhere (top-level import instead).
+# * List -> List, dict -> Dict, Tuple -> Tuple.
+# * copy-paste Buffer type from
+# * Shuffle _OperationBase, AffineExpr, Attribute, Type, Value to the top.
+# * Patch raw C++ types (like "PyAsmState") with a regex like `Py(.*)`.
+# * _BaseContext -> Context, MlirType -> Type, MlirTypeID -> TypeID, MlirAttribute -> Attribute.
# * Local edits to signatures and types that MyPy did not auto detect (or
# detected incorrectly).
+# * Add MLIRError, _GlobalDebug, _OperationBase to __all__ by hand.
+# * Fill in `Any`s where possible.
+# * black formatting.
+from __future__ import annotations
+
+import abc
+import collections
+import io
from typing import (
- Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple,
- Type as _Type, TypeVar
+ Any,
+ Callable,
+ ClassVar,
+ Dict,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type as _Type,
+ TypeVar,
+ Union,
)
from typing import overload
@@ -30,6 +78,8 @@ __all__ = [
"AffineSymbolExpr",
"ArrayAttr",
"ArrayAttributeIterator",
+ "AsmState",
+ "AttrBuilder",
"Attribute",
"BF16Type",
"Block",
@@ -40,29 +90,44 @@ __all__ = [
"BoolAttr",
"ComplexType",
"Context",
+ "DenseBoolArrayAttr",
+ "DenseBoolArrayIterator",
"DenseElementsAttr",
+ "DenseF32ArrayAttr",
+ "DenseF32ArrayIterator",
+ "DenseF64ArrayAttr",
+ "DenseF64ArrayIterator",
"DenseFPElementsAttr",
+ "DenseI16ArrayAttr",
+ "DenseI16ArrayIterator",
+ "DenseI32ArrayAttr",
+ "DenseI32ArrayIterator",
+ "DenseI64ArrayAttr",
+ "DenseI64ArrayIterator",
+ "DenseI8ArrayAttr",
+ "DenseI8ArrayIterator",
"DenseIntElementsAttr",
"DenseResourceElementsAttr",
- "Dialect",
- "DialectDescriptor",
- "Dialects",
"Diagnostic",
"DiagnosticHandler",
"DiagnosticInfo",
"DiagnosticSeverity",
+ "Dialect",
+ "DialectDescriptor",
+ "DialectRegistry",
+ "Dialects",
"DictAttr",
- "Float8E4M3FNType",
- "Float8E5M2Type",
- "Float8E4M3FNUZType",
- "Float8E4M3B11FNUZType",
- "Float8E5M2FNUZType",
"F16Type",
- "FloatTF32Type",
"F32Type",
"F64Type",
"FlatSymbolRefAttr",
+ "Float8E4M3B11FNUZType",
+ "Float8E4M3FNType",
+ "Float8E4M3FNUZType",
+ "Float8E5M2FNUZType",
+ "Float8E5M2Type",
"FloatAttr",
+ "FloatTF32Type",
"FunctionType",
"IndexType",
"InferShapedTypeOpInterface",
@@ -76,15 +141,18 @@ __all__ = [
"Location",
"MemRefType",
"Module",
- "MLIRError",
"NamedAttribute",
"NoneType",
- "OpaqueType",
"OpAttributeMap",
+ "OpOperand",
+ "OpOperandIterator",
"OpOperandList",
"OpResult",
"OpResultList",
+ "OpSuccessors",
"OpView",
+ "OpaqueAttr",
+ "OpaqueType",
"Operation",
"OperationIterator",
"OperationList",
@@ -94,11 +162,14 @@ __all__ = [
"RegionSequence",
"ShapedType",
"ShapedTypeComponents",
+ "StridedLayoutAttr",
"StringAttr",
+ "SymbolRefAttr",
"SymbolTable",
"TupleType",
"Type",
"TypeAttr",
+ "TypeID",
"UnitAttr",
"UnrankedMemRefType",
"UnrankedTensorType",
@@ -108,222 +179,561 @@ __all__ = [
"_OperationBase",
]
-# Base classes: declared first to simplify declarations below.
+if hasattr(collections.abc, "Buffer"):
+ Buffer = collections.abc.Buffer
+else:
+ class Buffer(abc.ABC):
+ pass
+
class _OperationBase:
- def detach_from_parent(self) -> OpView: ...
- def get_asm(self, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> object: ...
- def move_after(self, other: _OperationBase) -> None: ...
- def move_before(self, other: _OperationBase) -> None: ...
- def print(self, file: Optional[Any] = None, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> None: ...
- def verify(self) -> bool: ...
@overload
def __eq__(self, arg0: _OperationBase) -> bool: ...
@overload
- def __eq__(self, arg0: object) -> bool: ...
+ def __eq__(self, arg0: _OperationBase) -> bool: ...
def __hash__(self) -> int: ...
+ def __str__(self) -> str:
+ """
+ Returns the assembly form of the operation.
+ """
+ def clone(self, ip: InsertionPoint = None) -> OpView: ...
+ def detach_from_parent(self) -> OpView:
+ """
+ Detaches the operation from its parent block.
+ """
+ def erase(self) -> None: ...
+ def get_asm(
+ self,
+ binary: bool = False,
+ large_elements_limit: Optional[int] = None,
+ enable_debug_info: bool = False,
+ pretty_debug_info: bool = False,
+ print_generic_op_form: bool = False,
+ use_local_scope: bool = False,
+ assume_verified: bool = False,
+ ) -> Union[io.BytesIO, io.StringIO]:
+ """
+ Gets the assembly form of the operation with all options available.
+
+ Args:
+ binary: Whether to return a bytes (True) or str (False) object. Defaults to
+ False.
+ ... others ...: See the print() method for common keyword arguments for
+ configuring the printout.
+ Returns:
+ Either a bytes or str object, depending on the setting of the 'binary'
+ argument.
+ """
+ def move_after(self, other: _OperationBase) -> None:
+ """
+ Puts self immediately after the other operation in its parent block.
+ """
+ def move_before(self, other: _OperationBase) -> None:
+ """
+ Puts self immediately before the other operation in its parent block.
+ """
+ @overload
+ def print(
+ self,
+ state: AsmState,
+ file: Optional[Any] = None,
+ binary: bool = False,
+ ) -> None:
+ """
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ state: AsmState capturing the operation numbering and flags.
+ """
+ @overload
+ def print(
+ self,
+ large_elements_limit: Optional[int] = None,
+ enable_debug_info: bool = False,
+ pretty_debug_info: bool = False,
+ print_generic_op_form: bool = False,
+ use_local_scope: bool = False,
+ assume_verified: bool = False,
+ file: Optional[Any] = None,
+ binary: bool = False,
+ ) -> None:
+ """
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ large_elements_limit: Whether to elide elements attributes above this
+ number of elements. Defaults to None (no limit).
+ enable_debug_info: Whether to print debug/location information. Defaults
+ to False.
+ pretty_debug_info: Whether to format debug information for easier reading
+ by a human (warning: the result is unparseable).
+ print_generic_op_form: Whether to print the generic assembly forms of all
+ ops. Defaults to False.
+ use_local_Scope: Whether to print in a way that is more optimized for
+ multi-threaded access but may not be consistent with how the overall
+ module prints.
+ assume_verified: By default, if not printing generic form, the verifier
+ will be run and if it fails, generic form will be printed with a comment
+ about failed verification. While a reasonable default for interactive use,
+ for systematic use, it is often better for the caller to verify explicitly
+ and report failures in a more robust fashion. Set this to True if doing this
+ in order to avoid running a redundant verification. If the IR is actually
+ invalid, behavior is undefined.
+ """
+ def verify(self) -> bool:
+ """
+ Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
+ """
+ def write_bytecode(self, file: Any, desired_version: Optional[int] = None) -> None:
+ """
+ Write the bytecode form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to.
+ desired_version: The version of bytecode to emit.
+ Returns:
+ The bytecode writer status.
+ """
@property
def _CAPIPtr(self) -> object: ...
@property
def attributes(self) -> OpAttributeMap: ...
@property
- def context(self) -> Context: ...
+ def context(self) -> Context:
+ """
+ Context that owns the Operation
+ """
@property
- def location(self) -> Location: ...
+ def location(self) -> Location:
+ """
+ Returns the source location the operation was defined or derived from.
+ """
@property
def name(self) -> str: ...
@property
def operands(self) -> OpOperandList: ...
@property
- @property
def parent(self) -> Optional[_OperationBase]: ...
+ @property
def regions(self) -> RegionSequence: ...
@property
- def result(self) -> OpResult: ...
+ def result(self) -> OpResult:
+ """
+ Shortcut to get an op result if it has only one (throws an error otherwise).
+ """
@property
- def results(self) -> OpResultList: ...
+ def results(self) -> OpResultList:
+ """
+ Returns the List of Operation results.
+ """
_TOperation = TypeVar("_TOperation", bound=_OperationBase)
-# TODO: Auto-generated. Audit and fix.
class AffineExpr:
- def __init__(self, *args, **kwargs) -> None: ...
+ @staticmethod
+ @overload
+ def get_add(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr:
+ """
+ Gets an affine expression containing a sum of two expressions.
+ """
+ @staticmethod
+ @overload
+ def get_add(arg0: int, arg1: AffineExpr) -> AffineAddExpr:
+ """
+ Gets an affine expression containing a sum of a constant and another expression.
+ """
+ @staticmethod
+ @overload
+ def get_add(arg0: AffineExpr, arg1: int) -> AffineAddExpr:
+ """
+ Gets an affine expression containing a sum of an expression and a constant.
+ """
+ @staticmethod
+ @overload
+ def get_ceil_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr:
+ """
+ Gets an affine expression containing the rounded-up result of dividing one expression by another.
+ """
+ @staticmethod
+ @overload
+ def get_ceil_div(arg0: int, arg1: AffineExpr) -> AffineCeilDivExpr:
+ """
+ Gets a semi-affine expression containing the rounded-up result of dividing a constant by an expression.
+ """
+ @staticmethod
+ @overload
+ def get_ceil_div(arg0: AffineExpr, arg1: int) -> AffineCeilDivExpr:
+ """
+ Gets an affine expression containing the rounded-up result of dividing an expression by a constant.
+ """
+ @staticmethod
+ def get_constant(
+ value: int, context: Optional[Context] = None
+ ) -> AffineConstantExpr:
+ """
+ Gets a constant affine expression with the given value.
+ """
+ @staticmethod
+ def get_dim(position: int, context: Optional[Context] = None) -> AffineDimExpr:
+ """
+ Gets an affine expression of a dimension at the given position.
+ """
+ @staticmethod
+ @overload
+ def get_floor_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr:
+ """
+ Gets an affine expression containing the rounded-down result of dividing one expression by another.
+ """
+ @staticmethod
+ @overload
+ def get_floor_div(arg0: int, arg1: AffineExpr) -> AffineFloorDivExpr:
+ """
+ Gets a semi-affine expression containing the rounded-down result of dividing a constant by an expression.
+ """
+ @staticmethod
+ @overload
+ def get_floor_div(arg0: AffineExpr, arg1: int) -> AffineFloorDivExpr:
+ """
+ Gets an affine expression containing the rounded-down result of dividing an expression by a constant.
+ """
+ @staticmethod
+ @overload
+ def get_mod(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr:
+ """
+ Gets an affine expression containing the modulo of dividing one expression by another.
+ """
+ @staticmethod
+ @overload
+ def get_mod(arg0: int, arg1: AffineExpr) -> AffineModExpr:
+ """
+ Gets a semi-affine expression containing the modulo of dividing a constant by an expression.
+ """
+ @staticmethod
+ @overload
+ def get_mod(arg0: AffineExpr, arg1: int) -> AffineModExpr:
+ """
+ Gets an affine expression containing the module of dividingan expression by a constant.
+ """
+ @staticmethod
+ @overload
+ def get_mul(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr:
+ """
+ Gets an affine expression containing a product of two expressions.
+ """
+ @staticmethod
+ @overload
+ def get_mul(arg0: int, arg1: AffineExpr) -> AffineMulExpr:
+ """
+ Gets an affine expression containing a product of a constant and another expression.
+ """
+ @staticmethod
+ @overload
+ def get_mul(arg0: AffineExpr, arg1: int) -> AffineMulExpr:
+ """
+ Gets an affine expression containing a product of an expression and a constant.
+ """
+ @staticmethod
+ def get_symbol(
+ position: int, context: Optional[Context] = None
+ ) -> AffineSymbolExpr:
+ """
+ Gets an affine expression of a symbol at the given position.
+ """
def _CAPICreate(self) -> AffineExpr: ...
- def compose(self, arg0) -> AffineExpr: ...
- def dump(self) -> None: ...
- def get_add(self, *args, **kwargs) -> Any: ...
- def get_ceil_div(self, *args, **kwargs) -> Any: ...
- def get_constant(self, *args, **kwargs) -> Any: ...
- def get_dim(self, *args, **kwargs) -> Any: ...
- def get_floor_div(self, *args, **kwargs) -> Any: ...
- def get_mod(self, *args, **kwargs) -> Any: ...
- def get_mul(self, *args, **kwargs) -> Any: ...
- def get_symbol(self, *args, **kwargs) -> Any: ...
- def __add__(self, other) -> Any: ...
+ @overload
+ def __add__(self, arg0: AffineExpr) -> AffineAddExpr: ...
+ @overload
+ def __add__(self, arg0: int) -> AffineAddExpr: ...
@overload
def __eq__(self, arg0: AffineExpr) -> bool: ...
@overload
- def __eq__(self, arg0: object) -> bool: ...
+ def __eq__(self, arg0: Any) -> bool: ...
def __hash__(self) -> int: ...
- def __mod__(self, other) -> Any: ...
- def __mul__(self, other) -> Any: ...
- def __radd__(self, other) -> Any: ...
- def __rmod__(self, other) -> Any: ...
- def __rmul__(self, other) -> Any: ...
- def __rsub__(self, other) -> Any: ...
- def __sub__(self, other) -> Any: ...
+ @overload
+ def __mod__(self, arg0: AffineExpr) -> AffineModExpr: ...
+ @overload
+ def __mod__(self, arg0: int) -> AffineModExpr: ...
+ @overload
+ def __mul__(self, arg0: AffineExpr) -> AffineMulExpr: ...
+ @overload
+ def __mul__(self, arg0: int) -> AffineMulExpr: ...
+ def __radd__(self, arg0: int) -> AffineAddExpr: ...
+ def __repr__(self) -> str: ...
+ def __rmod__(self, arg0: int) -> AffineModExpr: ...
+ def __rmul__(self, arg0: int) -> AffineMulExpr: ...
+ def __rsub__(self, arg0: int) -> AffineAddExpr: ...
+ def __str__(self) -> str: ...
+ @overload
+ def __sub__(self, arg0: AffineExpr) -> AffineAddExpr: ...
+ @overload
+ def __sub__(self, arg0: int) -> AffineAddExpr: ...
+ def compose(self, arg0: AffineMap) -> AffineExpr: ...
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
@property
def _CAPIPtr(self) -> object: ...
@property
def context(self) -> Context: ...
class Attribute:
- def __init__(self, cast_from_type: Attribute) -> None: ...
- def _CAPICreate(self) -> Attribute: ...
- def dump(self) -> None: ...
- def get_named(self, *args, **kwargs) -> Any: ...
@staticmethod
- def parse(asm: str, context: Optional[Context] = None) -> Any: ...
+ def parse(asm: str, context: Optional[Context] = None) -> Attribute:
+ """
+ Parses an attribute from an assembly form. Raises an MLIRError on failure.
+ """
+ def _CAPICreate(self) -> Attribute: ...
@overload
def __eq__(self, arg0: Attribute) -> bool: ...
@overload
def __eq__(self, arg0: object) -> bool: ...
def __hash__(self) -> int: ...
+ def __init__(self, cast_from_type: Attribute) -> None:
+ """
+ Casts the passed attribute to the generic Attribute
+ """
+ def __repr__(self) -> str: ...
+ def __str__(self) -> str:
+ """
+ Returns the assembly form of the Attribute.
+ """
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
+ def get_named(self, arg0: str) -> NamedAttribute:
+ """
+ Binds a name to the attribute
+ """
+ def maybe_downcast(self) -> Any: ...
@property
def _CAPIPtr(self) -> object: ...
@property
- def context(self) -> Context: ...
+ def context(self) -> Context:
+ """
+ Context that owns the Attribute
+ """
@property
def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
class Type:
- def __init__(self, cast_from_type: Type) -> None: ...
- def _CAPICreate(self) -> Type: ...
- def dump(self) -> None: ...
@staticmethod
- def parse(asm: str, context: Optional[Context] = None) -> Type: ...
+ def parse(asm: str, context: Optional[Context] = None) -> Type:
+ """
+ Parses the assembly form of a type.
+
+ Returns a Type object or raises an MLI...
[truncated]
|
77609d4
to
bb25ef6
Compare
bb25ef6
to
910a421
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will take it. Big qol improvement.
The function signatures seem to have changed after llvm/llvm-project#75099
So I definitely didn't mean to get sucked into this (I just wanted to add
StridedLayoutAttr
) but here we are: I regeneratedir.pyi
usingpybind11-stubgen
instead of mypy'sstubgen
. It did a pretty good job! It added quite lot, including ~20 types we weren't generating for before. Of course I had to patch things up and then double check, so this isn't very reproducible (in its current form) but neither is the current/previous iteration.