Skip to content

[mlir][vector] Distribute all non-permutation or broadcasted masked transfer reads #73539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

qedawkins
Copy link
Contributor

The primary difficulty with distribution of masked transfers is when the permutation map permutes the vector, in which case the distribution logic needs to make sure the correct mask elements end up with the distributed transfer. This is only tricky when the permutation map has a permutation in it, so we can relax the condition for distribution.

…ransfer reads

The primary difficulty with distribution of masked transfers is when the
permutation map permutes the vector, in which case the distribution
logic needs to make sure the correct mask elements end up with the
distributed transfer. This is only tricky when the permutation map
has a permutation in it, so we can relax the condition for
distribution.
@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

Changes

The primary difficulty with distribution of masked transfers is when the permutation map permutes the vector, in which case the distribution logic needs to make sure the correct mask elements end up with the distributed transfer. This is only tricky when the permutation map has a permutation in it, so we can relax the condition for distribution.


Full diff: https://github.com/llvm/llvm-project/pull/73539.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+1-1)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+26)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 0ad2c71cf3a6a11..07ecd8857520338 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -837,7 +837,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
       // of which lane is responsible for which element is captured strictly
       // by shape information on the warp op, and thus requires materializing
       // the permutation in IR.
-      if (!read.getPermutationMap().isMinorIdentity())
+      if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
         return failure();
       VectorType maskType =
           getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 8056260f4610977..ab175effa3dfb80 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1351,6 +1351,32 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
 
 // -----
 
+func.func @warp_propagate_nontrivial_map_masked_transfer_read(%laneid: index, %src: memref<4096x4096xf32>, %index: index) -> vector<2xf32> {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>) {
+    %mask = "mask_def_0"() : () -> (vector<128xi1>)
+    %0 = vector.transfer_read %src[%index, %c0], %f0, %mask {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<4096x4096xf32>, vector<128xf32>
+    vector.yield %0 : vector<128xf32>
+  }
+  return %r : vector<2xf32>
+}
+
+//   CHECK-PROP-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
+//   CHECK-PROP-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-PROP-LABEL: func.func @warp_propagate_nontrivial_map_masked_transfer_read
+//  CHECK-PROP-SAME:   %[[ARG0:.+]]: index, {{.*}}, %[[ARG2:.+]]: index
+//       CHECK-PROP:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[64] -> (vector<2xi1>) {
+//       CHECK-PROP:     %[[M0:.*]] = "mask_def_0"
+//       CHECK-PROP:     vector.yield %[[M0]] : vector<128xi1>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[DIST_READ_IDX0:.+]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG0]]]
+//       CHECK-PROP:   vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[C0]]], {{.*}}, %[[R]]
+//  CHECK-PROP-SAME:   permutation_map = #[[$MAP1]]} {{.*}} vector<2xf32>
+
+// -----
+
 func.func @warp_propagate_masked_transfer_read_shared_mask(%laneid: index, %src: memref<4096x4096xf32>, %index: index, %index2: index, %mask_ub: index) -> (vector<2xf32>, vector<2xf32>) {
   %f0 = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index

@qedawkins qedawkins merged commit f385f6c into llvm:main Nov 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants