Skip to content

Commit 047e7ff

Browse files
[mlir][tensor] TrackingListener: Find replacement ops through cast-like InsertSliceOps
Certain InsertSliceOps, that do not use elements from the destination, are treated like casts when looking for replacement ops. Such InsertSliceOps are typically rank expansions. Tensors with dynamic shape are not supported at the moment. Also adds test cases for the TrackingListener. Differential Revision: https://reviews.llvm.org/D151422
1 parent 19b9c74 commit 047e7ff

File tree

7 files changed

+206
-0
lines changed

7 files changed

+206
-0
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ void prepareValueMappings(
5959
/// the operands to `block`'s terminator.
6060
void forwardTerminatorOperands(Block *block, transform::TransformState &state,
6161
transform::TransformResults &results);
62+
63+
/// Make a dummy transform state for testing purposes. This MUST NOT be used
64+
/// outside of test cases.
65+
TransformState makeTransformStateForTesting(Region *region,
66+
Operation *payloadRoot);
6267
} // namespace detail
6368

6469
/// Options controlling the application of transform operations by the
@@ -162,6 +167,9 @@ class TransformState {
162167
const RaggedArray<MappedValue> &,
163168
const TransformOptions &);
164169

170+
friend TransformState
171+
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
172+
165173
public:
166174
/// Returns the op at which the transformation state is rooted. This is
167175
/// typically helpful for transformations that apply globally.

mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@ using namespace tensor;
2323
// TrackingListener
2424
//===----------------------------------------------------------------------===//
2525

26+
/// A tensor.insert_slice is a cast-like operation if it the source tensor and
27+
/// the destination tensor have the same number of elements. I.e., the result
28+
/// tensor data equals the source tensor data, maybe rank-extended to a
29+
/// different shape.
30+
static bool isCastLikeInsertSliceOp(InsertSliceOp op) {
31+
// TODO: Support dynamically shaped tensors. Utilize ValueBoundsOpInterface
32+
// to check if source and destination have the same shape.
33+
if (!op.getSourceType().hasStaticShape() ||
34+
!op.getDestType().hasStaticShape())
35+
return false;
36+
return op.getSourceType().getNumElements() ==
37+
op.getDestType().getNumElements();
38+
}
39+
2640
Operation *
2741
tensor::TrackingListener::findReplacementOp(Operation *op,
2842
ValueRange newValues) const {
@@ -48,6 +62,10 @@ tensor::TrackingListener::findReplacementOp(Operation *op,
4862
[&](ExpandShapeOp op) { values.push_back(op.getSrc()); })
4963
.Case<ReshapeOp>(
5064
[&](ReshapeOp op) { values.push_back(op.getSource()); })
65+
.Case<InsertSliceOp>([&](InsertSliceOp op) {
66+
if (isCastLikeInsertSliceOp(op))
67+
values.push_back(op.getSource());
68+
})
5169
.Default([](Operation *op) {});
5270
} while (!values.empty());
5371

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,12 @@ void transform::detail::forwardTerminatorOperands(
12381238
}
12391239
}
12401240

1241+
transform::TransformState
1242+
transform::detail::makeTransformStateForTesting(Region *region,
1243+
Operation *payloadRoot) {
1244+
return TransformState(region, payloadRoot);
1245+
}
1246+
12411247
//===----------------------------------------------------------------------===//
12421248
// Utilities for PossibleTopLevelTransformOpTrait.
12431249
//===----------------------------------------------------------------------===//
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: mlir-opt -test-tensor-transform-patterns=test-tracking-listener \
2+
// RUN: -split-input-file -verify-diagnostics %s
3+
4+
func.func @replace_op_with_op_of_same_type() {
5+
%0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
6+
// expected-remark @below {{replacement found}}
7+
%1 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>)
8+
return
9+
}
10+
11+
// -----
12+
13+
func.func @replace_op_with_op_of_different_type() {
14+
// expected-error @below {{listener could not find replacement op}}
15+
%0 = tensor.empty() {replaced} : tensor<5xf32>
16+
%1 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>)
17+
return
18+
}
19+
20+
// -----
21+
22+
func.func @multi_result_replacement() {
23+
%0:2 = "test.foo"() {replaced} : () -> (tensor<5xf32>, tensor<6xf32>)
24+
// expected-remark @below {{replacement found}}
25+
%1:2 = "test.foo"() {replacement_0 = 0, replacement_1 = 1}
26+
: () -> (tensor<5xf32>, tensor<6xf32>)
27+
return
28+
}
29+
30+
// -----
31+
32+
func.func @multi_result_replacement_with_multiple_ops() {
33+
// expected-error @below {{listener could not find replacement op}}
34+
%0:2 = "test.foo"() {replaced} : () -> (tensor<5xf32>, tensor<6xf32>)
35+
%1:2 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>, tensor<6xf32>)
36+
%2:2 = "test.foo"() {replacement_1 = 1} : () -> (tensor<5xf32>, tensor<6xf32>)
37+
return
38+
}
39+
40+
// -----
41+
42+
func.func @replacement_wrapped_in_cast() {
43+
%0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
44+
// expected-remark @below {{replacement found}}
45+
%1 = "test.foo"() : () -> (tensor<?xf32>)
46+
%2 = tensor.cast %1 {replacement_0 = 0} : tensor<?xf32> to tensor<5xf32>
47+
return
48+
}
49+
50+
// -----
51+
52+
func.func @replacement_wrapped_in_chain_of_casts() {
53+
%0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
54+
// expected-remark @below {{replacement found}}
55+
%1 = "test.foo"() : () -> (tensor<?xf32>)
56+
%2 = tensor.cast %1 : tensor<?xf32> to tensor<5xf32>
57+
%3 = tensor.cast %2 : tensor<5xf32> to tensor<?xf32>
58+
%4 = tensor.cast %3 {replacement_0 = 0} : tensor<?xf32> to tensor<5xf32>
59+
return
60+
}
61+
62+
// -----
63+
64+
func.func @cast_like_insert_slice(%t: tensor<1x5xf32>) {
65+
%0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
66+
// expected-remark @below {{replacement found}}
67+
%1 = "test.foo"() : () -> (tensor<5xf32>)
68+
%2 = tensor.insert_slice %1 into %t[0, 0][1, 5][1, 1] {replacement_0 = 0}
69+
: tensor<5xf32> into tensor<1x5xf32>
70+
return
71+
}
72+
73+
// -----
74+
75+
func.func @non_cast_like_insert_slice(%t: tensor<7xf32>) {
76+
// expected-error @below {{listener could not find replacement op}}
77+
%0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
78+
%1 = "test.foo"() : () -> (tensor<5xf32>)
79+
// This is not a cast-like insert_slice op because elements from %t are
80+
// contained in %2.
81+
%2 = tensor.insert_slice %1 into %t[0][5][1] {replacement_0 = 0}
82+
: tensor<5xf32> into tensor<7xf32>
83+
return
84+
}

mlir/test/lib/Dialect/Tensor/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ add_mlir_library(MLIRTensorTestPasses
1010
MLIRPass
1111
MLIRSCFDialect
1212
MLIRTensorDialect
13+
MLIRTensorTransformOps
1314
MLIRTensorTransforms
15+
MLIRTransformDialect
1416
MLIRTransforms
1517
)

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1515
#include "mlir/Dialect/SCF/IR/SCF.h"
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
17+
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
1718
#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
1819
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
20+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1921
#include "mlir/Pass/Pass.h"
2022
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2123

@@ -85,6 +87,11 @@ struct TestTensorTransforms
8587
*this, "test-simplify-pack-patterns",
8688
llvm::cl::desc("Test patterns to simplify tensor.pack"),
8789
llvm::cl::init(false)};
90+
91+
Option<bool> testTrackingListener{
92+
*this, "test-tracking-listener",
93+
llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
94+
llvm::cl::init(false)};
8895
};
8996
} // namespace
9097

