@@ -17,105 +17,58 @@ using namespace mlir::tensor;
17
17
18
18
namespace {
19
19
20
- // / Return "true" if `insertSliceOp` inserts into a subset that is equivalent
21
- // / to the subset defined by `candidate`. `equivalenceFn` is used to determine
22
- // / equivalence of tensors.
23
20
template <typename OpTy>
24
- bool isSubsetEquivalentToInsertSliceLikeOp (
25
- OpTy insertSliceOp, Value candidate,
26
- function_ref<bool (Value, Value)> equivalenceFn) {
27
- // Look for a matching tensor.extract_slice op.
28
- auto extractSliceOp = candidate.getDefiningOp <tensor::ExtractSliceOp>();
29
- if (!extractSliceOp)
30
- return false ;
31
- if (!equivalenceFn (extractSliceOp.getSource (), insertSliceOp.getDest ()))
32
- return false ;
33
- return sameOffsetsSizesAndStrides (extractSliceOp, insertSliceOp,
34
- isEqualConstantIntOrValue);
35
- }
36
-
37
- template <typename OpTy>
38
- Value buildSubsetExtractionOfInsertSliceLikeOp (OpBuilder &b, Location loc,
39
- OpTy insertSliceOp) {
40
- auto extractOp = b.create <tensor::ExtractSliceOp>(
41
- loc, insertSliceOp.getSourceType (), insertSliceOp.getDest (),
42
- insertSliceOp.getMixedOffsets (), insertSliceOp.getMixedSizes (),
43
- insertSliceOp.getMixedStrides ());
44
- return extractOp.getResult ();
45
- }
46
-
47
- template <typename OpTy>
48
- SmallVector<Value>
49
- getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp (OpTy insertSliceOp) {
50
- SmallVector<Value> neededValues;
51
- // Collect all values that are needed to construct the replacement op.
52
- neededValues.append (insertSliceOp.getOffsets ().begin (),
53
- insertSliceOp.getOffsets ().end ());
54
- neededValues.append (insertSliceOp.getSizes ().begin (),
55
- insertSliceOp.getSizes ().end ());
56
- neededValues.append (insertSliceOp.getStrides ().begin (),
57
- insertSliceOp.getStrides ().end ());
58
- neededValues.push_back (insertSliceOp.getDest ());
59
- return neededValues;
60
- }
61
-
62
- struct InsertSliceOpInterface
63
- : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
64
- tensor::InsertSliceOp> {
65
- OpOperand &getSourceOperand (Operation *op) const {
66
- return cast<tensor::InsertSliceOp>(op).getSourceMutable ();
67
- }
68
-
69
- bool
70
- isEquivalentSubset (Operation *op, Value candidate,
71
- function_ref<bool (Value, Value)> equivalenceFn) const {
72
- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
73
- return isSubsetEquivalentToInsertSliceLikeOp (insertSliceOp, candidate,
74
- equivalenceFn);
75
- }
76
-
77
- Value buildSubsetExtraction (Operation *op, OpBuilder &builder,
78
- Location loc) const {
79
- return buildSubsetExtractionOfInsertSliceLikeOp (
80
- builder, loc, cast<tensor::InsertSliceOp>(op));
81
- }
82
-
83
- SmallVector<Value>
84
- getValuesNeededToBuildSubsetExtraction (Operation *op) const {
85
- return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp (
86
- cast<tensor::InsertSliceOp>(op));
87
- }
88
- };
89
-
90
- struct ParallelInsertSliceOpInterface
21
+ struct InsertSliceLikeOpInterface
91
22
: public SubsetInsertionOpInterface::ExternalModel<
92
- ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp > {
23
+ InsertSliceLikeOpInterface<OpTy>, OpTy > {
93
24
OpOperand &getSourceOperand (Operation *op) const {
94
- return cast<tensor::ParallelInsertSliceOp >(op).getSourceMutable ();
25
+ return cast<OpTy >(op).getSourceMutable ();
95
26
}
96
27
97
28
OpOperand &getDestinationOperand (Operation *op) const {
98
- return cast<tensor::ParallelInsertSliceOp >(op).getDestMutable ();
29
+ return cast<OpTy >(op).getDestMutable ();
99
30
}
100
31
32
+ // / Return "true" if `insertSliceOp` inserts into a subset that is equivalent
33
+ // / to the subset defined by `candidate`. `equivalenceFn` is used to determine
34
+ // / equivalence of tensors.
101
35
bool
102
36
isEquivalentSubset (Operation *op, Value candidate,
103
37
function_ref<bool (Value, Value)> equivalenceFn) const {
104
- auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
105
- return isSubsetEquivalentToInsertSliceLikeOp (insertSliceOp, candidate,
106
- equivalenceFn);
38
+ auto insertSliceOp = cast<OpTy>(op);
39
+ // Look for a matching tensor.extract_slice op.
40
+ auto extractSliceOp = candidate.getDefiningOp <tensor::ExtractSliceOp>();
41
+ if (!extractSliceOp)
42
+ return false ;
43
+ if (!equivalenceFn (extractSliceOp.getSource (), insertSliceOp.getDest ()))
44
+ return false ;
45
+ return sameOffsetsSizesAndStrides (extractSliceOp, insertSliceOp,
46
+ isEqualConstantIntOrValue);
107
47
}
108
48
109
49
Value buildSubsetExtraction (Operation *op, OpBuilder &builder,
110
50
Location loc) const {
111
- return buildSubsetExtractionOfInsertSliceLikeOp (
112
- builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
51
+ auto insertSliceOp = cast<OpTy>(op);
52
+ auto extractOp = builder.create <tensor::ExtractSliceOp>(
53
+ loc, insertSliceOp.getSourceType (), insertSliceOp.getDest (),
54
+ insertSliceOp.getMixedOffsets (), insertSliceOp.getMixedSizes (),
55
+ insertSliceOp.getMixedStrides ());
56
+ return extractOp.getResult ();
113
57
}
114
58
115
59
SmallVector<Value>
116
60
getValuesNeededToBuildSubsetExtraction (Operation *op) const {
117
- return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp (
118
- cast<tensor::ParallelInsertSliceOp>(op));
61
+ auto insertSliceOp = cast<OpTy>(op);
62
+ SmallVector<Value> neededValues;
63
+ // Collect all values that are needed to construct the replacement op.
64
+ neededValues.append (insertSliceOp.getOffsets ().begin (),
65
+ insertSliceOp.getOffsets ().end ());
66
+ neededValues.append (insertSliceOp.getSizes ().begin (),
67
+ insertSliceOp.getSizes ().end ());
68
+ neededValues.append (insertSliceOp.getStrides ().begin (),
69
+ insertSliceOp.getStrides ().end ());
70
+ neededValues.push_back (insertSliceOp.getDest ());
71
+ return neededValues;
119
72
}
120
73
};
121
74
@@ -124,8 +77,9 @@ struct ParallelInsertSliceOpInterface
124
77
void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels (
125
78
DialectRegistry ®istry) {
126
79
registry.addExtension (+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
127
- InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
128
- ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
80
+ InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
129
81
*ctx);
82
+ ParallelInsertSliceOp::attachInterface<
83
+ InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
130
84
});
131
85
}
0 commit comments