Skip to content

[mlir][python] simplify extensions #69642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 0 additions & 45 deletions mlir/python/mlir/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,3 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._affine_ops_gen import *
from ._affine_ops_gen import _Dialect

try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import Optional, Sequence, Union


@_ods_cext.register_operation(_Dialect, replace=True)
class AffineStoreOp(AffineStoreOp):
"""Specialization for the Affine store operation."""

def __init__(
self,
value: Union[Operation, OpView, Value],
memref: Union[Operation, OpView, Value],
map: AffineMap = None,
*,
map_operands=None,
loc=None,
ip=None,
):
"""Creates an affine store operation.

- `value`: the value to store into the memref.
- `memref`: the buffer to store into.
- `map`: the affine map that maps the map_operands to the index of the
memref.
- `map_operands`: the list of arguments to substitute the dimensions,
then symbols in the affine map, in increasing order.
"""
map = map if map is not None else []
map_operands = map_operands if map_operands is not None else []
indicies = [_get_op_result_or_value(op) for op in map_operands]
_ods_successors = None
super().__init__(
value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
)
36 changes: 0 additions & 36 deletions mlir/python/mlir/dialects/bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._bufferization_ops_gen import *
from ._bufferization_ops_gen import _Dialect
from ._bufferization_enum_gen import *

try:
from typing import Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context, _cext as _ods_cext

from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e


@_ods_cext.register_operation(_Dialect, replace=True)
class AllocTensorOp(AllocTensorOp):
"""Extends the bufferization.alloc_tensor op."""

def __init__(
self,
tensor_type: Type,
dynamic_sizes: Sequence[Value],
copy: Value,
size_hint: Value,
escape: BoolAttr,
*,
loc=None,
ip=None,
):
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
super().__init__(
tensor_type,
dynamic_sizes,
copy=copy,
size_hint=size_hint,
loc=loc,
ip=ip,
)
3 changes: 0 additions & 3 deletions mlir/python/mlir/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""

def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
super().__init__(result, value, loc=loc, ip=ip)

@property
def type(self):
return self.results[0].type
Expand Down
38 changes: 0 additions & 38 deletions mlir/python/mlir/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,3 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._memref_ops_gen import *
from ._memref_ops_gen import _Dialect

try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import Optional, Sequence, Union


@_ods_cext.register_operation(_Dialect, replace=True)
class LoadOp(LoadOp):
"""Specialization for the MemRef load operation."""

def __init__(
self,
memref: Union[Operation, OpView, Value],
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
ip=None,
):
"""Creates a memref load operation.

Args:
memref: the buffer to load from.
indices: the list of subscripts, may be empty for zero-dimensional
buffers.
loc: user-visible location of the operation.
ip: insertion point.
"""
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
super().__init__(memref, indices_resolved, loc=loc, ip=ip)
69 changes: 0 additions & 69 deletions mlir/python/mlir/dialects/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,6 @@
)


@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
"""Specialization for PDL apply native constraint op class."""

def __init__(
self,
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(name, args, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
"""Specialization for PDL apply native rewrite op class."""

def __init__(
self,
results: Sequence[Type],
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(results, name, args, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class AttributeOp(AttributeOp):
"""Specialization for PDL attribute op class."""
Expand All @@ -75,21 +38,6 @@ def __init__(
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class EraseOp(EraseOp):
"""Specialization for PDL erase op class."""

def __init__(
self,
operation: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
operation = _get_value(operation)
super().__init__(operation, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class OperandOp(OperandOp):
"""Specialization for PDL operand op class."""
Expand Down Expand Up @@ -216,23 +164,6 @@ def __init__(
super().__init__(result, parent, index, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class ResultsOp(ResultsOp):
"""Specialization for PDL results op class."""

def __init__(
self,
result: Type,
parent: Union[OpView, Operation, Value],
index: Optional[Union[IntegerAttr, int]] = None,
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
super().__init__(result, parent, index=index, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class RewriteOp(RewriteOp):
"""Specialization for PDL rewrite op class."""
Expand Down
33 changes: 8 additions & 25 deletions mlir/python/mlir/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
from typing import Optional, Sequence, Union


_ForOp = ForOp


@_ods_cext.register_operation(_Dialect, replace=True)
class ForOp(_ForOp):
class ForOp(ForOp):
"""Specialization for the SCF for op class."""

def __init__(
Expand All @@ -50,17 +47,8 @@ def __init__(
iter_args = _get_op_results_or_values(iter_args)

results = [arg.type for arg in iter_args]
super(_ForOp, self).__init__(
self.build_generic(
regions=1,
results=results,
operands=[
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
]
+ list(iter_args),
loc=loc,
ip=ip,
)
super().__init__(
results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
)
self.regions[0].blocks.append(self.operands[0].type, *results)

Expand All @@ -83,28 +71,23 @@ def inner_iter_args(self):
return self.body.arguments[1:]


_IfOp = IfOp


@_ods_cext.register_operation(_Dialect, replace=True)
class IfOp(_IfOp):
class IfOp(IfOp):
"""Specialization for the SCF if op class."""

def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
"""Creates an SCF `if` operation.

- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- `hasElse` determines whether the if operation has the else branch.
"""
if results_ is None:
results_ = []
operands = []
operands.append(cond)
results = []
results.extend(results_)
super(_IfOp, self).__init__(
self.build_generic(
regions=2, results=results, operands=operands, loc=loc, ip=ip
)
)
super().__init__(results, cond)
self.regions[0].blocks.append(*[])
if hasElse:
self.regions[1].blocks.append(*[])
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/python/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def affine_store_test(arg0):
a1 = arith.ConstantOp(f32, 2.1)

# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
affine.AffineStoreOp(a1, mem, map, map_operands=[arg0, arg0])
affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return mem

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/python/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def testFunctionCalls():
qux = func.FuncOp("qux", ([], [F32Type.get()]))
qux.sym_visibility = StringAttr.get("private")

con = func.ConstantOp(qux.type, qux.sym_name.value)
assert con.type == qux.type

with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
func.CallOp(foo, [])
func.CallOp([IndexType.get()], "bar", [])
Expand All @@ -94,6 +97,7 @@ def testFunctionCalls():
# CHECK: func private @foo()
# CHECK: func private @bar() -> index
# CHECK: func private @qux() -> f32
# CHECK: %f = func.constant @qux : () -> f32
# CHECK: func @caller() {
# CHECK: call @foo() : () -> ()
# CHECK: %0 = call @bar() : () -> index
Expand Down