Skip to content

Commit 33468a5

Browse files
[mlir][Tensor] Add support for insert_slice in FoldTensorSubsetOps
Differential Revision: https://reviews.llvm.org/D148334
1 parent c8144ee commit 33468a5

File tree

8 files changed

+321
-69
lines changed

8 files changed

+321
-69
lines changed

mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
1010
#define MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
1111

12+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1213
#include "mlir/IR/OpDefinition.h"
1314
#include "mlir/Interfaces/ViewLikeInterface.h"
1415

@@ -22,7 +23,8 @@ class RewriterBase;
2223
/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
2324
/// - Combined sizes = consumer_sizes
2425
/// - Combined strides = producer_strides * consumer_strides
25-
// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
26+
// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or
27+
// deprecate.
2628
LogicalResult
2729
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
2830
ArrayRef<OpFoldResult> producerOffsets,
@@ -38,7 +40,8 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
3840

3941
/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
4042
/// when combining a `producer` slice op **into** a `consumer` slice op.
41-
// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
43+
// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or
44+
// deprecate.
4245
LogicalResult
4346
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
4447
OffsetSizeAndStrideOpInterface producer,
@@ -48,8 +51,8 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
4851
SmallVector<OpFoldResult> &combinedSizes,
4952
SmallVector<OpFoldResult> &combinedStrides);
5053

51-
/// Given the 'indicesVals' of a load/store operation operating on an op with
52-
/// offsets and strides, return the combined indices.
54+
/// Given the 'consumerIndices' of a load/store operation operating on an op
55+
/// with offsets and strides, return the combined indices.
5356
///
5457
/// For example, using `memref.load` and `memref.subview` as an illustration:
5558
///
@@ -64,13 +67,37 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
6467
///
6568
/// ```
6669
/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
67-
/// memref<12x42xf32>
70+
/// memref<12x42xf32>å
6871
/// ```
69-
void resolveSourceIndicesOffsetsAndStrides(
70-
RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
71-
ArrayRef<OpFoldResult> mixedStrides,
72-
const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
73-
SmallVectorImpl<Value> &sourceIndices);
72+
void resolveIndicesIntoOpWithOffsetsAndStrides(
73+
RewriterBase &rewriter, Location loc,
74+
ArrayRef<OpFoldResult> mixedSourceOffsets,
75+
ArrayRef<OpFoldResult> mixedSourceStrides,
76+
const llvm::SmallBitVector &rankReducedDims,
77+
ArrayRef<OpFoldResult> consumerIndices,
78+
SmallVectorImpl<Value> &resolvedIndices);
79+
80+
inline void resolveIndicesIntoOpWithOffsetsAndStrides(
81+
RewriterBase &rewriter, Location loc,
82+
ArrayRef<OpFoldResult> mixedSourceOffsets,
83+
ArrayRef<OpFoldResult> mixedSourceStrides,
84+
const llvm::SmallBitVector &rankReducedDims, ValueRange consumerIndices,
85+
SmallVectorImpl<Value> &resolvedIndices) {
86+
return resolveIndicesIntoOpWithOffsetsAndStrides(
87+
rewriter, loc, mixedSourceOffsets, mixedSourceStrides, rankReducedDims,
88+
getAsOpFoldResult(consumerIndices), resolvedIndices);
89+
}
90+
91+
/// Given `sourceSizes`, `destSizes` and information about which dimensions are
92+
/// dropped by the source: `rankReducedSourceDims`, compute the resolved sizes
93+
/// that correspond to dest_op(source_op).
94+
/// In practice, this amounts to filtering by `rankReducedSourceDims` and taking
95+
/// from `sourceSizes` if a dimension is dropped, otherwise taking from
96+
/// `destSizes`.
97+
void resolveSizesIntoOpWithSizes(
98+
ArrayRef<OpFoldResult> sourceSizes, ArrayRef<OpFoldResult> destSizes,
99+
const llvm::SmallBitVector &rankReducedSourceDims,
100+
SmallVectorImpl<OpFoldResult> &resolvedSizes);
74101

75102
} // namespace mlir
76103

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,13 +1963,13 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19631963
}];
19641964

19651965
let builders = [
1966-
// Build a SubViewOp with mixed static and dynamic entries and custom
1967-
// result type. If the type passed is nullptr, it is inferred.
1966+
// Build a SubViewOp with mixed static and dynamic entries and inferred
1967+
// result type.
19681968
OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$offsets,
19691969
"ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
19701970
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1971-
// Build a SubViewOp with mixed static and dynamic entries and inferred
1972-
// result type.
1971+
// Build a SubViewOp with mixed static and dynamic entries and custom
1972+
// result type. If the type passed is nullptr, it is inferred.
19731973
OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
19741974
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
19751975
"ArrayRef<OpFoldResult>":$strides,

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -823,17 +823,18 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
823823
}];
824824

825825
let builders = [
826-
// Build a InsertSliceOp with mixed static and dynamic entries.
826+
// Build a InsertSliceOp with mixed static and dynamic entries and inferred
827+
// result type.
827828
OpBuilder<(ins "Value":$source, "Value":$dest,
828829
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
829830
"ArrayRef<OpFoldResult>":$strides,
830831
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
831-
// Build a InsertSliceOp with dynamic entries.
832+
// Build a InsertSliceOp with dynamic entries and inferred result type.
832833
OpBuilder<(ins "Value":$source, "Value":$dest,
833834
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
834835
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
835836
// Build an InsertSliceOp with mixed static and dynamic entries packed in
836-
// a Range vector.
837+
// a Range vector and inferred result type.
837838
OpBuilder<(ins "Value":$source, "Value":$dest,
838839
"ArrayRef<Range>":$ranges,
839840
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
@@ -1450,6 +1451,10 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
14501451
/// Return the OpResult of the enclosing ForallOp that is
14511452
/// corresponding to this ParallelInsertSliceOp.
14521453
OpResult getTiedOpResult();
1454+
1455+
/// Return the dimensions of the dest that are omitted to insert a source
1456+
/// when the result is rank-extended.
1457+
llvm::SmallBitVector getDroppedDims();
14531458
}];
14541459

14551460
let builders = [

mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,32 +77,49 @@ LogicalResult mlir::mergeOffsetsSizesAndStrides(
7777
combinedOffsets, combinedSizes, combinedStrides);
7878
}
7979

80-
void mlir::resolveSourceIndicesOffsetsAndStrides(
81-
RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
82-
ArrayRef<OpFoldResult> mixedStrides,
83-
const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
84-
SmallVectorImpl<Value> &sourceIndices) {
80+
void mlir::resolveIndicesIntoOpWithOffsetsAndStrides(
81+
RewriterBase &rewriter, Location loc,
82+
ArrayRef<OpFoldResult> mixedSourceOffsets,
83+
ArrayRef<OpFoldResult> mixedSourceStrides,
84+
const llvm::SmallBitVector &rankReducedDims,
85+
ArrayRef<OpFoldResult> consumerIndices,
86+
SmallVectorImpl<Value> &resolvedIndices) {
8587
OpFoldResult zero = rewriter.getIndexAttr(0);
8688

8789
// For each dimension that is rank-reduced, add a zero to the indices.
8890
int64_t indicesDim = 0;
8991
SmallVector<OpFoldResult> indices;
90-
for (auto dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
92+
for (auto dim : llvm::seq<int64_t>(0, mixedSourceOffsets.size())) {
9193
OpFoldResult ofr =
92-
(rankReducedDims.test(dim)) ? zero : indicesVals[indicesDim++];
94+
(rankReducedDims.test(dim)) ? zero : consumerIndices[indicesDim++];
9395
indices.push_back(ofr);
9496
}
9597

96-
sourceIndices.resize(indices.size());
97-
sourceIndices.clear();
98+
resolvedIndices.resize(indices.size());
99+
resolvedIndices.clear();
98100
for (auto [offset, index, stride] :
99-
llvm::zip_equal(mixedOffsets, indices, mixedStrides)) {
101+
llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) {
100102
AffineExpr off, idx, str;
101103
bindSymbols(rewriter.getContext(), off, idx, str);
102104
OpFoldResult ofr = makeComposedFoldedAffineApply(
103105
rewriter, loc, AffineMap::get(0, 3, off + idx * str),
104106
{offset, index, stride});
105-
sourceIndices.push_back(
107+
resolvedIndices.push_back(
106108
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
107109
}
108110
}
111+
112+
void mlir::resolveSizesIntoOpWithSizes(
113+
ArrayRef<OpFoldResult> sourceSizes, ArrayRef<OpFoldResult> destSizes,
114+
const llvm::SmallBitVector &rankReducedSourceDims,
115+
SmallVectorImpl<OpFoldResult> &resolvedSizes) {
116+
int64_t dim = 0;
117+
int64_t srcRank = sourceSizes.size();
118+
for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) {
119+
if (rankReducedSourceDims[srcDim]) {
120+
resolvedSizes.push_back(sourceSizes[srcDim]);
121+
continue;
122+
}
123+
resolvedSizes.push_back(destSizes[dim++]);
124+
}
125+
}

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -248,48 +248,38 @@ class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
248248

249249
LogicalResult matchAndRewrite(memref::SubViewOp subView,
250250
PatternRewriter &rewriter) const override {
251-
Location loc = subView.getLoc();
252251
auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
253252
if (!srcSubView)
254253
return failure();
255-
int64_t srcRank = srcSubView.getSourceType().getRank();
256-
257-
// TODO: Only stride 1 is supported.
258-
for (auto s : {subView.getMixedStrides(), srcSubView.getMixedStrides()})
259-
if (!llvm::all_of(
260-
s, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }))
261-
return failure();
262-
263-
// Get original offsets and sizes.
264-
SmallVector<OpFoldResult> offsets = subView.getMixedOffsets();
265-
SmallVector<OpFoldResult> srcOffsets = srcSubView.getMixedOffsets();
266-
SmallVector<OpFoldResult> sizes = subView.getMixedSizes();
267-
SmallVector<OpFoldResult> srcSizes = srcSubView.getMixedSizes();
268-
269-
// Compute new offsets and sizes.
270-
llvm::SmallBitVector srcReducedDims = srcSubView.getDroppedDims();
271-
SmallVector<OpFoldResult> newOffsets, newSizes;
272-
int64_t dim = 0;
273-
for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) {
274-
if (srcReducedDims[srcDim]) {
275-
// Dim is reduced in srcSubView.
276-
assert(isConstantIntValue(srcSizes[srcDim], 1) && "expected size 1");
277-
newOffsets.push_back(srcOffsets[srcDim]);
278-
newSizes.push_back(srcSizes[srcDim]);
279-
continue;
280-
}
281-
AffineExpr sym0, sym1;
282-
bindSymbols(subView.getContext(), sym0, sym1);
283-
newOffsets.push_back(makeComposedFoldedAffineApply(
284-
rewriter, loc, sym0 + sym1, {srcOffsets[srcDim], offsets[dim]}));
285-
newSizes.push_back(sizes[dim]);
286-
++dim;
254+
255+
// TODO: relax unit stride assumption.
256+
if (!subView.hasUnitStride()) {
257+
return rewriter.notifyMatchFailure(subView, "requires unit strides");
258+
}
259+
if (!srcSubView.hasUnitStride()) {
260+
return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
287261
}
288262

263+
// Resolve sizes according to dropped dims.
264+
SmallVector<OpFoldResult> resolvedSizes;
265+
llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
266+
resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
267+
subView.getMixedSizes(), srcDroppedDims,
268+
resolvedSizes);
269+
270+
// Resolve offsets according to source offsets and strides.
271+
SmallVector<Value> resolvedOffsets;
272+
resolveIndicesIntoOpWithOffsetsAndStrides(
273+
rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
274+
srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
275+
resolvedOffsets);
276+
289277
// Replace original op.
290278
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
291-
subView, subView.getType(), srcSubView.getSource(), newOffsets,
292-
newSizes, srcSubView.getMixedStrides());
279+
subView, subView.getType(), srcSubView.getSource(),
280+
getAsOpFoldResult(resolvedOffsets), resolvedSizes,
281+
srcSubView.getMixedStrides());
282+
293283
return success();
294284
}
295285
};
@@ -372,7 +362,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
372362
indices.assign(expandedIndices.begin(), expandedIndices.end());
373363
}
374364
SmallVector<Value> sourceIndices;
375-
resolveSourceIndicesOffsetsAndStrides(
365+
resolveIndicesIntoOpWithOffsetsAndStrides(
376366
rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
377367
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
378368
sourceIndices);
@@ -492,7 +482,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
492482
indices.assign(expandedIndices.begin(), expandedIndices.end());
493483
}
494484
SmallVector<Value> sourceIndices;
495-
resolveSourceIndicesOffsetsAndStrides(
485+
resolveIndicesIntoOpWithOffsetsAndStrides(
496486
rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
497487
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
498488
sourceIndices);

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3086,6 +3086,10 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
30863086
InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
30873087
}
30883088

3089+
llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3090+
return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3091+
}
3092+
30893093
//===----------------------------------------------------------------------===//
30903094
// ScatterOp
30913095
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)