Skip to content

Commit b9fe461

Browse files
[mlir][transform] LISH: Add transform op (#70630)
Add a transform op for loop-invariant subset hoisting. Delete the old transform op from the Linalg dialect.
1 parent f0535c7 commit b9fe461

File tree

19 files changed

+352
-89
lines changed

19 files changed

+352
-89
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2247,56 +2247,6 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
22472247
}];
22482248
}
22492249

2250-
//===----------------------------------------------------------------------===//
2251-
// HoistRedundantTensorSubsetsOp
2252-
//===----------------------------------------------------------------------===//
2253-
2254-
def HoistRedundantTensorSubsetsOp :
2255-
Op<Transform_Dialect, "structured.hoist_redundant_tensor_subsets",
2256-
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
2257-
TransformEachOpTrait,
2258-
TransformOpInterface,
2259-
ReportTrackingListenerFailuresOpTrait]> {
2260-
let description = [{
2261-
Hoists supported tensor subset extract/insert operation pairs out of
2262-
immediately enclosing loop iteratively, if the following conditions
2263-
are true:
2264-
1. The 2 ops access the same tensor subset.
2265-
2. All operands are invariant under the enclosing loop.
2266-
2267-
The supported subset extract/insert operation pairs currently comprise:
2268-
- tensor.extract_slice / tensor.insert_slice
2269-
- vector.transfer_read / vector.transfer_write on tensors
2270-
2271-
Only scf.for loops are currently supported.
2272-
2273-
When applied to:
2274-
1. an scf.for loop, hoist out of this loop only.
2275-
2. a non-loop op, apply hoisting to all the contained loop ops.
2276-
2277-
#### Return modes:
2278-
2279-
The operation always succeeds and returns nothing.
2280-
}];
2281-
2282-
let arguments = (ins TransformHandleTypeInterface:$target);
2283-
let results = (outs);
2284-
2285-
let assemblyFormat = [{
2286-
$target
2287-
attr-dict
2288-
`:` functional-type(operands, results)
2289-
}];
2290-
2291-
let extraClassDeclaration = [{
2292-
::mlir::DiagnosedSilenceableFailure applyToOne(
2293-
::mlir::transform::TransformRewriter &rewriter,
2294-
::mlir::Operation *target,
2295-
::mlir::transform::ApplyToEachResultList &results,
2296-
::mlir::transform::TransformState &state);
2297-
}];
2298-
}
2299-
23002250
//===----------------------------------------------------------------------===//
23012251
// InsertSliceToCopyOp
23022252
//===----------------------------------------------------------------------===//
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_subdirectory(IR)
2+
add_subdirectory(LoopExtension)
23
add_subdirectory(PDLExtension)
34
add_subdirectory(Transforms)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS LoopExtensionOps.td)
2+
mlir_tablegen(LoopExtensionOps.h.inc -gen-op-decls)
3+
mlir_tablegen(LoopExtensionOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRTransformDialectLoopExtensionOpsIncGen)
5+
6+
add_mlir_doc(LoopExtensionOps LoopExtensionOps Dialects/ -gen-op-doc)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//===- LoopExtension.h - Loop extension for Transform dialect ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
namespace mlir {
10+
class DialectRegistry;
11+
12+
namespace transform {
13+
/// Registers the loop extension of the Transform dialect in the given registry.
14+
void registerLoopExtension(DialectRegistry &dialectRegistry);
15+
} // namespace transform
16+
} // namespace mlir
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- LoopExtensionOps.h - Loop ext. for Transform dialect -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
10+
#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
11+
12+
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
14+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
15+
#include "mlir/IR/OpDefinition.h"
16+
#include "mlir/IR/OpImplementation.h"
17+
#include "mlir/Interfaces/LoopLikeInterface.h"
18+
#include "mlir/Interfaces/SideEffectInterfaces.h"
19+
20+
#define GET_OP_CLASSES
21+
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h.inc"
22+
23+
#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
//===- LoopExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
10+
#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
15+
16+
def HoistLoopInvariantSubsetsOp
17+
: TransformDialectOp<"loop.hoist_loop_invariant_subsets",
18+
[TransformOpInterface, TransformEachOpTrait,
19+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
20+
ReportTrackingListenerFailuresOpTrait]> {
21+
let summary = "Hoist loop invariant subset ops";
22+
let description = [{
23+
This transform hoists loop-invariant subset ops out of the targeted
24+
loop-like op. It looks for matching subset extraction/insertion op pairs and
25+
hoists them. The loop body operates on a newly introduced region iter_arg.
26+
27+
Subset ops are hoisted only from the targeted op. If subset ops should be
28+
hoisted from an entire loop nest, this transformation must be applied to
29+
each loop-like op of the loop nest, starting with the innermost loop and
30+
ending with the outermost loop.
31+
32+
Example:
33+
```
34+
%r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
35+
%0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
36+
%1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
37+
%2 = tensor.insert_slice %1 into %t[0][5][1]
38+
: tensor<5xf32> into tensor<?xf32>
39+
scf.yield %2 : tensor<?xf32>
40+
}
41+
```
42+
Is transformed to:
43+
```
44+
%0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
45+
%new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
46+
%1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
47+
scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
48+
}
49+
%r = tensor.insert_slice %new_loop#1 into %new_loop#0
50+
: tensor<5xf32> into tensor<?xf32>
51+
```
52+
53+
Subset ops are hoisted only if there are no conflicting subset ops. E.g.,
54+
if there were a second overlapping extraction in the above example, no ops
55+
could be hoisted safely.
56+
57+
This transform reads the target handle and modifies the payload. This
58+
transform does not invalidate any handles, but loop-like ops are replaced
59+
with new loop-like ops when a subset op is hoisted. The transform rewriter
60+
updates all handles accordingly.
61+
}];
62+
63+
let arguments = (ins TransformHandleTypeInterface:$target);
64+
let results = (outs);
65+
let assemblyFormat = "$target attr-dict `:` type($target)";
66+
67+
let extraClassDeclaration = [{
68+
::mlir::DiagnosedSilenceableFailure applyToOne(
69+
::mlir::transform::TransformRewriter &rewriter,
70+
::mlir::LoopLikeOpInterface loopLikeOp,
71+
::mlir::transform::ApplyToEachResultList &results,
72+
::mlir::transform::TransformState &state);
73+
}];
74+
}
75+
76+
#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS

mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
1+
//===- PDLExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
3535
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
3636
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
37+
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
3738
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
3839
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
3940
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
@@ -74,6 +75,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
7475
scf::registerTransformDialectExtension(registry);
7576
sparse_tensor::registerTransformDialectExtension(registry);
7677
tensor::registerTransformDialectExtension(registry);
78+
transform::registerLoopExtension(registry);
7779
transform::registerPDLExtension(registry);
7880
vector::registerTransformDialectExtension(registry);
7981

mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace mlir {
1818
class LoopLikeOpInterface;
1919
class Operation;
2020
class Region;
21+
class RewriterBase;
2122
class Value;
2223

2324
/// Given a list of regions, perform loop-invariant code motion. An operation is
@@ -108,7 +109,8 @@ size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
108109
/// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
109110
/// : tensor<5xf32> into tensor<?xf32>
110111
/// ```
111-
LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
112+
LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter,
113+
LoopLikeOpInterface loopLike);
112114

113115
} // end namespace mlir
114116

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3163,35 +3163,6 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
31633163
return DiagnosedSilenceableFailure::success();
31643164
}
31653165

3166-
//===----------------------------------------------------------------------===//
3167-
// HoistRedundantTensorSubsetsOp
3168-
//===----------------------------------------------------------------------===//
3169-
3170-
DiagnosedSilenceableFailure
3171-
transform::HoistRedundantTensorSubsetsOp::applyToOne(
3172-
transform::TransformRewriter &rewriter, Operation *target,
3173-
transform::ApplyToEachResultList &results,
3174-
transform::TransformState &state) {
3175-
auto forOp = dyn_cast<scf::ForOp>(target);
3176-
if (forOp) {
3177-
linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
3178-
return DiagnosedSilenceableFailure::success();
3179-
}
3180-
3181-
// TODO: walking in some reverse / inside-out order would be more efficient
3182-
// and would capture more cases.
3183-
target->walk([&](scf::ForOp forOp) {
3184-
hoistRedundantSubsetExtractInsert(rewriter, forOp);
3185-
});
3186-
return DiagnosedSilenceableFailure::success();
3187-
}
3188-
3189-
void transform::HoistRedundantTensorSubsetsOp::getEffects(
3190-
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3191-
transform::onlyReadsHandle(getTarget(), effects);
3192-
transform::modifiesPayload(effects);
3193-
}
3194-
31953166
//===----------------------------------------------------------------------===//
31963167
// InsertSliceToCopyOp
31973168
//===----------------------------------------------------------------------===//
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_subdirectory(IR)
2+
add_subdirectory(LoopExtension)
23
add_subdirectory(PDLExtension)
34
add_subdirectory(Transforms)
45
add_subdirectory(Utils)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
add_mlir_dialect_library(MLIRTransformLoopExtension
2+
LoopExtension.cpp
3+
LoopExtensionOps.cpp
4+
5+
DEPENDS
6+
MLIRTransformDialectLoopExtensionOpsIncGen
7+
8+
LINK_LIBS PUBLIC
9+
MLIRIR
10+
MLIRLoopLikeInterface
11+
MLIRTransformDialect
12+
MLIRTransforms
13+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===- LoopExtension.cpp - Loop extension for the Transform dialect -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
10+
11+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
12+
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h"
13+
#include "mlir/IR/DialectRegistry.h"
14+
15+
using namespace mlir;
16+
17+
namespace {
18+
/// Loop extension of the Transform dialect. This provides "core" transform
19+
/// operations for loop-like ops.
20+
class LoopExtension
21+
: public transform::TransformDialectExtension<LoopExtension> {
22+
public:
23+
void init() {
24+
registerTransformOps<
25+
#define GET_OP_LIST
26+
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc"
27+
>();
28+
}
29+
};
30+
} // namespace
31+
32+
void mlir::transform::registerLoopExtension(DialectRegistry &dialectRegistry) {
33+
dialectRegistry.addExtensions<LoopExtension>();
34+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- LoopExtensionOps.cpp - Loop extension for the Transform dialect ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h"
10+
11+
#include "mlir/IR/OpImplementation.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
14+
15+
using namespace mlir;
16+
17+
#define GET_OP_CLASSES
18+
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc"
19+
20+
//===----------------------------------------------------------------------===//
21+
// HoistLoopInvariantSubsetsOp
22+
//===----------------------------------------------------------------------===//
23+
24+
DiagnosedSilenceableFailure transform::HoistLoopInvariantSubsetsOp::applyToOne(
25+
transform::TransformRewriter &rewriter, LoopLikeOpInterface loopLikeOp,
26+
transform::ApplyToEachResultList &results,
27+
transform::TransformState &state) {
28+
hoistLoopInvariantSubsets(rewriter, loopLikeOp);
29+
return DiagnosedSilenceableFailure::success();
30+
}
31+
32+
void transform::HoistLoopInvariantSubsetsOp::getEffects(
33+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
34+
transform::onlyReadsHandle(getTarget(), effects);
35+
transform::modifiesPayload(effects);
36+
}

mlir/lib/Transforms/LoopInvariantCodeMotion.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Transforms/Passes.h"
1414

15+
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/Interfaces/LoopLikeInterface.h"
1617
#include "mlir/Interfaces/SideEffectInterfaces.h"
1718
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
@@ -47,11 +48,12 @@ void LoopInvariantCodeMotion::runOnOperation() {
4748
}
4849

4950
void LoopInvariantSubsetHoisting::runOnOperation() {
51+
IRRewriter rewriter(getOperation()->getContext());
5052
// Walk through all loops in a function in innermost-loop-first order. This
5153
// way, we first hoist from the inner loop, and place the ops in the outer
5254
// loop, which in turn can be further hoisted from.
5355
getOperation()->walk([&](LoopLikeOpInterface loopLike) {
54-
(void)hoistLoopInvariantSubsets(loopLike);
56+
(void)hoistLoopInvariantSubsets(rewriter, loopLike);
5557
});
5658
}
5759

0 commit comments

Comments
 (0)