Skip to content

[mlir][linalg][transform][python] Drop _get_op_result... from mix-ins. #65726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 18 additions & 29 deletions mlir/python/mlir/dialects/_structured_transform_ops_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl, transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
Expand Down Expand Up @@ -101,7 +100,7 @@ def _dispatch_mixed_values(
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(_get_op_result_or_value(size))
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)

return (dynamic_values, packed_values, static_values)
Expand Down Expand Up @@ -204,9 +203,7 @@ class DecomposeOp:
"""Specialization for DecomposeOp class."""

def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
)
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this commit, but since we are in cleanup mode, can we replace all this pdl with transform.any_op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, that answers a question of mine ("whether and why we needed the PDL types"). OK, will do.



class FuseIntoContainingOp:
Expand Down Expand Up @@ -277,9 +274,7 @@ class GeneralizeOp:
"""Specialization for GeneralizeOp class."""

def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
)
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)


class InterchangeOp:
Expand All @@ -296,7 +291,7 @@ def __init__(
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
target,
iterator_interchange=iterator_interchange,
loc=loc,
ip=ip,
Expand Down Expand Up @@ -415,7 +410,7 @@ def match_op_names(
loc=None,
ip=None,
):
...
...

@overload
@classmethod
Expand All @@ -428,7 +423,7 @@ def match_op_names(
loc=None,
ip=None,
):
...
...

@classmethod
def match_op_names(
Expand All @@ -441,20 +436,20 @@ def match_op_names(
ip=None,
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
target = target_or_names
names = names_or_none
result_type = result_type_or_target
target = target_or_names
names = names_or_none
else:
result_type = transform.AnyOpType.get()
target = result_type_or_target
names = target_or_names
result_type = transform.AnyOpType.get()
target = result_type_or_target
names = target_or_names

if isinstance(names, str):
names = [names]
names = [names]

return cls(
result_type,
_get_op_result_or_value(target),
target,
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
loc=loc,
ip=ip,
Expand All @@ -479,7 +474,7 @@ def __init__(
result_type,
result_type,
result_type,
_get_op_result_or_value(target),
target,
dimension=dimension,
target_size=target_size,
divisor=divisor,
Expand Down Expand Up @@ -530,9 +525,7 @@ class ScalarizeOp:

def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
)
super().__init__(pdl_operation_type, target, loc=loc, ip=ip)


class SplitOp:
Expand All @@ -552,9 +545,7 @@ def __init__(
dynamic_split_point = None
else:
static_split_point = ShapedType.get_dynamic_size()
dynamic_split_point = _get_op_result_or_value(split_point)

target = _get_op_result_or_value(target)
dynamic_split_point = split_point

super().__init__(
target.type,
Expand Down Expand Up @@ -626,8 +617,6 @@ def __init__(
)
target = target_or_none

target = _get_op_result_or_value(target)

super().__init__(
target.type,
loop_types,
Expand Down Expand Up @@ -750,7 +739,7 @@ def __init__(
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
target,
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
vectorize_nd_extract=vectorize_nd_extract,
Expand Down