Skip to content

[mlir][vector] Support more mask types in foldTransferFullMask() #96761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jun 26, 2024

Using the existing getMaskFormat() this can be extended to support arith.constant masks.

Using the existing `getMaskFormat()` this can be extended to support
`arith.constant` masks.
@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

Using the existing getMaskFormat() this can be extended to support arith.constant masks.


Full diff: https://github.com/llvm/llvm-project/pull/96761.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-5)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+10-2)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6734c80f2760d..149723f51cc12 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4172,11 +4172,7 @@ static LogicalResult foldTransferFullMask(TransferOp op) {
   if (!mask)
     return failure();
 
-  auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
-  if (!constantMask)
-    return failure();
-
-  if (!constantMask.isAllOnesMask())
+  if (getMaskFormat(mask) != MaskFormat::AllTrue)
     return failure();
 
   op.getMaskMutable().clear();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8181f1a8c5d13..ecd49df3b2141 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -868,7 +868,7 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
 // -----
 
 // CHECK-LABEL: fold_vector_transfer_masks
-func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>) {
+func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
@@ -876,6 +876,8 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
 
   %mask = vector.constant_mask [8, 4] : vector<8x4xi1>
 
+  %mask_splat = arith.constant dense<true> : vector<4x[4]xi1>
+
   // CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
   %1 = vector.transfer_read %A[%c0, %c0], %f0, %mask
       {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<?x?xf32>, vector<4x8xf32>
@@ -884,8 +886,14 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
   vector.transfer_write %1, %A[%c0, %c0], %mask
       {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<4x8xf32>, memref<?x?xf32>
 
+  // CHECK: vector.transfer_read %{{.*}}, %[[F0]] :
+  %2 = vector.transfer_read %A[%c0, %c0], %f0, %mask_splat : memref<?x?xf32>, vector<4x[4]xf32>
+
+  // CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] :
+  vector.transfer_write %2, %A[%c0, %c0], %mask_splat : vector<4x[4]xf32>, memref<?x?xf32>
+
   // CHECK: return
-  return %1 : vector<4x8xf32>
+  return %1, %2 : vector<4x8xf32>, vector<4x[4]xf32>
 }
 
 // -----

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! LGTM just one minor nit, cheers

@MacDue MacDue merged commit 2731d26 into llvm:main Jun 27, 2024
4 of 5 checks passed
@MacDue MacDue deleted the fold_more branch June 27, 2024 11:14
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
…m#96761)

Using the existing `getMaskFormat()` this can be extended to support
`arith.constant` masks.
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…m#96761)

Using the existing `getMaskFormat()` this can be extended to support
`arith.constant` masks.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants