Skip to content

Commit 2e3c62b

Browse files
[mlir][tensor][NFC] Simplify SubsetInsertionOpInterface implementation (#69999)
`tensor.insert_slice` and `tensor.parallel_insert_slice` can share the same implementation.
1 parent 3324776 commit 2e3c62b

File tree

1 file changed

+36
-82
lines changed

1 file changed

+36
-82
lines changed

mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp

Lines changed: 36 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,105 +17,58 @@ using namespace mlir::tensor;
1717

1818
namespace {
1919

20-
/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
21-
/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
22-
/// equivalence of tensors.
2320
template <typename OpTy>
24-
bool isSubsetEquivalentToInsertSliceLikeOp(
25-
OpTy insertSliceOp, Value candidate,
26-
function_ref<bool(Value, Value)> equivalenceFn) {
27-
// Look for a matching tensor.extract_slice op.
28-
auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
29-
if (!extractSliceOp)
30-
return false;
31-
if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
32-
return false;
33-
return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
34-
isEqualConstantIntOrValue);
35-
}
36-
37-
template <typename OpTy>
38-
Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc,
39-
OpTy insertSliceOp) {
40-
auto extractOp = b.create<tensor::ExtractSliceOp>(
41-
loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
42-
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
43-
insertSliceOp.getMixedStrides());
44-
return extractOp.getResult();
45-
}
46-
47-
template <typename OpTy>
48-
SmallVector<Value>
49-
getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) {
50-
SmallVector<Value> neededValues;
51-
// Collect all values that are needed to construct the replacement op.
52-
neededValues.append(insertSliceOp.getOffsets().begin(),
53-
insertSliceOp.getOffsets().end());
54-
neededValues.append(insertSliceOp.getSizes().begin(),
55-
insertSliceOp.getSizes().end());
56-
neededValues.append(insertSliceOp.getStrides().begin(),
57-
insertSliceOp.getStrides().end());
58-
neededValues.push_back(insertSliceOp.getDest());
59-
return neededValues;
60-
}
61-
62-
struct InsertSliceOpInterface
63-
: public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
64-
tensor::InsertSliceOp> {
65-
OpOperand &getSourceOperand(Operation *op) const {
66-
return cast<tensor::InsertSliceOp>(op).getSourceMutable();
67-
}
68-
69-
bool
70-
isEquivalentSubset(Operation *op, Value candidate,
71-
function_ref<bool(Value, Value)> equivalenceFn) const {
72-
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
73-
return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
74-
equivalenceFn);
75-
}
76-
77-
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
78-
Location loc) const {
79-
return buildSubsetExtractionOfInsertSliceLikeOp(
80-
builder, loc, cast<tensor::InsertSliceOp>(op));
81-
}
82-
83-
SmallVector<Value>
84-
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
85-
return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
86-
cast<tensor::InsertSliceOp>(op));
87-
}
88-
};
89-
90-
struct ParallelInsertSliceOpInterface
21+
struct InsertSliceLikeOpInterface
9122
: public SubsetInsertionOpInterface::ExternalModel<
92-
ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
23+
InsertSliceLikeOpInterface<OpTy>, OpTy> {
9324
OpOperand &getSourceOperand(Operation *op) const {
94-
return cast<tensor::ParallelInsertSliceOp>(op).getSourceMutable();
25+
return cast<OpTy>(op).getSourceMutable();
9526
}
9627

9728
OpOperand &getDestinationOperand(Operation *op) const {
98-
return cast<tensor::ParallelInsertSliceOp>(op).getDestMutable();
29+
return cast<OpTy>(op).getDestMutable();
9930
}
10031

32+
/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
33+
/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
34+
/// equivalence of tensors.
10135
bool
10236
isEquivalentSubset(Operation *op, Value candidate,
10337
function_ref<bool(Value, Value)> equivalenceFn) const {
104-
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
105-
return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
106-
equivalenceFn);
38+
auto insertSliceOp = cast<OpTy>(op);
39+
// Look for a matching tensor.extract_slice op.
40+
auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
41+
if (!extractSliceOp)
42+
return false;
43+
if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
44+
return false;
45+
return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
46+
isEqualConstantIntOrValue);
10747
}
10848

10949
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
11050
Location loc) const {
111-
return buildSubsetExtractionOfInsertSliceLikeOp(
112-
builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
51+
auto insertSliceOp = cast<OpTy>(op);
52+
auto extractOp = builder.create<tensor::ExtractSliceOp>(
53+
loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
54+
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
55+
insertSliceOp.getMixedStrides());
56+
return extractOp.getResult();
11357
}
11458

11559
SmallVector<Value>
11660
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
117-
return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
118-
cast<tensor::ParallelInsertSliceOp>(op));
61+
auto insertSliceOp = cast<OpTy>(op);
62+
SmallVector<Value> neededValues;
63+
// Collect all values that are needed to construct the replacement op.
64+
neededValues.append(insertSliceOp.getOffsets().begin(),
65+
insertSliceOp.getOffsets().end());
66+
neededValues.append(insertSliceOp.getSizes().begin(),
67+
insertSliceOp.getSizes().end());
68+
neededValues.append(insertSliceOp.getStrides().begin(),
69+
insertSliceOp.getStrides().end());
70+
neededValues.push_back(insertSliceOp.getDest());
71+
return neededValues;
11972
}
12073
};
12174

@@ -124,8 +77,9 @@ struct ParallelInsertSliceOpInterface
12477
void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
12578
DialectRegistry &registry) {
12679
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
127-
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
128-
ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
80+
InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
12981
*ctx);
82+
ParallelInsertSliceOp::attachInterface<
83+
InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
13084
});
13185
}

0 commit comments

Comments
 (0)