Skip to content

Commit bf897d5

Browse files
authored
[mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs (#72142)
This patch extends TransferReadDropUnitDimsPattern to support dropping unit dims from partially-static memrefs, for example: %v = vector.transfer_read %base[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8> Is rewritten as: %dim0 = memref.dim %base, %c0 : memref<?x1xi8, strided<[?, ?], offset: ?>> %subview = memref.subview %base[0, 0] [%dim0, 1] [1, 1] : memref<?x1xi8, strided<[?, ?], offset: ?>> to memref<?xi8, #map1> %v = vector.transfer_read %subview[%c0], %pad {in_bounds = [true]} : memref<?xi8, #map1>, vector<[16]xi8> Scalable vectors are now also supported, the scalable dims were being dropped when creating the rank-reduced vector type. The xfer op can also have a mask of type 'vector.create_mask', which gets rewritten as long as the mask of the unit dim is a constant of 1.
1 parent cdf6693 commit bf897d5

File tree

2 files changed

+211
-28
lines changed

2 files changed

+211
-28
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 99 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,37 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
260260
opToErase.push_back(read.getOperation());
261261
}
262262

263+
/// Returns a copy of `shape` without unit dims.
264+
static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
265+
SmallVector<int64_t> reducedShape;
266+
llvm::copy_if(shape, std::back_inserter(reducedShape),
267+
[](int64_t dimSize) { return dimSize != 1; });
268+
return reducedShape;
269+
}
270+
271+
/// Converts OpFoldResults to int64_t shape without unit dims.
272+
static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
273+
SmallVector<int64_t> reducedShape;
274+
for (const auto size : mixedSizes) {
275+
if (llvm::dyn_cast_if_present<Value>(size)) {
276+
reducedShape.push_back(ShapedType::kDynamic);
277+
continue;
278+
}
279+
280+
auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
281+
if (value == 1)
282+
continue;
283+
reducedShape.push_back(value.getSExtValue());
284+
}
285+
return reducedShape;
286+
}
287+
263288
/// Drops unit dimensions from the input MemRefType.
264-
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
265-
ArrayRef<int64_t> sizes,
266-
ArrayRef<int64_t> strides) {
267-
SmallVector<int64_t> targetShape = llvm::to_vector(
268-
llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
289+
static MemRefType dropUnitDims(MemRefType inputType,
290+
ArrayRef<OpFoldResult> offsets,
291+
ArrayRef<OpFoldResult> sizes,
292+
ArrayRef<OpFoldResult> strides) {
293+
auto targetShape = getReducedShape(sizes);
269294
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
270295
targetShape, inputType, offsets, sizes, strides);
271296
return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
@@ -277,30 +302,63 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
277302
mlir::Location loc,
278303
Value input) {
279304
MemRefType inputType = cast<MemRefType>(input.getType());
280-
assert(inputType.hasStaticShape());
281-
SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
282-
SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
283-
ArrayRef<int64_t> subViewSizes = inputType.getShape();
284-
MemRefType resultType =
285-
dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
305+
SmallVector<OpFoldResult> offsets(inputType.getRank(),
306+
rewriter.getIndexAttr(0));
307+
SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
308+
SmallVector<OpFoldResult> strides(inputType.getRank(),
309+
rewriter.getIndexAttr(1));
310+
MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
311+
286312
if (canonicalizeStridedLayout(resultType) ==
287313
canonicalizeStridedLayout(inputType))
288314
return input;
289-
return rewriter.create<memref::SubViewOp>(
290-
loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
315+
return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
316+
sizes, strides);
291317
}
292318

293319
/// Returns the number of dims that aren't unit dims.
294320
static int getReducedRank(ArrayRef<int64_t> shape) {
295321
return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
296322
}
297323

298-
/// Returns a copy of `shape` without unit dims.
299-
static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
300-
SmallVector<int64_t> reducedShape;
301-
llvm::copy_if(shape, std::back_inserter(reducedShape),
302-
[](int64_t dimSize) { return dimSize != 1; });
303-
return reducedShape;
324+
/// Trims non-scalable one dimensions from `oldType` and returns the result
325+
/// type.
326+
static VectorType trimNonScalableUnitDims(VectorType oldType) {
327+
SmallVector<int64_t> newShape;
328+
SmallVector<bool> newScalableDims;
329+
for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
330+
if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
331+
continue;
332+
newShape.push_back(dimSize);
333+
newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
334+
}
335+
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
336+
}
337+
338+
// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
339+
static FailureOr<Value>
340+
createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
341+
vector::CreateMaskOp op) {
342+
auto type = op.getType();
343+
auto reducedType = trimNonScalableUnitDims(type);
344+
if (reducedType.getRank() == type.getRank())
345+
return failure();
346+
347+
SmallVector<Value> reducedOperands;
348+
for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
349+
type.getShape(), type.getScalableDims(), op.getOperands())) {
350+
if (dim == 1 && !dimIsScalable) {
351+
// If the mask for the unit dim is not a constant of 1, do nothing.
352+
auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
353+
if (!constant || (constant.value() != 1))
354+
return failure();
355+
continue;
356+
}
357+
reducedOperands.push_back(operand);
358+
}
359+
return rewriter
360+
.create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
361+
.getResult();
304362
}
305363

306364
namespace {
@@ -320,9 +378,7 @@ class TransferReadDropUnitDimsPattern
320378
Value source = transferReadOp.getSource();
321379
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
322380
// TODO: support tensor types.
323-
if (!sourceType || !sourceType.hasStaticShape())
324-
return failure();
325-
if (sourceType.getNumElements() != vectorType.getNumElements())
381+
if (!sourceType)
326382
return failure();
327383
// TODO: generalize this pattern, relax the requirements here.
328384
if (transferReadOp.hasOutOfBoundsDim())
@@ -335,23 +391,38 @@ class TransferReadDropUnitDimsPattern
335391
return failure();
336392
// Check if the reduced vector shape matches the reduced source shape.
337393
// Otherwise, this case is not supported yet.
338-
int vectorReducedRank = getReducedRank(vectorType.getShape());
339-
if (reducedRank != vectorReducedRank)
394+
auto reducedVectorType = trimNonScalableUnitDims(vectorType);
395+
if (reducedRank != reducedVectorType.getRank())
340396
return failure();
341397
if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
342398
return getConstantIntValue(v) != static_cast<int64_t>(0);
343399
}))
344400
return failure();
401+
402+
Value maskOp = transferReadOp.getMask();
403+
if (maskOp) {
404+
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
405+
if (!createMaskOp)
406+
return rewriter.notifyMatchFailure(
407+
transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
408+
"currently supported");
409+
FailureOr<Value> rankReducedCreateMask =
410+
createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
411+
if (failed(rankReducedCreateMask))
412+
return failure();
413+
maskOp = *rankReducedCreateMask;
414+
}
415+
345416
Value reducedShapeSource =
346417
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
347418
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
348419
SmallVector<Value> zeros(reducedRank, c0);
349420
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
350-
auto reducedVectorType = VectorType::get(
351-
getReducedShape(vectorType.getShape()), vectorType.getElementType());
352-
421+
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
353422
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
354-
loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
423+
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
424+
transferReadOp.getPadding(), maskOp,
425+
rewriter.getBoolArrayAttr(inBounds));
355426
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
356427
loc, vectorType, newTransferReadOp);
357428
rewriter.replaceOp(transferReadOp, shapeCast);

mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,118 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
8282
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
8383
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
8484

85+
func.func @transfer_read_dynamic_rank_reducing(
86+
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
87+
%c0 = arith.constant 0 : index
88+
%pad = arith.constant 0 : i8
89+
%v = vector.transfer_read %arg[%c0, %c0], %pad {in_bounds = [true, true]} :
90+
memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
91+
return %v : vector<[16]x1xi8>
92+
}
93+
// CHECK-LABEL: func @transfer_read_dynamic_rank_reducing
94+
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
95+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
96+
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
97+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
98+
// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<?xi8, {{.*}}>, vector<[16]xi8>
99+
100+
func.func @masked_transfer_read_dynamic_rank_reducing_1(
101+
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
102+
%mask_dim0 : index) -> vector<[16]x1xi8> {
103+
%c0 = arith.constant 0 : index
104+
%c1 = arith.constant 1 : index
105+
%pad = arith.constant 0 : i8
106+
%mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1>
107+
%v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
108+
memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
109+
return %v : vector<[16]x1xi8>
110+
}
111+
// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1
112+
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
113+
// CHECK-SAME: %[[MASK_DIM0:.+]]: index
114+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
115+
// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
116+
// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1>
117+
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
118+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
119+
// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
120+
121+
func.func @masked_transfer_read_dynamic_rank_reducing_2(
122+
%arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>,
123+
%mask_dim1 : index, %mask_dim4 : index) -> vector<1x[1]x3x1x[16]x1xi8> {
124+
%c0 = arith.constant 0 : index
125+
%c1 = arith.constant 1 : index
126+
%c2 = arith.constant 2 : index
127+
%pad = arith.constant 0 : i8
128+
%mask = vector.create_mask %c1, %mask_dim1, %c2, %c1, %mask_dim4, %c1 : vector<1x[1]x3x1x[16]x1xi1>
129+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true, true, true, true]} :
130+
memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8>
131+
return %v : vector<1x[1]x3x1x[16]x1xi8>
132+
}
133+
// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2
134+
// CHECK-SAME: %[[ARG:.+]]: memref<1x?x3x1x?x1xi8
135+
// CHECK-SAME: %[[MASK_DIM1:.+]]: index, %[[MASK_DIM4:.+]]: index
136+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
137+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
138+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
139+
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
140+
// CHECK-DAG: %[[PAD:.+]] = arith.constant 0 : i8
141+
// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM1]], %[[C2]], %[[MASK_DIM4]] : vector<[1]x3x[16]xi1>
142+
// CHECK: %[[DIM1:.+]] = memref.dim %[[ARG]], %[[C1]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
143+
// CHECK: %[[DIM4:.+]] = memref.dim %[[ARG]], %[[C4]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
144+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
145+
// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
146+
147+
/// Only masks operands of vector.create_mask are currently supported.
148+
func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
149+
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
150+
%mask : vector<[16]x1xi1>) -> vector<[16]x1xi8> {
151+
%c0 = arith.constant 0 : index
152+
%pad = arith.constant 0 : i8
153+
%v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
154+
memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
155+
return %v : vector<[16]x1xi8>
156+
}
157+
// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_1
158+
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
159+
// CHECK-NOT: vector.create_mask
160+
// CHECK-NOT: memref.subview
161+
// CHECK: vector.transfer_read %[[ARG]]
162+
163+
/// Unit dim mask must be constant of 1.
164+
func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_2(
165+
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
166+
%mask_dim0 : index, %mask_dim1 : index) -> vector<[16]x1xi8> {
167+
%c0 = arith.constant 0 : index
168+
%c1 = arith.constant 1 : index
169+
%pad = arith.constant 0 : i8
170+
%mask = vector.create_mask %mask_dim0, %mask_dim1 : vector<[16]x1xi1>
171+
%v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
172+
memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
173+
return %v : vector<[16]x1xi8>
174+
}
175+
// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_2
176+
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
177+
// CHECK-NOT: memref.subview
178+
// CHECK: vector.transfer_read {{.*}} vector<[16]x1xi8>
179+
180+
/// Unit dim must be non-scalable.
181+
func.func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim(
182+
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
183+
%mask_dim0 : index) -> vector<[16]x[1]xi8> {
184+
%c0 = arith.constant 0 : index
185+
%c1 = arith.constant 1 : index
186+
%pad = arith.constant 0 : i8
187+
%mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x[1]xi1>
188+
%v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
189+
memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x[1]xi8>
190+
return %v : vector<[16]x[1]xi8>
191+
}
192+
// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim
193+
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
194+
// CHECK-NOT: memref.subview
195+
// CHECK: vector.transfer_read {{.*}} vector<[16]x[1]xi8>
196+
85197
module attributes {transform.with_named_sequence} {
86198
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
87199
transform.apply_patterns to %func_op {

0 commit comments

Comments
 (0)