Skip to content

Commit a8645a3

Browse files
[mlir][Linalg] Post submit addressed comments missed in f0cdc5bcd3f25192f12bfaff072ce02497b59c3c
Differential Revision: https://reviews.llvm.org/D133936
1 parent 24c10ab commit a8645a3

File tree

5 files changed

+43
-36
lines changed

5 files changed

+43
-36
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -767,9 +767,17 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
767767
Note that this transformation is invalidating the handles to any payload IR
768768
operation that is contained inside the vectorization target.
769769

770-
`disable_multi_reduction_to_contract_patterns` and
771-
`disable_transfer_permutation_map_lowering_patterns` limits the power of
772-
vectorization. They are currently intended for testing purposes.
770+
This transformation supports the following attributes:
771+
- `vectorize_padding`: a UnitAttr to activate the vectorization of
772+
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
773+
loops.
774+
- `disable_multi_reduction_to_contract_patterns`: a UnitAttr to deactivate
775+
the rewrite of `vector.multi_reduction` to `vector.contract`. This is
776+
intended to be used in tests only.
777+
- `disable_transfer_permutation_map_lowering_patterns`: a UnitAttr to
778+
deactivate the rewrite of `vector.transfer` with permutation maps into
779+
explicit `vector.transpose` operations. This is intended to be used in
780+
tests only but may be promotoed to a first class attribute in the future.
773781

774782
#### Return modes:
775783

@@ -780,13 +788,12 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
780788
}];
781789

782790
let arguments = (ins PDL_Operation:$target,
783-
DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding,
784-
DefaultValuedAttr<BoolAttr, "false">:$disable_multi_reduction_to_contract_patterns,
785-
DefaultValuedAttr<BoolAttr, "false">:$disable_transfer_permutation_map_lowering_patterns);
791+
UnitAttr:$vectorize_padding,
792+
UnitAttr:$disable_multi_reduction_to_contract_patterns,
793+
UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
786794
let results = (outs PDL_Operation:$transformed);
787795

788796
let assemblyFormat = "$target attr-dict";
789-
790797
let extraClassDeclaration = [{
791798
::mlir::DiagnosedSilenceableFailure applyToOne(
792799
::mlir::Operation *target,

mlir/python/mlir/dialects/_structured_transform_ops_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def __init__(self,
287287
ip=None):
288288
pdl_operation_type = pdl.OperationType.get()
289289
if isinstance(vectorize_padding, bool):
290-
vectorize_padding = BoolAttr.get(vectorize_padding)
290+
vectorize_padding = UnitAttr.get()
291291
super().__init__(
292292
pdl_operation_type,
293293
_get_op_result_or_value(target),

mlir/test/Dialect/Linalg/transform-op-vectorize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ transform.with_pdl_patterns {
130130
^bb1(%arg1: !pdl.operation):
131131
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
132132
%1 = get_closest_isolated_parent %0
133-
%2 = transform.structured.vectorize %1 {vectorize_padding = true}
133+
%2 = transform.structured.vectorize %1 {vectorize_padding}
134134
}
135135
}
136136

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ transform.with_pdl_patterns {
1818
^bb1(%arg1: !pdl.operation):
1919
%0 = transform.structured.match ops{["linalg.dot"]} in %arg1
2020
%1 = get_closest_isolated_parent %0
21-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
21+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns }
2222
}
2323
}
2424

@@ -40,7 +40,7 @@ transform.with_pdl_patterns {
4040
^bb1(%arg1: !pdl.operation):
4141
%0 = transform.structured.match ops{["linalg.matvec"]} in %arg1
4242
%1 = get_closest_isolated_parent %0
43-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
43+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns }
4444
}
4545
}
4646

@@ -61,7 +61,7 @@ transform.with_pdl_patterns {
6161
^bb1(%arg1: !pdl.operation):
6262
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
6363
%1 = get_closest_isolated_parent %0
64-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
64+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns }
6565
}
6666
}
6767

