Skip to content

[mlir][vector] Fix patterns for dropping leading unit dims from masks #73525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 27, 2023

Conversation

qedawkins
Copy link
Contributor

Previously the pattern only worked when the permutation map was a minor identity. Infer the new mask type from the new transfer map after dropping leading unit dims.

Previously the pattern only worked when the permutation map was a minor
identity. Infer the new mask type from the new transfer map after
dropping leading unit dims.
@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2023

@llvm/pr-subscribers-mlir-vector

Author: Quinn Dawkins (qedawkins)

Changes

Previously the pattern only worked when the permutation map was a minor identity. Infer the new mask type from the new transfer map after dropping leading unit dims.


Full diff: https://github.com/llvm/llvm-project/pull/73525.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+6)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+2-6)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+23-9)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+40)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 9ab20e20d975429..e9dab8f1e44ae68 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -160,6 +160,12 @@ getAsConstantIndexOps(ArrayRef<Value> values);
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
 
+/// Infers the mask type for a transfer op given its vector type and
+/// permutation map. The mask in a transfer op operation applies to the
+/// tensor/buffer part of it and its type should match the vector shape
+/// *before* any permutation or broadcasting.
+VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);
+
 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
 /// as masked operation.
 void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7b74701fdbc8f2..c462b23e1133fc9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3754,12 +3754,8 @@ void TransferReadOp::print(OpAsmPrinter &p) {
   p << " : " << getShapedType() << ", " << getVectorType();
 }
 
-/// Infers the mask type for a transfer op given its vector type and
-/// permutation map. The mask in a transfer op operation applies to the
-/// tensor/buffer part of it and its type should match the vector shape
-/// *before* any permutation or broadcasting.
-static VectorType inferTransferOpMaskType(VectorType vecType,
-                                          AffineMap permMap) {
+VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
+                                                 AffineMap permMap) {
   auto i1Type = IntegerType::get(permMap.getContext(), 1);
   AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
   assert(invPermMap && "Inversed permutation map couldn't be computed");
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 75f32b23e57b0d6..3c85606da5ec522 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -197,6 +197,23 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
   }
 };
 
+static Value processTransferMask(OpBuilder &b, Location loc, Value mask,
+                                 VectorType newType, AffineMap newMap,
+                                 VectorType oldMaskType) {
+  // Infer the type of the new mask from the new map.
+  auto newMaskType = inferTransferOpMaskType(newType, newMap);
+
+  // If the new mask is broadcastable to the old result type, we can safely
+  // use a `vector.extract` to get the new mask. Otherwise the best we can
+  // do is shape cast.
+  if (mlir::vector::isBroadcastableTo(newMaskType, oldMaskType) ==
+      BroadcastableToResult::Success) {
+    int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
+    return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
+  }
+  return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+}
+
 // Turns vector.transfer_read on vector with leading 1 dimensions into
 // vector.shape_cast followed by vector.transfer_read on vector without leading
 // 1 dimensions.
@@ -234,11 +251,9 @@ struct CastAwayTransferReadLeadingOneDim
 
     Value mask = Value();
     if (read.getMask()) {
-      // The mask shape must always match the shape of the written vector, so we
-      // can safely use the same extraction indices.
-      int64_t dropDim = oldType.getRank() - newType.getRank();
-      mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
-                                                splatZero(dropDim));
+      VectorType maskType = read.getMaskType();
+      mask = processTransferMask(rewriter, read.getLoc(), read.getMask(),
+                                 newType, newMap, maskType);
     }
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
@@ -289,10 +304,9 @@ struct CastAwayTransferWriteLeadingOneDim
         write.getLoc(), write.getVector(), splatZero(dropDim));
 
     if (write.getMask()) {
-      // The mask shape must always match the shape of the written vector, so we
-      // can safely use the same extraction indices.
-      auto newMask = rewriter.create<vector::ExtractOp>(
-          write.getLoc(), write.getMask(), splatZero(dropDim));
+      VectorType maskType = write.getMaskType();
+      Value newMask = processTransferMask(
+          rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
       rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
           write, newVector, write.getSource(), write.getIndices(),
           AffineMapAttr::get(newMap), newMask, inBoundsAttr);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 5de30206927db2f..71dffca8f14da59 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -232,6 +232,27 @@ func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x
   return %0: vector<1x1xf16>
 }
 
+// -----
+
+// CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
+func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+  %f0 = arith.constant 0. : f16
+  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
+  // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
+                            permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<1x1x4xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
 func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -263,6 +284,25 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
   return
 }
 
+// -----
+
+// CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
+func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
+  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
+  // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
+
+  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
+                        permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
 func.func @cast_away_elementwise_leading_one_dims(
   %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2023

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

Changes

Previously the pattern only worked when the permutation map was a minor identity. Infer the new mask type from the new transfer map after dropping leading unit dims.


Full diff: https://github.com/llvm/llvm-project/pull/73525.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+6)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+2-6)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+23-9)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+40)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 9ab20e20d975429..e9dab8f1e44ae68 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -160,6 +160,12 @@ getAsConstantIndexOps(ArrayRef<Value> values);
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
 
+/// Infers the mask type for a transfer op given its vector type and
+/// permutation map. The mask in a transfer op operation applies to the
+/// tensor/buffer part of it and its type should match the vector shape
+/// *before* any permutation or broadcasting.
+VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);
+
 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
 /// as masked operation.
 void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7b74701fdbc8f2..c462b23e1133fc9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3754,12 +3754,8 @@ void TransferReadOp::print(OpAsmPrinter &p) {
   p << " : " << getShapedType() << ", " << getVectorType();
 }
 
-/// Infers the mask type for a transfer op given its vector type and
-/// permutation map. The mask in a transfer op operation applies to the
-/// tensor/buffer part of it and its type should match the vector shape
-/// *before* any permutation or broadcasting.
-static VectorType inferTransferOpMaskType(VectorType vecType,
-                                          AffineMap permMap) {
+VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
+                                                 AffineMap permMap) {
   auto i1Type = IntegerType::get(permMap.getContext(), 1);
   AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
   assert(invPermMap && "Inversed permutation map couldn't be computed");
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 75f32b23e57b0d6..3c85606da5ec522 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -197,6 +197,23 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
   }
 };
 
+static Value processTransferMask(OpBuilder &b, Location loc, Value mask,
+                                 VectorType newType, AffineMap newMap,
+                                 VectorType oldMaskType) {
+  // Infer the type of the new mask from the new map.
+  auto newMaskType = inferTransferOpMaskType(newType, newMap);
+
+  // If the new mask is broadcastable to the old result type, we can safely
+  // use a `vector.extract` to get the new mask. Otherwise the best we can
+  // do is shape cast.
+  if (mlir::vector::isBroadcastableTo(newMaskType, oldMaskType) ==
+      BroadcastableToResult::Success) {
+    int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
+    return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
+  }
+  return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+}
+
 // Turns vector.transfer_read on vector with leading 1 dimensions into
 // vector.shape_cast followed by vector.transfer_read on vector without leading
 // 1 dimensions.
@@ -234,11 +251,9 @@ struct CastAwayTransferReadLeadingOneDim
 
     Value mask = Value();
     if (read.getMask()) {
-      // The mask shape must always match the shape of the written vector, so we
-      // can safely use the same extraction indices.
-      int64_t dropDim = oldType.getRank() - newType.getRank();
-      mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
-                                                splatZero(dropDim));
+      VectorType maskType = read.getMaskType();
+      mask = processTransferMask(rewriter, read.getLoc(), read.getMask(),
+                                 newType, newMap, maskType);
     }
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
@@ -289,10 +304,9 @@ struct CastAwayTransferWriteLeadingOneDim
         write.getLoc(), write.getVector(), splatZero(dropDim));
 
     if (write.getMask()) {
-      // The mask shape must always match the shape of the written vector, so we
-      // can safely use the same extraction indices.
-      auto newMask = rewriter.create<vector::ExtractOp>(
-          write.getLoc(), write.getMask(), splatZero(dropDim));
+      VectorType maskType = write.getMaskType();
+      Value newMask = processTransferMask(
+          rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
       rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
           write, newVector, write.getSource(), write.getIndices(),
           AffineMapAttr::get(newMap), newMask, inBoundsAttr);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 5de30206927db2f..71dffca8f14da59 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -232,6 +232,27 @@ func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x
   return %0: vector<1x1xf16>
 }
 
+// -----
+
+// CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
+func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+  %f0 = arith.constant 0. : f16
+  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
+  // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
+                            permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<1x1x4xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
 func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -263,6 +284,25 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
   return
 }
 
+// -----
+
+// CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
+func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
+  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
+  // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
+
+  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
+                        permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
 func.func @cast_away_elementwise_leading_one_dims(
   %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a few minor changes. THanks!

@qedawkins qedawkins merged commit 6e8f7d5 into llvm:main Nov 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants