-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Fix patterns for dropping leading unit dims from masks #73525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Fix patterns for dropping leading unit dims from masks #73525
Conversation
Previously the pattern only worked when the permutation map was a minor identity. Infer the new mask type from the new transfer map after dropping leading unit dims.
@llvm/pr-subscribers-mlir-vector Author: Quinn Dawkins (qedawkins) ChangesPreviously the pattern only worked when the permutation map was a minor identity. Infer the new mask type from the new transfer map after dropping leading unit dims. Full diff: https://github.com/llvm/llvm-project/pull/73525.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 9ab20e20d975429..e9dab8f1e44ae68 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -160,6 +160,12 @@ getAsConstantIndexOps(ArrayRef<Value> values);
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
+/// Infers the mask type for a transfer op given its vector type and
+/// permutation map. The mask in a transfer op operation applies to the
+/// tensor/buffer part of it and its type should match the vector shape
+/// *before* any permutation or broadcasting.
+VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);
+
/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
/// as masked operation.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7b74701fdbc8f2..c462b23e1133fc9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3754,12 +3754,8 @@ void TransferReadOp::print(OpAsmPrinter &p) {
p << " : " << getShapedType() << ", " << getVectorType();
}
-/// Infers the mask type for a transfer op given its vector type and
-/// permutation map. The mask in a transfer op operation applies to the
-/// tensor/buffer part of it and its type should match the vector shape
-/// *before* any permutation or broadcasting.
-static VectorType inferTransferOpMaskType(VectorType vecType,
- AffineMap permMap) {
+VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
+ AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
assert(invPermMap && "Inversed permutation map couldn't be computed");
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 75f32b23e57b0d6..3c85606da5ec522 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -197,6 +197,23 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
}
};
+static Value processTransferMask(OpBuilder &b, Location loc, Value mask,
+ VectorType newType, AffineMap newMap,
+ VectorType oldMaskType) {
+ // Infer the type of the new mask from the new map.
+ auto newMaskType = inferTransferOpMaskType(newType, newMap);
+
+ // If the new mask is broadcastable to the old result type, we can safely
+ // use a `vector.extract` to get the new mask. Otherwise the best we can
+ // do is shape cast.
+ if (mlir::vector::isBroadcastableTo(newMaskType, oldMaskType) ==
+ BroadcastableToResult::Success) {
+ int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
+ return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
+ }
+ return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+}
+
// Turns vector.transfer_read on vector with leading 1 dimensions into
// vector.shape_cast followed by vector.transfer_read on vector without leading
// 1 dimensions.
@@ -234,11 +251,9 @@ struct CastAwayTransferReadLeadingOneDim
Value mask = Value();
if (read.getMask()) {
- // The mask shape must always match the shape of the written vector, so we
- // can safely use the same extraction indices.
- int64_t dropDim = oldType.getRank() - newType.getRank();
- mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
- splatZero(dropDim));
+ VectorType maskType = read.getMaskType();
+ mask = processTransferMask(rewriter, read.getLoc(), read.getMask(),
+ newType, newMap, maskType);
}
auto newRead = rewriter.create<vector::TransferReadOp>(
@@ -289,10 +304,9 @@ struct CastAwayTransferWriteLeadingOneDim
write.getLoc(), write.getVector(), splatZero(dropDim));
if (write.getMask()) {
- // The mask shape must always match the shape of the written vector, so we
- // can safely use the same extraction indices.
- auto newMask = rewriter.create<vector::ExtractOp>(
- write.getLoc(), write.getMask(), splatZero(dropDim));
+ VectorType maskType = write.getMaskType();
+ Value newMask = processTransferMask(
+ rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 5de30206927db2f..71dffca8f14da59 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -232,6 +232,27 @@ func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x
return %0: vector<1x1xf16>
}
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
+func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+ %f0 = arith.constant 0. : f16
+ // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+ // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
+ // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
+ %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
+ // CHECK: return %[[CAST]]
+ return %0: vector<1x1x4xf16>
+}
+
+// -----
+
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -263,6 +284,25 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
return
}
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
+func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
+ // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+ // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
+ // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
+
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
func.func @cast_away_elementwise_leading_one_dims(
%arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
|
@llvm/pr-subscribers-mlir Author: Quinn Dawkins (qedawkins) ChangesPreviously the pattern only worked when the permutation map was a minor identity. Infer the new mask type from the new transfer map after dropping leading unit dims. Full diff: https://github.com/llvm/llvm-project/pull/73525.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 9ab20e20d975429..e9dab8f1e44ae68 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -160,6 +160,12 @@ getAsConstantIndexOps(ArrayRef<Value> values);
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
+/// Infers the mask type for a transfer op given its vector type and
+/// permutation map. The mask in a transfer op operation applies to the
+/// tensor/buffer part of it and its type should match the vector shape
+/// *before* any permutation or broadcasting.
+VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);
+
/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
/// as masked operation.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7b74701fdbc8f2..c462b23e1133fc9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3754,12 +3754,8 @@ void TransferReadOp::print(OpAsmPrinter &p) {
p << " : " << getShapedType() << ", " << getVectorType();
}
-/// Infers the mask type for a transfer op given its vector type and
-/// permutation map. The mask in a transfer op operation applies to the
-/// tensor/buffer part of it and its type should match the vector shape
-/// *before* any permutation or broadcasting.
-static VectorType inferTransferOpMaskType(VectorType vecType,
- AffineMap permMap) {
+VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
+ AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
assert(invPermMap && "Inversed permutation map couldn't be computed");
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 75f32b23e57b0d6..3c85606da5ec522 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -197,6 +197,23 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
}
};
+static Value processTransferMask(OpBuilder &b, Location loc, Value mask,
+ VectorType newType, AffineMap newMap,
+ VectorType oldMaskType) {
+ // Infer the type of the new mask from the new map.
+ auto newMaskType = inferTransferOpMaskType(newType, newMap);
+
+ // If the new mask is broadcastable to the old result type, we can safely
+ // use a `vector.extract` to get the new mask. Otherwise the best we can
+ // do is shape cast.
+ if (mlir::vector::isBroadcastableTo(newMaskType, oldMaskType) ==
+ BroadcastableToResult::Success) {
+ int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
+ return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
+ }
+ return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+}
+
// Turns vector.transfer_read on vector with leading 1 dimensions into
// vector.shape_cast followed by vector.transfer_read on vector without leading
// 1 dimensions.
@@ -234,11 +251,9 @@ struct CastAwayTransferReadLeadingOneDim
Value mask = Value();
if (read.getMask()) {
- // The mask shape must always match the shape of the written vector, so we
- // can safely use the same extraction indices.
- int64_t dropDim = oldType.getRank() - newType.getRank();
- mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
- splatZero(dropDim));
+ VectorType maskType = read.getMaskType();
+ mask = processTransferMask(rewriter, read.getLoc(), read.getMask(),
+ newType, newMap, maskType);
}
auto newRead = rewriter.create<vector::TransferReadOp>(
@@ -289,10 +304,9 @@ struct CastAwayTransferWriteLeadingOneDim
write.getLoc(), write.getVector(), splatZero(dropDim));
if (write.getMask()) {
- // The mask shape must always match the shape of the written vector, so we
- // can safely use the same extraction indices.
- auto newMask = rewriter.create<vector::ExtractOp>(
- write.getLoc(), write.getMask(), splatZero(dropDim));
+ VectorType maskType = write.getMaskType();
+ Value newMask = processTransferMask(
+ rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 5de30206927db2f..71dffca8f14da59 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -232,6 +232,27 @@ func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x
return %0: vector<1x1xf16>
}
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
+func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+ %f0 = arith.constant 0. : f16
+ // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+ // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
+ // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
+ %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
+ // CHECK: return %[[CAST]]
+ return %0: vector<1x1x4xf16>
+}
+
+// -----
+
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -263,6 +284,25 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
return
}
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
+func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
+ // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+ // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
+ // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
+
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
func.func @cast_away_elementwise_leading_one_dims(
%arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a few minor changes. THanks!
Previously the pattern only worked when the permutation map was a minor identity. Infer the new mask type from the new transfer map after dropping leading unit dims.