Skip to content

[mlir][python] simplify extensions #69642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 19, 2023

Conversation

makslevental
Copy link
Contributor

#68853 enabled a lot of nice cleanup. Note, I made sure each of the touched extensions had tests.

@@ -37,7 +37,7 @@ def affine_store_test(arg0):
a1 = arith.ConstantOp(f32, 2.1)

# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
affine.AffineStoreOp(a1, mem, map, map_operands=[arg0, arg0])
affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Oct 19, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2023

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

#68853 enabled a lot of nice cleanup. Note, I made sure each of the touched extensions had tests.


Full diff: https://github.com/llvm/llvm-project/pull/69642.diff

8 Files Affected:

  • (modified) mlir/python/mlir/dialects/affine.py (-45)
  • (modified) mlir/python/mlir/dialects/bufferization.py (-36)
  • (modified) mlir/python/mlir/dialects/func.py (-3)
  • (modified) mlir/python/mlir/dialects/memref.py (-38)
  • (modified) mlir/python/mlir/dialects/pdl.py (-69)
  • (modified) mlir/python/mlir/dialects/scf.py (+8-25)
  • (modified) mlir/test/python/dialects/affine.py (+1-1)
  • (modified) mlir/test/python/dialects/func.py (+4)
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 1eaccfa73a85cbf..80d3873e19a05cb 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,48 +3,3 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect
-
-try:
-    from ..ir import *
-    from ._ods_common import (
-        get_op_result_or_value as _get_op_result_or_value,
-        get_op_results_or_values as _get_op_results_or_values,
-        _cext as _ods_cext,
-    )
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-@_ods_cext.register_operation(_Dialect, replace=True)
-class AffineStoreOp(AffineStoreOp):
-    """Specialization for the Affine store operation."""
-
-    def __init__(
-        self,
-        value: Union[Operation, OpView, Value],
-        memref: Union[Operation, OpView, Value],
-        map: AffineMap = None,
-        *,
-        map_operands=None,
-        loc=None,
-        ip=None,
-    ):
-        """Creates an affine store operation.
-
-        - `value`: the value to store into the memref.
-        - `memref`: the buffer to store into.
-        - `map`: the affine map that maps the map_operands to the index of the
-          memref.
-        - `map_operands`: the list of arguments to substitute the dimensions,
-          then symbols in the affine map, in increasing order.
-        """
-        map = map if map is not None else []
-        map_operands = map_operands if map_operands is not None else []
-        indicies = [_get_op_result_or_value(op) for op in map_operands]
-        _ods_successors = None
-        super().__init__(
-            value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
-        )
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 0ce5448ace4b14c..759b6aa24a9ff73 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -3,40 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._bufferization_ops_gen import *
-from ._bufferization_ops_gen import _Dialect
 from ._bufferization_enum_gen import *