@@ -83,7 +83,7 @@ transform.with_pdl_patterns {
8383
^bb1(%arg1: !pdl.operation):
8484
%0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1
8585
%1 = get_closest_isolated_parent %0
86-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
86+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns }
8787
}
8888
}
8989

@@ -126,7 +126,7 @@ transform.with_pdl_patterns {
126126
^bb1(%arg1: !pdl.operation):
127127
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
128128
%1 = get_closest_isolated_parent %0
129-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
129+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
130130
}
131131
}
132132

@@ -169,7 +169,7 @@ transform.with_pdl_patterns {
169169
^bb1(%arg1: !pdl.operation):
170170
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
171171
%1 = get_closest_isolated_parent %0
172-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
172+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
173173
}
174174
}
175175

@@ -199,7 +199,7 @@ transform.with_pdl_patterns {
199199
^bb1(%arg1: !pdl.operation):
200200
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
201201
%1 = get_closest_isolated_parent %0
202-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
202+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
203203
}
204204
}
205205

@@ -242,7 +242,7 @@ transform.with_pdl_patterns {
242242
^bb1(%arg1: !pdl.operation):
243243
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
244244
%1 = get_closest_isolated_parent %0
245-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
245+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
246246
}
247247
}
248248

@@ -265,7 +265,7 @@ transform.with_pdl_patterns {
265265
^bb1(%arg1: !pdl.operation):
266266
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
267267
%1 = get_closest_isolated_parent %0
268-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
268+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns }
269269
}
270270
}
271271

@@ -553,7 +553,7 @@ transform.with_pdl_patterns {
553553
^bb1(%arg1: !pdl.operation):
554554
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
555555
%1 = get_closest_isolated_parent %0
556-
%2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true }
556+
%2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns }
557557
}
558558
}
559559

@@ -647,7 +647,7 @@ transform.with_pdl_patterns {
647647
^bb1(%arg1: !pdl.operation):
648648
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
649649
%1 = get_closest_isolated_parent %0
650-
%2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns = true }
650+
%2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns }
651651
}
652652
}
653653

@@ -694,7 +694,7 @@ transform.with_pdl_patterns {
694694
^bb1(%arg1: !pdl.operation):
695695
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
696696
%1 = get_closest_isolated_parent %0
697-
%2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true }
697+
%2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns }
698698
}
699699
}
700700

@@ -740,7 +740,7 @@ transform.with_pdl_patterns {
740740
^bb1(%arg1: !pdl.operation):
741741
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
742742
%1 = get_closest_isolated_parent %0
743-
%2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true }
743+
%2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns }
744744
}
745745
}
746746

@@ -775,7 +775,7 @@ transform.with_pdl_patterns {
775775
^bb1(%arg1: !pdl.operation):
776776
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
777777
%1 = get_closest_isolated_parent %0
778-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
778+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
779779
}
780780
}
781781

@@ -807,7 +807,7 @@ transform.with_pdl_patterns {
807807
^bb1(%arg1: !pdl.operation):
808808
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1
809809
%1 = get_closest_isolated_parent %0
810-
%2 = transform.structured.vectorize %1 { vectorize_padding = true }
810+
%2 = transform.structured.vectorize %1 { vectorize_padding }
811811
}
812812
}
813813

@@ -839,7 +839,7 @@ transform.with_pdl_patterns {
839839
^bb1(%arg1: !pdl.operation):
840840
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1
841841
%1 = get_closest_isolated_parent %0
842-
%2 = transform.structured.vectorize %1 { vectorize_padding = true }
842+
%2 = transform.structured.vectorize %1 { vectorize_padding }
843843
}
844844
}
845845

@@ -879,7 +879,7 @@ transform.with_pdl_patterns {
879879
^bb1(%arg1: !pdl.operation):
880880
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1
881881
%1 = get_closest_isolated_parent %0
882-
%2 = transform.structured.vectorize %1 { vectorize_padding = true }
882+
%2 = transform.structured.vectorize %1 { vectorize_padding }
883883
}
884884
}
885885

