-
Notifications
You must be signed in to change notification settings - Fork 14.4k
A few tweaks to the MLIR .pyi files #110488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The exact commands to repro python3 -m pyupgrade --py310-plus --keep-percent-format mlir/python/mlir/_mlir_libs/**/*.pyi python3 -m ruff check --select=F401 --fix mlir/python/mlir/_mlir_libs/**/*.pyi
@llvm/pr-subscribers-mlir Author: Sergei Lebedev (superbobry) ChangesPatch is 52.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110488.diff 7 Files Affected:
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
index 93b978c75540f4..42694747e5f24f 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -1,9 +1,8 @@
-from typing import List
globals: "_Globals"
class _Globals:
- dialect_search_modules: List[str]
+ dialect_search_modules: list[str]
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ...
def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
def append_dialect_search_prefix(self, module_name: str) -> None: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
index 8ec944d191c6ff..d12c6839deabaf 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
@@ -2,7 +2,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Optional
from mlir.ir import Type, Context
@@ -26,7 +25,7 @@ class AttributeType(Type):
def isinstance(type: Type) -> bool: ...
@staticmethod
- def get(context: Optional[Context] = None) -> AttributeType: ...
+ def get(context: Context | None = None) -> AttributeType: ...
class OperationType(Type):
@@ -34,7 +33,7 @@ class OperationType(Type):
def isinstance(type: Type) -> bool: ...
@staticmethod
- def get(context: Optional[Context] = None) -> OperationType: ...
+ def get(context: Context | None = None) -> OperationType: ...
class RangeType(Type):
@@ -53,7 +52,7 @@ class TypeType(Type):
def isinstance(type: Type) -> bool: ...
@staticmethod
- def get(context: Optional[Context] = None) -> TypeType: ...
+ def get(context: Context | None = None) -> TypeType: ...
class ValueType(Type):
@@ -61,4 +60,4 @@ class ValueType(Type):
def isinstance(type: Type) -> bool: ...
@staticmethod
- def get(context: Optional[Context] = None) -> ValueType: ...
+ def get(context: Context | None = None) -> ValueType: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
index c9c66d52b8c250..a10bc693ba6001 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
@@ -2,7 +2,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import List
from mlir.ir import Type
@@ -94,15 +93,15 @@ class UniformQuantizedPerAxisType(QuantizedType):
@classmethod
def get(cls, flags: int, storage_type: Type, expressed_type: Type,
- scales: List[float], zero_points: List[int], quantized_dimension: int,
+ scales: list[float], zero_points: list[int], quantized_dimension: int,
storage_type_min: int, storage_type_max: int):
...
@property
- def scales(self) -> List[float]: ...
+ def scales(self) -> list[float]: ...
@property
- def zero_points(self) -> List[float]: ...
+ def zero_points(self) -> list[float]: ...
@property
def quantized_dimension(self) -> int: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
index 2a29541734a821..a3f1b09102379f 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
@@ -2,7 +2,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Optional
from mlir.ir import Type, Context
@@ -12,7 +11,7 @@ class AnyOpType(Type):
def isinstance(type: Type) -> bool: ...
@staticmethod
- def get(context: Optional[Context] = None) -> AnyOpType: ...
+ def get(context: Context | None = None) -> AnyOpType: ...
class OperationType(Type):
@@ -20,7 +19,7 @@ class OperationType(Type):
def isinstance(type: Type) -> bool: ...
@staticmethod
- def get(operation_name: str, context: Optional[Context] = None) -> OperationType: ...
+ def get(operation_name: str, context: Context | None = None) -> OperationType: ...
@property
def operation_name(self) -> str: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 4d5b4cef9d8aa8..41ed84e0467254 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -44,22 +44,9 @@ from __future__ import annotations
import abc
import collections
+from collections.abc import Callable, Sequence
import io
-from typing import (
- Any,
- Callable,
- ClassVar,
- Dict,
- List,
- Optional,
- Sequence,
- Tuple,
- Type as _Type,
- TypeVar,
- Union,
-)
-
-from typing import overload
+from typing import Any, ClassVar, TypeVar, overload
__all__ = [
"AffineAddExpr",
@@ -210,14 +197,14 @@ class _OperationBase:
def get_asm(
self,
binary: bool = False,
- large_elements_limit: Optional[int] = None,
+ large_elements_limit: int | None = None,
enable_debug_info: bool = False,
pretty_debug_info: bool = False,
print_generic_op_form: bool = False,
use_local_scope: bool = False,
assume_verified: bool = False,
skip_regions: bool = False,
- ) -> Union[io.BytesIO, io.StringIO]:
+ ) -> io.BytesIO | io.StringIO:
"""
Gets the assembly form of the operation with all options available.
@@ -242,7 +229,7 @@ class _OperationBase:
def print(
self,
state: AsmState,
- file: Optional[Any] = None,
+ file: Any | None = None,
binary: bool = False,
) -> None:
"""
@@ -256,13 +243,13 @@ class _OperationBase:
@overload
def print(
self,
- large_elements_limit: Optional[int] = None,
+ large_elements_limit: int | None = None,
enable_debug_info: bool = False,
pretty_debug_info: bool = False,
print_generic_op_form: bool = False,
use_local_scope: bool = False,
assume_verified: bool = False,
- file: Optional[Any] = None,
+ file: Any | None = None,
binary: bool = False,
skip_regions: bool = False,
) -> None:
@@ -296,7 +283,7 @@ class _OperationBase:
"""
Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
"""
- def write_bytecode(self, file: Any, desired_version: Optional[int] = None) -> None:
+ def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
"""
Write the bytecode form of the operation to a file like object.
@@ -325,7 +312,7 @@ class _OperationBase:
@property
def operands(self) -> OpOperandList: ...
@property
- def parent(self) -> Optional[_OperationBase]: ...
+ def parent(self) -> _OperationBase | None: ...
@property
def regions(self) -> RegionSequence: ...
@property
@@ -380,13 +367,13 @@ class AffineExpr:
"""
@staticmethod
def get_constant(
- value: int, context: Optional[Context] = None
+ value: int, context: Context | None = None
) -> AffineConstantExpr:
"""
Gets a constant affine expression with the given value.
"""
@staticmethod
- def get_dim(position: int, context: Optional[Context] = None) -> AffineDimExpr:
+ def get_dim(position: int, context: Context | None = None) -> AffineDimExpr:
"""
Gets an affine expression of a dimension at the given position.
"""
@@ -446,7 +433,7 @@ class AffineExpr:
"""
@staticmethod
def get_symbol(
- position: int, context: Optional[Context] = None
+ position: int, context: Context | None = None
) -> AffineSymbolExpr:
"""
Gets an affine expression of a symbol at the given position.
@@ -489,7 +476,7 @@ class AffineExpr:
class Attribute:
@staticmethod
- def parse(asm: str | bytes, context: Optional[Context] = None) -> Attribute:
+ def parse(asm: str | bytes, context: Context | None = None) -> Attribute:
"""
Parses an attribute from an assembly form. Raises an MLIRError on failure.
"""
@@ -530,7 +517,7 @@ class Attribute:
class Type:
@staticmethod
- def parse(asm: str | bytes, context: Optional[Context] = None) -> Type:
+ def parse(asm: str | bytes, context: Context | None = None) -> Type:
"""
Parses the assembly form of a type.
@@ -640,7 +627,7 @@ class AffineCeilDivExpr(AffineBinaryExpr):
class AffineConstantExpr(AffineExpr):
@staticmethod
- def get(value: int, context: Optional[Context] = None) -> AffineConstantExpr: ...
+ def get(value: int, context: Context | None = None) -> AffineConstantExpr: ...
@staticmethod
def isinstance(other: AffineExpr) -> bool: ...
def __init__(self, expr: AffineExpr) -> None: ...
@@ -649,7 +636,7 @@ class AffineConstantExpr(AffineExpr):
class AffineDimExpr(AffineExpr):
@staticmethod
- def get(position: int, context: Optional[Context] = None) -> AffineDimExpr: ...
+ def get(position: int, context: Context | None = None) -> AffineDimExpr: ...
@staticmethod
def isinstance(other: AffineExpr) -> bool: ...
def __init__(self, expr: AffineExpr) -> None: ...
@@ -657,7 +644,7 @@ class AffineDimExpr(AffineExpr):
def position(self) -> int: ...
class AffineExprList:
- def __add__(self, arg0: AffineExprList) -> List[AffineExpr]: ...
+ def __add__(self, arg0: AffineExprList) -> list[AffineExpr]: ...
class AffineFloorDivExpr(AffineBinaryExpr):
@staticmethod
@@ -669,43 +656,43 @@ class AffineFloorDivExpr(AffineBinaryExpr):
class AffineMap:
@staticmethod
def compress_unused_symbols(
- arg0: List, arg1: Optional[Context]
- ) -> List[AffineMap]: ...
+ arg0: list, arg1: Context | None
+ ) -> list[AffineMap]: ...
@staticmethod
def get(
dim_count: int,
symbol_count: int,
- exprs: List,
- context: Optional[Context] = None,
+ exprs: list,
+ context: Context | None = None,
) -> AffineMap:
"""
Gets a map with the given expressions as results.
"""
@staticmethod
- def get_constant(value: int, context: Optional[Context] = None) -> AffineMap:
+ def get_constant(value: int, context: Context | None = None) -> AffineMap:
"""
Gets an affine map with a single constant result
"""
@staticmethod
- def get_empty(context: Optional[Context] = None) -> AffineMap:
+ def get_empty(context: Context | None = None) -> AffineMap:
"""
Gets an empty affine map.
"""
@staticmethod
- def get_identity(n_dims: int, context: Optional[Context] = None) -> AffineMap:
+ def get_identity(n_dims: int, context: Context | None = None) -> AffineMap:
"""
Gets an identity map with the given number of dimensions.
"""
@staticmethod
def get_minor_identity(
- n_dims: int, n_results: int, context: Optional[Context] = None
+ n_dims: int, n_results: int, context: Context | None = None
) -> AffineMap:
"""
Gets a minor identity map with the given number of dimensions and results.
"""
@staticmethod
def get_permutation(
- permutation: List[int], context: Optional[Context] = None
+ permutation: list[int], context: Context | None = None
) -> AffineMap:
"""
Gets an affine map that permutes its inputs.
@@ -722,7 +709,7 @@ class AffineMap:
"""
def get_major_submap(self, n_results: int) -> AffineMap: ...
def get_minor_submap(self, n_results: int) -> AffineMap: ...
- def get_submap(self, result_positions: List[int]) -> AffineMap: ...
+ def get_submap(self, result_positions: list[int]) -> AffineMap: ...
def replace(
self,
expr: AffineExpr,
@@ -748,7 +735,7 @@ class AffineMap:
@property
def n_symbols(self) -> int: ...
@property
- def results(self) -> "AffineMapExprList": ...
+ def results(self) -> AffineMapExprList: ...
class AffineMapAttr(Attribute):
static_typeid: ClassVar[TypeID]
@@ -781,7 +768,7 @@ class AffineMulExpr(AffineBinaryExpr):
class AffineSymbolExpr(AffineExpr):
@staticmethod
- def get(position: int, context: Optional[Context] = None) -> AffineSymbolExpr: ...
+ def get(position: int, context: Context | None = None) -> AffineSymbolExpr: ...
@staticmethod
def isinstance(other: AffineExpr) -> bool: ...
def __init__(self, expr: AffineExpr) -> None: ...
@@ -791,13 +778,13 @@ class AffineSymbolExpr(AffineExpr):
class ArrayAttr(Attribute):
static_typeid: ClassVar[TypeID]
@staticmethod
- def get(attributes: List, context: Optional[Context] = None) -> ArrayAttr:
+ def get(attributes: list, context: Context | None = None) -> ArrayAttr:
"""
Gets a uniqued Array attribute
"""
@staticmethod
def isinstance(other: Attribute) -> bool: ...
- def __add__(self, arg0: List) -> ArrayAttr: ...
+ def __add__(self, arg0: list) -> ArrayAttr: ...
def __getitem__(self, arg0: int) -> Attribute: ...
def __init__(self, cast_from_attr: Attribute) -> None: ...
def __iter__(
@@ -835,7 +822,7 @@ class AttrBuilder:
class BF16Type(Type):
static_typeid: ClassVar[TypeID]
@staticmethod
- def get(context: Optional[Context] = None) -> BF16Type:
+ def get(context: Context | None = None) -> BF16Type:
"""
Create a bf16 type.
"""
@@ -849,8 +836,8 @@ class Block:
@staticmethod
def create_at_start(
parent: Region,
- arg_types: List[Type],
- arg_locs: Optional[Sequence] = None,
+ arg_types: list[Type],
+ arg_locs: Sequence | None = None,
) -> Block:
"""
Creates and returns a new Block at the beginning of the given region (with given argument types and locations).
@@ -876,11 +863,11 @@ class Block:
"""
Append this block to a region, transferring ownership if necessary
"""
- def create_after(self, *args, arg_locs: Optional[Sequence] = None) -> Block:
+ def create_after(self, *args, arg_locs: Sequence | None = None) -> Block:
"""
Creates and returns a new Block after this block (with given argument types and locations).
"""
- def create_before(self, *args, arg_locs: Optional[Sequence] = None) -> Block:
+ def create_before(self, *args, arg_locs: Sequence | None = None) -> Block:
"""
Creates and returns a new Block before this block (with given argument types and locations).
"""
@@ -924,9 +911,9 @@ class BlockArgumentList:
@overload
def __getitem__(self, arg0: slice) -> BlockArgumentList: ...
def __len__(self) -> int: ...
- def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ...
+ def __add__(self, arg0: BlockArgumentList) -> list[BlockArgument]: ...
@property
- def types(self) -> List[Type]: ...
+ def types(self) -> list[Type]: ...
class BlockIterator:
def __iter__(self) -> BlockIterator: ...
@@ -936,7 +923,7 @@ class BlockList:
def __getitem__(self, arg0: int) -> Block: ...
def __iter__(self) -> BlockIterator: ...
def __len__(self) -> int: ...
- def append(self, *args, arg_locs: Optional[Sequence] = None) -> Block:
+ def append(self, *args, arg_locs: Sequence | None = None) -> Block:
"""
Appends a new block, with argument types as positional args.
@@ -946,7 +933,7 @@ class BlockList:
class BoolAttr(Attribute):
@staticmethod
- def get(value: bool, context: Optional[Context] = None) -> BoolAttr:
+ def get(value: bool, context: Context | None = None) -> BoolAttr:
"""
Gets an uniqued bool attribute
"""
@@ -1000,7 +987,7 @@ class Context:
def _get_context_again(self) -> Context: ...
def _get_live_module_count(self) -> int: ...
def _get_live_operation_count(self) -> int: ...
- def _get_live_operation_objects(self) -> List[Operation]: ...
+ def _get_live_operation_objects(self) -> list[Operation]: ...
def append_dialect_registry(self, registry: DialectRegistry) -> None: ...
def attach_diagnostic_handler(
self, callback: Callable[[Diagnostic], bool]
@@ -1031,14 +1018,14 @@ class Context:
class DenseBoolArrayAttr(Attribute):
@staticmethod
def get(
- values: Sequence[bool], context: Optional[Context] = None
+ values: Sequence[bool], context: Context | None = None
) -> DenseBoolArrayAttr:
"""
Gets a uniqued dense array attribute
"""
@staticmethod
def isinstance(other: Attribute) -> bool: ...
- def __add__(self, arg0: List) -> DenseBoolArrayAttr: ...
+ def __add__(self, arg0: list) -> DenseBoolArrayAttr: ...
def __getitem__(self, arg0: int) -> bool: ...
def __init__(self, cast_from_attr: Attribute) -> None: ...
def __iter__(
@@ -1061,9 +1048,9 @@ class DenseElementsAttr(Attribute):
def get(
array: Buffer,
signless: bool = True,
- type: Optional[Type] = None,
- shape: Optional[List[int]] = None,
- context: Optional[Context] = None,
+ type: Type | None = None,
+ shape: list[int] | None = None,
+ context: Context | None = None,
) -> DenseElementsAttr:
"""
Gets a DenseElementsAttr from a Python buffer or array.
@@ -1128,14 +1115,14 @@ class DenseElementsAttr(Attribute):
class DenseF32ArrayAttr(Attribute):
@staticmethod
def get(
- values: Sequence[float], context: Optional[Context] = None
+ values: Sequence[float], context: Context | None = None
) -> DenseF32ArrayAttr:
"""
Gets a uniqued dense array attribute
"""
@staticmethod
def isinstance(other: Attribute) -> bool: ...
- def __add__(self, arg0: List) -> DenseF32ArrayAttr: ...
+ def __add__(self, arg0: list) -> DenseF32ArrayAttr: ...
def __getitem__(self, arg0: int) -> float: ...
def __init__(self, cast_from_attr: Attribute) -> None: ...
def __iter__(
@@ -1156,14 +1143,14 @@ class DenseF32ArrayIterator:
class DenseF64ArrayAttr(Attribute):
@staticmethod
def get(
- values: Sequence[float], context: Optional[Context] = None
+ values: Sequence[float], context: Context | None = None
) -> DenseF64ArrayAttr:
"""
Gets a uniqued dense array attribute
"""
@staticmethod
def isinstance(other: Attribute) -> bool: ...
- def __add__(self, arg0: List) -> DenseF64ArrayAttr: ...
+ def __add__(self, arg0: list) -> DenseF64ArrayAttr: ...
def __getitem__(self, arg0: int) -> float: ...
def __init__(self, cast_from_attr: Attribute) -> None: ...
def __iter__(
@@ -1186,9 +1173,9 @@ class DenseFPElementsAttr(DenseElementsAttr):
def get(
array: Buffer,
signless: bool = True,
- type: Optional[Type] = None,
- shape: Optional[List[int]] = None,
- context: Optional[Context] = None,
+ type: Type | None = None,
+ shape: list[int] | None = None,
+ context: Context | None = None,
) -> DenseFPElementsAttr: ...
@staticmethod
def isinstance(other: Attribute) -> bool: ...
@@ -1203,13 +1190,13 @@ class DenseFPElementsAttr(DenseElementsAttr):
class DenseI16ArrayAttr(Attribute):
@staticmethod
- def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI16ArrayAttr:
+ def get(values: Sequence[int], context: Context | None = None) -> DenseI16ArrayAttr:
"""
Gets a uniqued dense array attribute
"""
@staticmethod
def isinstance(other: Attribute) -> bool: ...
- def __add__(self, arg0: List) -> DenseI16ArrayAttr: ...
+ def __add__(self, arg0: list) -> DenseI16ArrayAttr: ...
def __getitem__(self, arg0: int) -> int: ...
def __init__(self, cast_from_attr: Attribute) -> None: ...
...
[truncated]
|
@@ -26,15 +25,15 @@ class AttributeType(Type): | |||
def isinstance(type: Type) -> bool: ... | |||
|
|||
@staticmethod | |||
def get(context: Optional[Context] = None) -> AttributeType: ... | |||
def get(context: Context | None = None) -> AttributeType: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this syntax requires 3.10? 3.11? ie would break 3.8, 3.9.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will break if the type checker enforces a specific Python version on type stubs. Not sure if that happens in practice.
Note also that this syntax is already used in ir.pyi
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why MLIR wants to target 3.8+ for its Python bindings? 3.8 is EOL in a months and 3.9 is only receiving security updates before EOLing next year.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Traditionally, llvm debates such things. I can't remember the specifics of the last time, but generally we've been tracking the eol schedule. Dropping 3.8 on those grounds SGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also support type stubs only being supported for active versions so long as it doesn't impede use of the library on older versions. I think it is completely reasonable for development efficiency features to not target discontinued versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type stubs are not used at all at runtime, so as long as type checkers are happy with the new syntax, it should be fine to use it, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will break if the type checker enforces a specific Python version on type stubs.
That's not correct - it will raise a TypeError
in just vanilla python:
$ conda create -n throwaway python=3.8
$ conda activate throwaway
(throwaway) $ python -c "def fun(a: int | float): pass"
Traceback (most recent call last):
File "<string>", line 1, in <module>
TypeError: unsupported operand type(s) for |: 'type' and 'type'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion on moving the support window forward - I'm just verifying.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry I didn't realize this was in a pyi file. In which case I think you're correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks (sorry again for the initially cursory reading)
Thanks @makslevental, can you merge the PR please? I don't have write access. |
No description provided.