@@ -276,6 +283,82 @@ applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
276283
return applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
277284
}
278285

286+
namespace {
287+
class DummyTrackingListener : public tensor::TrackingListener {
288+
public:
289+
using tensor::TrackingListener::TrackingListener;
290+
291+
// Expose `findReplacementOp` as a public function, so that it can be tested.
292+
Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
293+
return findReplacementOp(op, newValues);
294+
}
295+
};
296+
} // namespace
297+
298+
static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
299+
// Find replaced op.
300+
Operation *replaced = nullptr;
301+
WalkResult status = rootOp->walk([&](Operation *op) {
302+
if (op->hasAttr("replaced")) {
303+
if (replaced) {
304+
op->emitError("only one 'replaced' op is allowed per test case");
305+
replaced->emitRemark("other 'replaced' op");
306+
return WalkResult::interrupt();
307+
}
308+
replaced = op;
309+
}
310+
return WalkResult::advance();
311+
});
312+
if (status.wasInterrupted())
313+
return failure();
314+
if (!replaced) {
315+
replaced->emitError("could not find 'replaced' op");
316+
return failure();
317+
}
318+
319+
// Find replacements.
320+
SmallVector<Value> replacements(replaced->getNumResults(), Value());
321+
status = rootOp->walk([&](Operation *op) {
322+
for (int64_t i = 0; i < replaced->getNumResults(); ++i) {
323+
if (auto attr = op->getAttrOfType<IntegerAttr>("replacement_" +
324+
std::to_string(i))) {
325+
if (replacements[i]) {
326+
op->emitError("only one 'replacement_" + std::to_string(i) +
327+
"' is allowed per test case");
328+
replacements[i].getDefiningOp()->emitRemark("other 'replacement_" +
329+
std::to_string(i) + "'");
330+
return WalkResult::interrupt();
331+
}
332+
replacements[i] = op->getResult(attr.getInt());
333+
}
334+
}
335+
return WalkResult::advance();
336+
});
337+
if (status.wasInterrupted())
338+
return failure();
339+
340+
if (!llvm::all_of(replacements,
341+
[](Value v) { return static_cast<bool>(v); })) {
342+
replaced->emitError("insufficient replacement values");
343+
return failure();
344+
}
345+
346+
// Find the replacement op (if any) and emit a remark/error.
347+
transform::TransformState transformState =
348+
transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
349+
/*payloadRoot=*/nullptr);
350+
DummyTrackingListener listener(transformState,
351+
transform::TransformOpInterface());
352+
Operation *replacement = listener.getReplacementOp(replaced, replacements);
353+
if (!replacement) {
354+
replaced->emitError("listener could not find replacement op");
355+
return failure();
356+
}
357+
358+
replacement->emitRemark("replacement found");
359+
return success();
360+
}
361+
279362
void TestTensorTransforms::runOnOperation() {
280363
Operation *rootOp = getOperation();
281364
if (testSimplifyPackPatterns)
@@ -295,6 +378,9 @@ void TestTensorTransforms::runOnOperation() {
295378
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
296379
return signalPassFailure();
297380
}
381+
if (testTrackingListener)
382+
if (failed(testTrackingListenerReplacements(rootOp)))
383+
return signalPassFailure();
298384
}
299385

300386
namespace mlir {

utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,9 @@ cc_library(
855855
"//mlir:Pass",
856856
"//mlir:SCFDialect",
857857
"//mlir:TensorDialect",
858+
"//mlir:TensorTransformOps",
858859
"//mlir:TensorTransforms",
860+
"//mlir:TransformDialect",
859861
"//mlir:Transforms",
860862
],
861863
)

0 commit comments

Comments
 (0)