@@ -913,7 +913,7 @@ transform.with_pdl_patterns {
913913
^bb1(%arg1: !pdl.operation):
914914
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1
915915
%1 = get_closest_isolated_parent %0
916-
%2 = transform.structured.vectorize %1 { vectorize_padding = true }
916+
%2 = transform.structured.vectorize %1 { vectorize_padding }
917917
}
918918
}
919919

@@ -949,7 +949,7 @@ transform.with_pdl_patterns {
949949
^bb1(%arg1: !pdl.operation):
950950
%3 = transform.structured.match ops{["tensor.pad"]} in %arg1
951951
%4 = get_closest_isolated_parent %3
952-
%5 = transform.structured.vectorize %4 { vectorize_padding = true }
952+
%5 = transform.structured.vectorize %4 { vectorize_padding }
953953
}
954954
}
955955

@@ -989,7 +989,7 @@ transform.with_pdl_patterns {
989989
^bb1(%arg1: !pdl.operation):
990990
%3 = transform.structured.match ops{["tensor.pad"]} in %arg1
991991
%4 = get_closest_isolated_parent %3
992-
%5 = transform.structured.vectorize %4 { vectorize_padding = true }
992+
%5 = transform.structured.vectorize %4 { vectorize_padding }
993993
}
994994
}
995995

@@ -1026,7 +1026,7 @@ transform.with_pdl_patterns {
10261026
^bb1(%arg1: !pdl.operation):
10271027
%3 = transform.structured.match ops{["tensor.pad"]} in %arg1
10281028
%4 = get_closest_isolated_parent %3
1029-
%5 = transform.structured.vectorize %4 { vectorize_padding = true }
1029+
%5 = transform.structured.vectorize %4 { vectorize_padding }
10301030
}
10311031
}
10321032

@@ -1097,7 +1097,7 @@ transform.with_pdl_patterns {
10971097
^bb1(%arg1: !pdl.operation):
10981098
%3 = transform.structured.match ops{["tensor.pad"]} in %arg1
10991099
%4 = get_closest_isolated_parent %3
1100-
%5 = transform.structured.vectorize %4 { vectorize_padding = true }
1100+
%5 = transform.structured.vectorize %4 { vectorize_padding }
11011101
}
11021102
}
11031103

@@ -1183,7 +1183,7 @@ transform.with_pdl_patterns {
11831183
^bb1(%arg1: !pdl.operation):
11841184
%3 = transform.structured.match ops{["linalg.generic"]} in %arg1
11851185
%4 = get_closest_isolated_parent %3
1186-
%5 = transform.structured.vectorize %4 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
1186+
%5 = transform.structured.vectorize %4 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
11871187
}
11881188
}
11891189

@@ -1216,7 +1216,7 @@ transform.with_pdl_patterns {
12161216
^bb1(%arg1: !pdl.operation):
12171217
%3 = transform.structured.match ops{["linalg.generic"]} in %arg1
12181218
%4 = get_closest_isolated_parent %3
1219-
%5 = transform.structured.vectorize %4 { vectorize_padding = true }
1219+
%5 = transform.structured.vectorize %4 { vectorize_padding }
12201220
}
12211221
}
12221222

@@ -1586,6 +1586,6 @@ transform.with_pdl_patterns {
15861586
^bb1(%arg1: !pdl.operation):
15871587
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
15881588
%1 = get_closest_isolated_parent %0
1589-
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
1589+
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
15901590
}
1591-
}
1591+
}

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,4 @@ def testVectorize():
178178
# CHECK-LABEL: TEST: testVectorize
179179
# CHECK: transform.sequence
180180
# CHECK: = transform.structured.vectorize
181-
# CHECK: vectorize_padding = true
181+
# CHECK: {vectorize_padding}

0 commit comments

Comments
 (0)