Skip to content

Commit fdf84cb

Browse files
authored
[mlir][vector] Fix unit dim dropping pattern for masked writes (#74038)
This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.
1 parent b92693a commit fdf84cb

File tree

2 files changed

+67
-19
lines changed

2 files changed

+67
-19
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
260260
opToErase.push_back(read.getOperation());
261261
}
262262

263-
/// Returns a copy of `shape` without unit dims.
264-
static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
265-
SmallVector<int64_t> reducedShape;
266-
llvm::copy_if(shape, std::back_inserter(reducedShape),
267-
[](int64_t dimSize) { return dimSize != 1; });
268-
return reducedShape;
269-
}
270-
271263
/// Converts OpFoldResults to int64_t shape without unit dims.
272264
static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
273265
SmallVector<int64_t> reducedShape;
@@ -340,7 +332,7 @@ static FailureOr<Value>
340332
createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
341333
vector::CreateMaskOp op) {
342334
auto type = op.getType();
343-
auto reducedType = trimNonScalableUnitDims(type);
335+
VectorType reducedType = trimNonScalableUnitDims(type);
344336
if (reducedType.getRank() == type.getRank())
345337
return failure();
346338

@@ -391,7 +383,7 @@ class TransferReadDropUnitDimsPattern
391383
return failure();
392384
// Check if the reduced vector shape matches the reduced source shape.
393385
// Otherwise, this case is not supported yet.
394-
auto reducedVectorType = trimNonScalableUnitDims(vectorType);
386+
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
395387
if (reducedRank != reducedVectorType.getRank())
396388
return failure();
397389
if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
@@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern
446438
Value source = transferWriteOp.getSource();
447439
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
448440
// TODO: support tensor type.
449-
if (!sourceType || !sourceType.hasStaticShape())
450-
return failure();
451-
if (sourceType.getNumElements() != vectorType.getNumElements())
441+
if (!sourceType)
452442
return failure();
453443
// TODO: generalize this pattern, relax the requirements here.
454444
if (transferWriteOp.hasOutOfBoundsDim())
@@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern
461451
return failure();
462452
// Check if the reduced vector shape matches the reduced destination shape.
463453
// Otherwise, this case is not supported yet.
464-
int vectorReducedRank = getReducedRank(vectorType.getShape());
465-
if (reducedRank != vectorReducedRank)
454+
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
455+
if (reducedRank != reducedVectorType.getRank())
466456
return failure();
467457
if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
468458
return getConstantIntValue(v) != static_cast<int64_t>(0);
469459
}))
470460
return failure();
461+
462+
Value maskOp = transferWriteOp.getMask();
463+
if (maskOp) {
464+
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
465+
if (!createMaskOp)
466+
return rewriter.notifyMatchFailure(
467+
transferWriteOp,
468+
"unsupported mask op, only 'vector.create_mask' is "
469+
"currently supported");
470+
FailureOr<Value> rankReducedCreateMask =
471+
createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
472+
if (failed(rankReducedCreateMask))
473+
return failure();
474+
maskOp = *rankReducedCreateMask;
475+
}
471476
Value reducedShapeSource =
472477
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
473478
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
474479
SmallVector<Value> zeros(reducedRank, c0);
475480
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
476-
VectorType reducedVectorType = VectorType::get(
477-
getReducedShape(vectorType.getShape()), vectorType.getElementType());
478-
481+
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
479482
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
480483
loc, reducedVectorType, vector);
481484
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
482-
transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
485+
transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
486+
identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
483487

484488
return success();
485489
}

mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,50 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
144144
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
145145
// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
146146

147+
func.func @masked_transfer_write_and_vector_rank_reducing(
148+
%arg : memref<1x1x3x1x16x1xf32>,
149+
%vec : vector<1x3x1x16x1xf32>,
150+
%mask_dim1 : index,
151+
%mask_dim2 : index) {
152+
%c0 = arith.constant 0 : index
153+
%c1 = arith.constant 1 : index
154+
%mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1>
155+
vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
156+
vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
157+
return
158+
}
159+
// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
160+
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
161+
// CHECK-SAME: {{.*}}: vector<1x3x1x16x1xf32>,
162+
// CHECK-SAME: %[[MASKDIM1:.+]]: index,
163+
// CHECK-SAME: %[[MASKDIM2:.+]]: index
164+
// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1>
165+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
166+
// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
167+
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
168+
169+
func.func @masked_transfer_write_dynamic_rank_reducing(
170+
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
171+
%vec : vector<[16]x1xi8>,
172+
%mask_dim0 : index) {
173+
%c0 = arith.constant 0 : index
174+
%c1 = arith.constant 1 : index
175+
%pad = arith.constant 0 : i8
176+
%mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1>
177+
vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
178+
vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
179+
return
180+
}
181+
// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
182+
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
183+
// CHECK-SAME: %{{.*}}: vector<[16]x1xi8>,
184+
// CHECK-SAME: %[[MASK_DIM0:.+]]: index
185+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
186+
// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1>
187+
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
188+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
189+
// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
190+
147191
/// Only masks operands of vector.create_mask are currently supported.
148192
func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
149193
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,

0 commit comments

Comments
 (0)