Skip to content

Commit a470df3

Browse files
[mlir][linalg][transform][python] Extend mix-in for Vectorize
Extends the existing mix-in for VectorizeOp with support for the missing unit attributes. Also fixes the unintuitive implementation where `structured.VectorizeOp(target=target, vectorize_padding=False)` still resulted in the creation of the UnitAttr `vectorize_padding`. Reviewed By: ingomueller-net Differential Revision: https://reviews.llvm.org/D158726
1 parent fff1830 commit a470df3

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

mlir/python/mlir/dialects/_structured_transform_ops_ext.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,16 +783,20 @@ def __init__(
783783
self,
784784
target: Union[Operation, Value],
785785
*,
786-
vectorize_padding: Union[bool, BoolAttr] = False,
786+
disable_multi_reduction_to_contract_patterns: bool = False,
787+
disable_transfer_permutation_map_lowering_patterns: bool = False,
788+
vectorize_nd_extract: bool = False,
789+
vectorize_padding: bool = False,
787790
loc=None,
788791
ip=None,
789792
):
790793
pdl_operation_type = pdl.OperationType.get()
791-
if isinstance(vectorize_padding, bool):
792-
vectorize_padding = UnitAttr.get()
793794
super().__init__(
794795
pdl_operation_type,
795796
_get_op_result_or_value(target),
797+
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
798+
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
799+
vectorize_nd_extract=vectorize_nd_extract,
796800
vectorize_padding=vectorize_padding,
797801
loc=loc,
798802
ip=ip,

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,17 +560,49 @@ def testTileToForallMapping():
560560

561561

562562
@run
563-
def testVectorize():
563+
def testVectorizeAllAttrs():
564564
sequence = transform.SequenceOp(
565565
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
566566
)
567567
with InsertionPoint(sequence.body):
568-
structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
568+
structured.VectorizeOp(
569+
sequence.bodyTarget,
570+
disable_multi_reduction_to_contract_patterns=True,
571+
disable_transfer_permutation_map_lowering_patterns=True,
572+
vectorize_nd_extract=True,
573+
vectorize_padding=True,
574+
)
575+
transform.YieldOp()
576+
# CHECK-LABEL: TEST: testVectorizeAllAttrs
577+
# CHECK: transform.sequence
578+
# CHECK: = transform.structured.vectorize
579+
# CHECK-SAME: disable_multi_reduction_to_contract_patterns
580+
# CHECK-SAME: disable_transfer_permutation_map_lowering_patterns
581+
# CHECK-SAME: vectorize_nd_extract
582+
# CHECK-SAME: vectorize_padding
583+
584+
585+
@run
586+
def testVectorizeNoAttrs():
587+
sequence = transform.SequenceOp(
588+
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
589+
)
590+
with InsertionPoint(sequence.body):
591+
structured.VectorizeOp(
592+
sequence.bodyTarget,
593+
disable_multi_reduction_to_contract_patterns=False,
594+
disable_transfer_permutation_map_lowering_patterns=False,
595+
vectorize_nd_extract=False,
596+
vectorize_padding=False,
597+
)
569598
transform.YieldOp()
570-
# CHECK-LABEL: TEST: testVectorize
599+
# CHECK-LABEL: TEST: testVectorizeNoAttrs
571600
# CHECK: transform.sequence
572601
# CHECK: = transform.structured.vectorize
573-
# CHECK: {vectorize_padding}
602+
# CHECK-NOT: disable_multi_reduction_to_contract_patterns
603+
# CHECK-NOT: disable_transfer_permutation_map_lowering_patterns
604+
# CHECK-NOT: vectorize_nd_extract
605+
# CHECK-NOT: vectorize_padding
574606

575607

576608
@run

0 commit comments

Comments
 (0)