-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs #72142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs #72142
Conversation
This patch extends the vector.transfer_read drop unit dim pattern to support scalable vectors with (non-scalable) unit dims, and dynamic memrefs. The xfer op can also have a mask of type 'vector.create_mask', which gets rewritten as long as the mask of the unit dim is a constant of 1.
For context, this and #72105 enable the lowering of a regular
the With this change the unit-dim can be dropped such that the transfer_read is of a rank-1 I've posted this as a draft as I'm not entirely sure if semantically this transformation is correct when I look at the memref dialect and subview op [1], that mentions:
but the vector.transfer_read is in-bounds. Would appreciate any thoughts / feedback. cc @dcaballe @nicolasvasilache @banach-space @MacDue [1] https://mlir.llvm.org/docs/Dialects/MemRef/#memrefsubview-memrefsubviewop |
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Cullen Rhodes (c-rhodes) ChangesThis patch extends TransferReadDropUnitDimsPattern to support dropping
Is rewritten as:
Scalable vectors are now also supported, the scalable dims were being Full diff: https://github.com/llvm/llvm-project/pull/72142.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index a5f1b28152b9bde..95445f2081ec89c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -260,14 +260,22 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
opToErase.push_back(read.getOperation());
}
+/// Returns a copy of `shape` without unit dims.
+static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
+ SmallVector<int64_t> reducedShape;
+ llvm::copy_if(shape, std::back_inserter(reducedShape),
+ [](int64_t dimSize) { return dimSize != 1; });
+ return reducedShape;
+}
+
/// Drops unit dimensions from the input MemRefType.
-static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
- ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides) {
- SmallVector<int64_t> targetShape = llvm::to_vector(
- llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
+static MemRefType dropUnitDims(MemRefType inputType,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
- targetShape, inputType, offsets, sizes, strides);
+ getReducedShape(inputType.getShape()), inputType, offsets, sizes,
+ strides);
return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
}
@@ -277,17 +285,18 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
mlir::Location loc,
Value input) {
MemRefType inputType = cast<MemRefType>(input.getType());
- assert(inputType.hasStaticShape());
- SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
- SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
- ArrayRef<int64_t> subViewSizes = inputType.getShape();
- MemRefType resultType =
- dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
+ SmallVector<OpFoldResult> offsets(inputType.getRank(),
+ rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
+ SmallVector<OpFoldResult> strides(inputType.getRank(),
+ rewriter.getIndexAttr(1));
+ MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
+
if (canonicalizeStridedLayout(resultType) ==
canonicalizeStridedLayout(inputType))
return input;
- return rewriter.create<memref::SubViewOp>(
- loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
+ return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
+ sizes, strides);
}
/// Returns the number of dims that aren't unit dims.
@@ -295,12 +304,18 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
}
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
- SmallVector<int64_t> reducedShape;
- llvm::copy_if(shape, std::back_inserter(reducedShape),
- [](int64_t dimSize) { return dimSize != 1; });
- return reducedShape;
+/// Trims non-scalable one dimensions from `oldType` and returns the result
+/// type.
+static VectorType trimUnitDims(VectorType oldType) {
+ SmallVector<int64_t> newShape;
+ SmallVector<bool> newScalableDims;
+ for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
+ if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
+ continue;
+ newShape.push_back(dimSize);
+ newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
+ }
+ return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
}
namespace {
@@ -320,9 +335,7 @@ class TransferReadDropUnitDimsPattern
Value source = transferReadOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// TODO: support tensor types.
- if (!sourceType || !sourceType.hasStaticShape())
- return failure();
- if (sourceType.getNumElements() != vectorType.getNumElements())
+ if (!sourceType)
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
@@ -335,23 +348,50 @@ class TransferReadDropUnitDimsPattern
return failure();
// Check if the reduced vector shape matches the reduced source shape.
// Otherwise, this case is not supported yet.
- int vectorReducedRank = getReducedRank(vectorType.getShape());
- if (reducedRank != vectorReducedRank)
+ auto reducedVectorType = trimUnitDims(vectorType);
+ if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
+
+ auto maskOp = transferReadOp.getMask();
+ if (maskOp) {
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+ auto maskType = maskOp.getType();
+ auto reducedMaskType = trimUnitDims(maskType);
+ if (reducedMaskType.getRank() == maskType.getRank())
+ return failure();
+ SmallVector<Value> maskOperands;
+ for (auto [dim, dimIsScalable, maskOperand] :
+ llvm::zip(maskType.getShape(), maskType.getScalableDims(),
+ createMaskOp.getOperands())) {
+ if (dim == 1 && !dimIsScalable) {
+ // If the mask for the unit dim is not a constant of 1, do nothing.
+ auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>();
+ if (!constant || (constant.value() != 1))
+ return failure();
+ continue;
+ }
+ maskOperands.push_back(maskOperand);
+ }
+ maskOp = rewriter.create<vector::CreateMaskOp>(loc, reducedMaskType,
+ maskOperands);
+ }
+
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
- auto reducedVectorType = VectorType::get(
- getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
+ SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
+ loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
+ transferReadOp.getPadding(), maskOp,
+ rewriter.getBoolArrayAttr(inBounds));
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, vectorType, newTransferReadOp);
rewriter.replaceOp(transferReadOp, shapeCast);
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 2852e301888cca8..688fcd114041812 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -82,6 +82,92 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
+func.func @transfer_read_dynamic_rank_reducing(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0], %pad {in_bounds = [true, true]} :
+ memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+ return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @transfer_read_dynamic_rank_reducing
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<?xi8, {{.*}}>, vector<[16]xi8>
+
+func.func @masked_transfer_read_dynamic_rank_reducing(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+ %mask_dim0 : index) -> vector<[16]x1xi8> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1>
+ %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+ memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+ return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-SAME: %[[MASK_DIM0:.+]]: index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
+// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1>
+// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
+
+/// Only vector.create_mask is currently supported.
+func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+ %mask : vector<[16]x1xi1>) -> vector<[16]x1xi8> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+ memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+ return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_1
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-NOT: vector.create_mask
+// CHECK-NOT: memref.subview
+// CHECK: vector.transfer_read %[[ARG]]
+
+/// Unit dim mask must be constant of 1.
+func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_2(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+ %mask_dim0 : index, %mask_dim1 : index) -> vector<[16]x1xi8> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.create_mask %mask_dim0, %mask_dim1 : vector<[16]x1xi1>
+ %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+ memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+ return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_2
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-NOT: memref.subview
+// CHECK: vector.transfer_read {{.*}} vector<[16]x1xi8>
+
+/// Unit dim must be non-scalable.
+func.func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+ %mask_dim0 : index) -> vector<[16]x[1]xi8> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x[1]xi1>
+ %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+ memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x[1]xi8>
+ return %v : vector<[16]x[1]xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-NOT: memref.subview
+// CHECK: vector.transfer_read {{.*}} vector<[16]x[1]xi8>
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
transform.apply_patterns to %func_op {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable to me :). The docs you link seem to say this is allowed, as this is only removing statically known unit dims. The added shape_cast
won't be something that can be lowered, but should (hopefully) fold away.
https://mlir.llvm.org/docs/Dialects/MemRef/#memrefsubview-memrefsubviewop
A subview operation may additionally reduce the rank of the resulting view by removing dimensions that are statically known to be of size 1.
// If the mask for the unit dim is not a constant of 1, do nothing. | ||
auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>(); | ||
if (!constant || (constant.value() != 1)) | ||
return failure(); | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably done elsewhere, but if any dim in the mask is 0 this whole read folds to a constant splat of the padding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't checked but I figured that should already exist and isn't something to handle here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we hoist this logic in a helper with a good name? It seems this is deep enough already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that came up in another review recently. I think we have a canonicalization pattern for that already?
can confirm the shape_cast folds away :) |
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
// If the mask for the unit dim is not a constant of 1, do nothing. | ||
auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>(); | ||
if (!constant || (constant.value() != 1)) | ||
return failure(); | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we hoist this logic in a helper with a good name? It seems this is deep enough already.
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>( | ||
loc, reducedVectorType, reducedShapeSource, zeros, identityMap); | ||
loc, reducedVectorType, reducedShapeSource, zeros, identityMap, | ||
transferReadOp.getPadding(), maskOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for adding the previously omitted mask!
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
It would be good to add tests with non-trailing unit dim and when there's more than one unit dim (and perhaps a mix of scalable and non-scalable).
mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
// If the mask for the unit dim is not a constant of 1, do nothing. | ||
auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>(); | ||
if (!constant || (constant.value() != 1)) | ||
return failure(); | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that came up in another review recently. I think we have a canonicalization pattern for that already?
* trimUnitDims -> trimNonScalableUnitDims * add notifyMatchFailure on unsupported mask op * llvm::zip -> llvm::zip_equal * add helper to rewrite vector.create_mask to drop non-scalable unit dims. * add getReducedShape that takes mixedSizes. * add a more complex test.
I've added a more complex test |
…ially-static memrefs (llvm#72142) This patch extends TransferReadDropUnitDimsPattern to support dropping unit dims from partially-static memrefs, for example: %v = vector.transfer_read %base[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8> Is rewritten as: %dim0 = memref.dim %base, %c0 : memref<?x1xi8, strided<[?, ?], offset: ?>> %subview = memref.subview %base[0, 0] [%dim0, 1] [1, 1] : memref<?x1xi8, strided<[?, ?], offset: ?>> to memref<?xi8, #map1> %v = vector.transfer_read %subview[%c0], %pad {in_bounds = [true]} : memref<?xi8, #map1>, vector<[16]xi8> Scalable vectors are now also supported, the scalable dims were being dropped when creating the rank-reduced vector type. The xfer op can also have a mask of type 'vector.create_mask', which gets rewritten as long as the mask of the unit dim is a constant of 1.
…ially-static memrefs (llvm#72142) This patch extends TransferReadDropUnitDimsPattern to support dropping unit dims from partially-static memrefs, for example: %v = vector.transfer_read %base[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8> Is rewritten as: %dim0 = memref.dim %base, %c0 : memref<?x1xi8, strided<[?, ?], offset: ?>> %subview = memref.subview %base[0, 0] [%dim0, 1] [1, 1] : memref<?x1xi8, strided<[?, ?], offset: ?>> to memref<?xi8, #map1> %v = vector.transfer_read %subview[%c0], %pad {in_bounds = [true]} : memref<?xi8, #map1>, vector<[16]xi8> Scalable vectors are now also supported, the scalable dims were being dropped when creating the rank-reduced vector type. The xfer op can also have a mask of type 'vector.create_mask', which gets rewritten as long as the mask of the unit dim is a constant of 1.
This does the same as llvm#72142 for vector.transfer_write. Previously the pattern would silently drop the mask.
This does the same as llvm#72142 for vector.transfer_write. Previously the pattern would silently drop the mask.
This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.
This does the same as llvm#72142 for vector.transfer_write. Previously the pattern would silently drop the mask.
This does the same as llvm#72142 for vector.transfer_write. Previously the pattern would silently drop the mask.
This patch extends TransferReadDropUnitDimsPattern to support dropping
unit dims from partially-static memrefs, for example:
Is rewritten as:
Scalable vectors are now also supported, the scalable dims were being
dropped when creating the rank-reduced vector type. The xfer op can also
have a mask of type 'vector.create_mask', which gets rewritten as long
as the mask of the unit dim is a constant of 1.