-
-try:
-    from typing import Sequence, Union
-    from ..ir import *
-    from ._ods_common import get_default_loc_context, _cext as _ods_cext
-
-    from typing import Any, List, Union
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-
-@_ods_cext.register_operation(_Dialect, replace=True)
-class AllocTensorOp(AllocTensorOp):
-    """Extends the bufferization.alloc_tensor op."""
-
-    def __init__(
-        self,
-        tensor_type: Type,
-        dynamic_sizes: Sequence[Value],
-        copy: Value,
-        size_hint: Value,
-        escape: BoolAttr,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
-        super().__init__(
-            tensor_type,
-            dynamic_sizes,
-            copy=copy,
-            size_hint=size_hint,
-            loc=loc,
-            ip=ip,
-        )
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 9c6c4c9092c7a88..6599f67b7078777 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -26,9 +26,6 @@
 class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
 
-    def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
-        super().__init__(result, value, loc=loc, ip=ip)
-
     @property
     def type(self):
         return self.results[0].type
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 111ad2178703d28..3afb6a70cb9e0db 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -3,41 +3,3 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._memref_ops_gen import *
-from ._memref_ops_gen import _Dialect
-
-try:
-    from ..ir import *
-    from ._ods_common import (
-        get_op_result_or_value as _get_op_result_or_value,
-        get_op_results_or_values as _get_op_results_or_values,
-        _cext as _ods_cext,
-    )
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-@_ods_cext.register_operation(_Dialect, replace=True)
-class LoadOp(LoadOp):
-    """Specialization for the MemRef load operation."""
-
-    def __init__(
-        self,
-        memref: Union[Operation, OpView, Value],
-        indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        """Creates a memref load operation.
-
-        Args:
-          memref: the buffer to load from.
-          indices: the list of subscripts, may be empty for zero-dimensional
-            buffers.
-          loc: user-visible location of the operation.
-          ip: insertion point.
-        """
-        indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
-        super().__init__(memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index a8d9c56f4233d9e..90d7d706238e649 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -21,43 +21,6 @@
 )
 
 
-@_ods_cext.register_operation(_Dialect, replace=True)
-class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
-    """Specialization for PDL apply native constraint op class."""
-
-    def __init__(
-        self,
-        name: Union[str, StringAttr],
-        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        if args is None:
-            args = []
-        args = _get_values(args)
-        super().__init__(name, args, loc=loc, ip=ip)
-
-
-@_ods_cext.register_operation(_Dialect, replace=True)
-class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
-    """Specialization for PDL apply native rewrite op class."""
-
-    def __init__(
-        self,
-        results: Sequence[Type],
-        name: Union[str, StringAttr],
-        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        if args is None:
-            args = []
-        args = _get_values(args)
-        super().__init__(results, name, args, loc=loc, ip=ip)
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class AttributeOp(AttributeOp):
     """Specialization for PDL attribute op class."""
@@ -75,21 +38,6 @@ def __init__(
         super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
 
 
-@_ods_cext.register_operation(_Dialect, replace=True)
-class EraseOp(EraseOp):
-    """Specialization for PDL erase op class."""
-
-    def __init__(
-        self,
-        operation: Optional[Union[OpView, Operation, Value]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        operation = _get_value(operation)
-        super().__init__(operation, loc=loc, ip=ip)
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class OperandOp(OperandOp):
     """Specialization for PDL operand op class."""
@@ -216,23 +164,6 @@ def __init__(
         super().__init__(result, parent, index, loc=loc, ip=ip)
 
 
-@_ods_cext.register_operation(_Dialect, replace=True)
-class ResultsOp(ResultsOp):
-    """Specialization for PDL results op class."""
-
-    def __init__(
-        self,
-        result: Type,
-        parent: Union[OpView, Operation, Value],
-        index: Optional[Union[IntegerAttr, int]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        parent = _get_value(parent)
-        super().__init__(result, parent, index=index, loc=loc, ip=ip)
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class RewriteOp(RewriteOp):
     """Specialization for PDL rewrite op class."""
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 43ad9f4e2d65f51..71c80cab76dfb86 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -20,11 +20,8 @@
 from typing import Optional, Sequence, Union
 
 
-_ForOp = ForOp
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
-class ForOp(_ForOp):
+class ForOp(ForOp):
     """Specialization for the SCF for op class."""
 
     def __init__(
@@ -50,17 +47,8 @@ def __init__(
         iter_args = _get_op_results_or_values(iter_args)
 
         results = [arg.type for arg in iter_args]
-        super(_ForOp, self).__init__(
-            self.build_generic(
-                regions=1,
-                results=results,
-                operands=[
-                    _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
-                ]
-                + list(iter_args),
-                loc=loc,
-                ip=ip,
-            )
+        super().__init__(
+            results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
         )
         self.regions[0].blocks.append(self.operands[0].type, *results)
 
@@ -83,28 +71,23 @@ def inner_iter_args(self):
         return self.body.arguments[1:]
 
 
-_IfOp = IfOp
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
-class IfOp(_IfOp):
+class IfOp(IfOp):
     """Specialization for the SCF if op class."""
 
-    def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
+    def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
         """Creates an SCF `if` operation.
 
         - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
         - `hasElse` determines whether the if operation has the else branch.
         """
+        if results_ is None:
+            results_ = []
         operands = []
         operands.append(cond)
         results = []
         results.extend(results_)
-        super(_IfOp, self).__init__(
-            self.build_generic(
-                regions=2, results=results, operands=operands, loc=loc, ip=ip
-            )
-        )
+        super().__init__(results, cond)
         self.regions[0].blocks.append(*[])
         if hasElse:
             self.regions[1].blocks.append(*[])
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index d2e664d4653420f..c5ec85457493b42 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -37,7 +37,7 @@ def affine_store_test(arg0):
                 a1 = arith.ConstantOp(f32, 2.1)
 
                 # CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
-                affine.AffineStoreOp(a1, mem, map, map_operands=[arg0, arg0])
+                affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
 
                 return mem
 
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 161a12d78776a0e..a2014c64d2fa53b 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -84,6 +84,9 @@ def testFunctionCalls():
     qux = func.FuncOp("qux", ([], [F32Type.get()]))
     qux.sym_visibility = StringAttr.get("private")
 
+    con = func.ConstantOp(qux.type, qux.sym_name.value)
+    assert con.type == qux.type
+
     with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
         func.CallOp(foo, [])
         func.CallOp([IndexType.get()], "bar", [])
@@ -94,6 +97,7 @@ def testFunctionCalls():
 # CHECK: func private @foo()
 # CHECK: func private @bar() -> index
 # CHECK: func private @qux() -> f32
+# CHECK: %f = func.constant @qux : () -> f32
 # CHECK: func @caller() {
 # CHECK:   call @foo() : () -> ()
 # CHECK:   %0 = call @bar() : () -> index

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants