Skip to content

Commit 31aa8ea

Browse files
[mlir][Linalg][Transform] Avoid FunctionalStyleTransformOpTrait where unnecesseary to improve usability
Differential Revision: https://reviews.llvm.org/D146305
1 parent e0f8f1f commit 31aa8ea

File tree

16 files changed

+302
-254
lines changed

16 files changed

+302
-254
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ include "mlir/IR/OpBase.td"
1717

1818
def MapNestedForallToThreads :
1919
Op<Transform_Dialect, "gpu.map_nested_forall_to_threads",
20-
[FunctionalStyleTransformOpTrait,
21-
MemoryEffectsOpInterface,
20+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
2221
TransformEachOpTrait,
2322
TransformOpInterface]> {
2423
let description = [{
@@ -72,9 +71,7 @@ def MapNestedForallToThreads :
7271
scf.forall operations with mappings other than gpu.thread are
7372
ignored.
7473

75-
The returned handle points to the same LaunchOp operand, consuming it and
76-
producing a new SSA value to satisfy chaining and linearity of the IR
77-
properties.
74+
This operation returns nothing.
7875

7976
#### Example:
8077

@@ -111,18 +108,19 @@ def MapNestedForallToThreads :
111108
```
112109
}];
113110

114-
let arguments = (ins PDL_Operation:$target,
111+
let arguments = (ins TransformHandleTypeInterface:$target,
115112
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$block_dims,
116113
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$warp_dims,
117114
DefaultValuedAttr<BoolAttr, "true">:$sync_after_distribute);
118-
let results = (outs PDL_Operation:$result);
115+
let results = (outs);
119116

120117
let assemblyFormat = [{
121118
$target
122119
`block_dims` `=` $block_dims
123120
(`warp_dims` `=` $warp_dims^)?
124121
(`sync_after_distribute` `=` $sync_after_distribute^)?
125122
attr-dict
123+
`:` functional-type(operands, results)
126124
}];
127125
let extraClassDeclaration = [{
128126
::mlir::DiagnosedSilenceableFailure applyToOne(

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,11 +1651,13 @@ def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
16511651
//===----------------------------------------------------------------------===//
16521652

16531653
def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
1654-
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
1655-
TransformEachOpTrait, TransformOpInterface]> {
1654+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1655+
TransformEachOpTrait,
1656+
TransformOpInterface]> {
16561657
let description = [{
16571658
Indicates that the given `target` op all the ops it contains should be
16581659
vectorized with the configuration specified by the attributes of this op.
1660+
16591661
This vectorization only handles structured ops that operate on shaped types
16601662
and does not vectorize loops or straight-line. Internally, it applies a
16611663
set of rewrite patterns, some of which enable vectorization and some of
@@ -1685,24 +1687,22 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
16851687

16861688
This operation produces `definiteFailure` if vectorization fails for any
16871689
reason.
1688-
The operation always returns the handle to the target op that is expected
1689-
to be isolated from above.
1690+
This operation returns nothing.
16901691
}];
16911692

1692-
let arguments = (ins PDL_Operation:$target,
1693+
let arguments = (ins TransformHandleTypeInterface:$target,
16931694
UnitAttr:$vectorize_padding,
16941695
UnitAttr:$vectorize_nd_extract,
16951696
UnitAttr:$disable_multi_reduction_to_contract_patterns,
16961697
UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
1697-
let results = (outs PDL_Operation:$transformed);
1698+
let results = (outs);
16981699

1699-
let assemblyFormat = "$target attr-dict";
1700+
let assemblyFormat = [{
1701+
$target
1702+
attr-dict
1703+
`:` functional-type(operands, results)
1704+
}];
17001705

1701-
let builders = [
1702-
OpBuilder<(ins "Value":$target,
1703-
CArg<"bool", "false">:$vectorizePadding,
1704-
CArg<"bool", "false">:$vectorizeNDExtract)>,
1705-
];
17061706
let extraClassDeclaration = [{
17071707
::mlir::DiagnosedSilenceableFailure applyToOne(
17081708
::mlir::Operation *target,
@@ -1711,6 +1711,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
17111711
}];
17121712
}
17131713

1714+
//===----------------------------------------------------------------------===//
1715+
// MaskedVectorizeOp
1716+
//===----------------------------------------------------------------------===//
1717+
17141718
def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
17151719
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
17161720
TransformOpInterface]> {
@@ -1765,8 +1769,9 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
17651769

17661770
def HoistRedundantVectorTransfersOp :
17671771
Op<Transform_Dialect, "structured.hoist_redundant_vector_transfers",
1768-
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
1769-
TransformEachOpTrait, TransformOpInterface]> {
1772+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1773+
TransformEachOpTrait,
1774+
TransformOpInterface]> {
17701775
let description = [{
17711776
Hoist vector.transfer_read / vector.transfer_write pairs out of immediately
17721777
enclosing scf::ForOp iteratively, if the following conditions are true:
@@ -1782,18 +1787,17 @@ def HoistRedundantVectorTransfersOp :
17821787

17831788
#### Return modes:
17841789

1785-
The operation always succeeds and returns a handle to the transformed
1786-
function op.
1790+
The operation always succeeds and returns nothing.
17871791
}];
17881792

17891793
let arguments = (ins TransformHandleTypeInterface:$target);
1790-
let results = (outs TransformHandleTypeInterface:$transformed);
1791-
1792-
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
1794+
let results = (outs);
1795+
let assemblyFormat = [{
1796+
$target
1797+
attr-dict
1798+
`:` functional-type(operands, results)
1799+
}];
17931800

1794-
let builders = [
1795-
OpBuilder<(ins "Value":$target)>,
1796-
];
17971801
let extraClassDeclaration = [{
17981802
::mlir::DiagnosedSilenceableFailure applyToOne(
17991803
::mlir::func::FuncOp target,
@@ -1884,8 +1888,9 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
18841888

18851889
def HoistRedundantTensorSubsetsOp :
18861890
Op<Transform_Dialect, "structured.hoist_redundant_tensor_subsets",
1887-
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
1888-
TransformEachOpTrait, TransformOpInterface]> {
1891+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1892+
TransformEachOpTrait,
1893+
TransformOpInterface]> {
18891894
let description = [{
18901895
Hoists supported tensor subset extract/insert operation pairs out of
18911896
immediately enclosing loop iteratively, if the following conditions
@@ -1905,18 +1910,18 @@ def HoistRedundantTensorSubsetsOp :
19051910

19061911
#### Return modes:
19071912

1908-
The operation always succeeds and returns a handle to the transformed
1909-
function op.
1913+
The operation always succeeds and returns nothing.
19101914
}];
19111915

19121916
let arguments = (ins TransformHandleTypeInterface:$target);
1913-
let results = (outs TransformHandleTypeInterface:$transformed);
1917+
let results = (outs);
19141918

1915-
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
1919+
let assemblyFormat = [{
1920+
$target
1921+
attr-dict
1922+
`:` functional-type(operands, results)
1923+
}];
19161924

1917-
let builders = [
1918-
OpBuilder<(ins "Value":$target)>,
1919-
];
19201925
let extraClassDeclaration = [{
19211926
::mlir::DiagnosedSilenceableFailure applyToOne(
19221927
::mlir::Operation *target,

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1717
include "mlir/IR/OpBase.td"
1818

1919
def LowerVectorsOp : Op<Transform_Dialect, "vector.lower_vectors",
20-
[DeclareOpInterfaceMethods<TransformOpInterface>,
21-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
20+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
21+
TransformEachOpTrait,
22+
TransformOpInterface]> {
2223
let description = [{
2324
Indicates that the vector operations nested under the isolated from above op
2425
`target` should be lowered to finer-grained vector primitives.
@@ -27,10 +28,14 @@ def LowerVectorsOp : Op<Transform_Dialect, "vector.lower_vectors",
2728

2829
This is usally a late step that is run after bufferization as part of the
2930
process of lowering to e.g. LLVM or NVVM.
31+
32+
#### Return modes:
33+
34+
The operation returns nothing.
3035
}];
3136

3237
// TODO: evolve this to proper enums.
33-
let arguments = (ins PDL_Operation:$target,
38+
let arguments = (ins TransformHandleTypeInterface:$target,
3439
DefaultValuedAttr<VectorContractLoweringAttr,
3540
"vector::VectorContractLowering::OuterProduct">:$contraction_lowering,
3641
DefaultValuedAttr<VectorMultiReductionLoweringAttr,
@@ -43,7 +48,7 @@ def LowerVectorsOp : Op<Transform_Dialect, "vector.lower_vectors",
4348
DefaultValuedAttr<BoolAttr, "false">:$transpose_avx2_lowering,
4449
DefaultValuedAttr<BoolAttr, "true">:$unroll_vector_transfers
4550
);
46-
let results = (outs PDL_Operation:$results);
51+
let results = (outs);
4752

4853
let builders = [
4954
OpBuilder<(ins "Type":$resultType, "Value":$target,
@@ -66,6 +71,14 @@ def LowerVectorsOp : Op<Transform_Dialect, "vector.lower_vectors",
6671
| `transpose_lowering` `=` $transpose_lowering
6772
)
6873
attr-dict
74+
`:` functional-type(operands, results)
75+
}];
76+
77+
let extraClassDeclaration = [{
78+
::mlir::DiagnosedSilenceableFailure applyToOne(
79+
::mlir::Operation *target,
80+
::mlir::transform::ApplyToEachResultList &results,
81+
::mlir::transform::TransformState &state);
6982
}];
7083
}
7184

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,12 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
848848
return DiagnosedSilenceableFailure::success();
849849
}
850850

851+
void transform::MapNestedForallToThreads::getEffects(
852+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
853+
onlyReadsHandle(getTarget(), effects);
854+
modifiesPayload(effects);
855+
}
856+
851857
DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
852858
Operation *target, ApplyToEachResultList &results, TransformState &state) {
853859
LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
@@ -880,7 +886,6 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
880886
mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
881887
getWarpDims(), getSyncAfterDistribute());
882888

883-
results.push_back(gpuLaunch.getOperation());
884889
return diag;
885890
}
886891

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,7 +1786,7 @@ LogicalResult transform::PadOp::verify() {
17861786
}
17871787

17881788
//===---------------------------------------------------------------------===//
1789-
// HoistPadOp
1789+
// PadOp
17901790
//===---------------------------------------------------------------------===//
17911791

17921792
DiagnosedSilenceableFailure
@@ -2977,21 +2977,6 @@ void transform::TileToScfForOp::getEffects(
29772977
// VectorizeOp
29782978
//===----------------------------------------------------------------------===//
29792979

2980-
void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result,
2981-
Value target, bool vectorizePadding,
2982-
bool vectorizeExtract) {
2983-
result.addOperands(target);
2984-
if (vectorizePadding) {
2985-
result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name),
2986-
builder.getUnitAttr());
2987-
}
2988-
if (vectorizeExtract) {
2989-
result.addAttribute(VectorizeOp::getVectorizeNdExtractAttrName(result.name),
2990-
builder.getUnitAttr());
2991-
}
2992-
result.addTypes(pdl::OperationType::get(builder.getContext()));
2993-
}
2994-
29952980
namespace {
29962981
/// This is an helper only to call vectorize via a pattern inside of
29972982
/// VectorizeOp::applyToOne.
@@ -3050,10 +3035,15 @@ transform::VectorizeOp::applyToOne(Operation *target,
30503035
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
30513036
return emitDefaultDefiniteFailure(target);
30523037

3053-
results.push_back(target);
30543038
return DiagnosedSilenceableFailure::success();
30553039
}
30563040

3041+
void transform::VectorizeOp::getEffects(
3042+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3043+
transform::onlyReadsHandle(getTarget(), effects);
3044+
transform::modifiesPayload(effects);
3045+
}
3046+
30573047
//===----------------------------------------------------------------------===//
30583048
// MaskedVectorizeOp
30593049
//===----------------------------------------------------------------------===//
@@ -3133,22 +3123,6 @@ SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() {
31333123
return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
31343124
}
31353125

3136-
//===----------------------------------------------------------------------===//
3137-
// HoistRedundantVectorTransfersOp
3138-
//===----------------------------------------------------------------------===//
3139-
3140-
DiagnosedSilenceableFailure
3141-
transform::HoistRedundantVectorTransfersOp::applyToOne(
3142-
func::FuncOp target, transform::ApplyToEachResultList &results,
3143-
transform::TransformState &state) {
3144-
// WARNING: This hoisting does not model parallelism and is generally
3145-
// incorrect when used on distributed loops with memref semantics!
3146-
// TODO: obsolete and should be retired.
3147-
linalg::hoistRedundantVectorTransfers(target);
3148-
results.push_back(target);
3149-
return DiagnosedSilenceableFailure::success();
3150-
}
3151-
31523126
//===----------------------------------------------------------------------===//
31533127
// ConvertConv2DToImg2ColOp.
31543128
//===----------------------------------------------------------------------===//
@@ -3193,9 +3167,7 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne(
31933167
IRRewriter rewriter(target->getContext());
31943168
auto forOp = dyn_cast<scf::ForOp>(target);
31953169
if (forOp) {
3196-
scf::ForOp newForOp =
3197-
linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
3198-
results.push_back(newForOp);
3170+
linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
31993171
return DiagnosedSilenceableFailure::success();
32003172
}
32013173

@@ -3204,10 +3176,36 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne(
32043176
target->walk([&](scf::ForOp forOp) {
32053177
hoistRedundantSubsetExtractInsert(rewriter, forOp);
32063178
});
3207-
results.push_back(target);
32083179
return DiagnosedSilenceableFailure::success();
32093180
}
32103181

3182+
void transform::HoistRedundantTensorSubsetsOp::getEffects(
3183+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3184+
transform::onlyReadsHandle(getTarget(), effects);
3185+
transform::modifiesPayload(effects);
3186+
}
3187+
3188+
//===----------------------------------------------------------------------===//
3189+
// HoistRedundantVectorTransfersOp
3190+
//===----------------------------------------------------------------------===//
3191+
3192+
DiagnosedSilenceableFailure
3193+
transform::HoistRedundantVectorTransfersOp::applyToOne(
3194+
func::FuncOp target, transform::ApplyToEachResultList &results,
3195+
transform::TransformState &state) {
3196+
// WARNING: This hoisting does not model parallelism and is generally
3197+
// incorrect when used on distributed loops with memref semantics!
3198+
// TODO: obsolete and should be retired.
3199+
linalg::hoistRedundantVectorTransfers(target);
3200+
return DiagnosedSilenceableFailure::success();
3201+
}
3202+
3203+
void transform::HoistRedundantVectorTransfersOp::getEffects(
3204+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3205+
transform::onlyReadsHandle(getTarget(), effects);
3206+
transform::modifiesPayload(effects);
3207+
}
3208+
32113209
//===----------------------------------------------------------------------===//
32123210
// Transform op registration
32133211
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)