-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][python] simplify extensions #69642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
makslevental
commented
Oct 19, 2023
@@ -37,7 +37,7 @@ def affine_store_test(arg0): | |||
a1 = arith.ConstantOp(f32, 2.1) | |||
|
|||
# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32> | |||
affine.AffineStoreOp(a1, mem, map, map_operands=[arg0, arg0]) | |||
affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @kaitingwang
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) Changes#68853 enabled a lot of nice cleanup. Note, I made sure each of the touched extensions had tests. Full diff: https://github.com/llvm/llvm-project/pull/69642.diff 8 Files Affected:
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 1eaccfa73a85cbf..80d3873e19a05cb 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,48 +3,3 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect
-
-try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- _cext as _ods_cext,
- )
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-@_ods_cext.register_operation(_Dialect, replace=True)
-class AffineStoreOp(AffineStoreOp):
- """Specialization for the Affine store operation."""
-
- def __init__(
- self,
- value: Union[Operation, OpView, Value],
- memref: Union[Operation, OpView, Value],
- map: AffineMap = None,
- *,
- map_operands=None,
- loc=None,
- ip=None,
- ):
- """Creates an affine store operation.
-
- - `value`: the value to store into the memref.
- - `memref`: the buffer to store into.
- - `map`: the affine map that maps the map_operands to the index of the
- memref.
- - `map_operands`: the list of arguments to substitute the dimensions,
- then symbols in the affine map, in increasing order.
- """
- map = map if map is not None else []
- map_operands = map_operands if map_operands is not None else []
- indicies = [_get_op_result_or_value(op) for op in map_operands]
- _ods_successors = None
- super().__init__(
- value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
- )
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 0ce5448ace4b14c..759b6aa24a9ff73 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -3,40 +3,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._bufferization_ops_gen import *
-from ._bufferization_ops_gen import _Dialect
from ._bufferization_enum_gen import *
-
-try:
- from typing import Sequence, Union
- from ..ir import *
- from ._ods_common import get_default_loc_context, _cext as _ods_cext
-
- from typing import Any, List, Union
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-
-@_ods_cext.register_operation(_Dialect, replace=True)
-class AllocTensorOp(AllocTensorOp):
- """Extends the bufferization.alloc_tensor op."""
-
- def __init__(
- self,
- tensor_type: Type,
- dynamic_sizes: Sequence[Value],
- copy: Value,
- size_hint: Value,
- escape: BoolAttr,
- *,
- loc=None,
- ip=None,
- ):
- """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
- super().__init__(
- tensor_type,
- dynamic_sizes,
- copy=copy,
- size_hint=size_hint,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 9c6c4c9092c7a88..6599f67b7078777 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -26,9 +26,6 @@
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
- def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
- super().__init__(result, value, loc=loc, ip=ip)
-
@property
def type(self):
return self.results[0].type
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 111ad2178703d28..3afb6a70cb9e0db 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -3,41 +3,3 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._memref_ops_gen import *
-from ._memref_ops_gen import _Dialect
-
-try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- _cext as _ods_cext,
- )
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-@_ods_cext.register_operation(_Dialect, replace=True)
-class LoadOp(LoadOp):
- """Specialization for the MemRef load operation."""
-
- def __init__(
- self,
- memref: Union[Operation, OpView, Value],
- indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- """Creates a memref load operation.
-
- Args:
- memref: the buffer to load from.
- indices: the list of subscripts, may be empty for zero-dimensional
- buffers.
- loc: user-visible location of the operation.
- ip: insertion point.
- """
- indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
- super().__init__(memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index a8d9c56f4233d9e..90d7d706238e649 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -21,43 +21,6 @@
)
-@_ods_cext.register_operation(_Dialect, replace=True)
-class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
- """Specialization for PDL apply native constraint op class."""
-
- def __init__(
- self,
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(name, args, loc=loc, ip=ip)
-
-
-@_ods_cext.register_operation(_Dialect, replace=True)
-class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
- """Specialization for PDL apply native rewrite op class."""
-
- def __init__(
- self,
- results: Sequence[Type],
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(results, name, args, loc=loc, ip=ip)
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class AttributeOp(AttributeOp):
"""Specialization for PDL attribute op class."""
@@ -75,21 +38,6 @@ def __init__(
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
-@_ods_cext.register_operation(_Dialect, replace=True)
-class EraseOp(EraseOp):
- """Specialization for PDL erase op class."""
-
- def __init__(
- self,
- operation: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- operation = _get_value(operation)
- super().__init__(operation, loc=loc, ip=ip)
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class OperandOp(OperandOp):
"""Specialization for PDL operand op class."""
@@ -216,23 +164,6 @@ def __init__(
super().__init__(result, parent, index, loc=loc, ip=ip)
-@_ods_cext.register_operation(_Dialect, replace=True)
-class ResultsOp(ResultsOp):
- """Specialization for PDL results op class."""
-
- def __init__(
- self,
- result: Type,
- parent: Union[OpView, Operation, Value],
- index: Optional[Union[IntegerAttr, int]] = None,
- *,
- loc=None,
- ip=None,
- ):
- parent = _get_value(parent)
- super().__init__(result, parent, index=index, loc=loc, ip=ip)
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class RewriteOp(RewriteOp):
"""Specialization for PDL rewrite op class."""
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 43ad9f4e2d65f51..71c80cab76dfb86 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -20,11 +20,8 @@
from typing import Optional, Sequence, Union
-_ForOp = ForOp
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
-class ForOp(_ForOp):
+class ForOp(ForOp):
"""Specialization for the SCF for op class."""
def __init__(
@@ -50,17 +47,8 @@ def __init__(
iter_args = _get_op_results_or_values(iter_args)
results = [arg.type for arg in iter_args]
- super(_ForOp, self).__init__(
- self.build_generic(
- regions=1,
- results=results,
- operands=[
- _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
- ]
- + list(iter_args),
- loc=loc,
- ip=ip,
- )
+ super().__init__(
+ results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
)
self.regions[0].blocks.append(self.operands[0].type, *results)
@@ -83,28 +71,23 @@ def inner_iter_args(self):
return self.body.arguments[1:]
-_IfOp = IfOp
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
-class IfOp(_IfOp):
+class IfOp(IfOp):
"""Specialization for the SCF if op class."""
- def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
+ def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
"""Creates an SCF `if` operation.
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- `hasElse` determines whether the if operation has the else branch.
"""
+ if results_ is None:
+ results_ = []
operands = []
operands.append(cond)
results = []
results.extend(results_)
- super(_IfOp, self).__init__(
- self.build_generic(
- regions=2, results=results, operands=operands, loc=loc, ip=ip
- )
- )
+ super().__init__(results, cond)
self.regions[0].blocks.append(*[])
if hasElse:
self.regions[1].blocks.append(*[])
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index d2e664d4653420f..c5ec85457493b42 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -37,7 +37,7 @@ def affine_store_test(arg0):
a1 = arith.ConstantOp(f32, 2.1)
# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
- affine.AffineStoreOp(a1, mem, map, map_operands=[arg0, arg0])
+ affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
return mem
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 161a12d78776a0e..a2014c64d2fa53b 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -84,6 +84,9 @@ def testFunctionCalls():
qux = func.FuncOp("qux", ([], [F32Type.get()]))
qux.sym_visibility = StringAttr.get("private")
+ con = func.ConstantOp(qux.type, qux.sym_name.value)
+ assert con.type == qux.type
+
with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
func.CallOp(foo, [])
func.CallOp([IndexType.get()], "bar", [])
@@ -94,6 +97,7 @@ def testFunctionCalls():
# CHECK: func private @foo()
# CHECK: func private @bar() -> index
# CHECK: func private @qux() -> f32
+# CHECK: %f = func.constant @qux : () -> f32
# CHECK: func @caller() {
# CHECK: call @foo() : () -> ()
# CHECK: %0 = call @bar() : () -> index
|
ftynse
approved these changes
Oct 19, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
#68853 enabled a lot of nice cleanup. Note, I made sure each of the touched extensions had tests.