-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Update tests for xfer permutation lowering (4/N) #127624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Update tests for xfer permutation lowering (4/N) #127624
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/127624.diff 1 Files Affected:
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index dfc79a19e6cc6..c816371c84c1f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -363,17 +363,40 @@ func.func @xfer_read_minor_identity_transposed_masked_scalable(
}
///----------------------------------------------------------------------------------------
-/// vector.transfer_read
+/// [Pattern: TransferOpReduceRank]
+///
+/// IN: vector.transfer_read (minor identity map + broadcast)
+/// OUT: vector.transfer_read + vector.broadcast
///----------------------------------------------------------------------------------------
/// TODO: Review and categorize
+// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims
+// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x4x2x3xf32> {
+// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
+// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
+// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
+func.func @xfer_read_minor_identitiy_bcast_dims(
+ %mem: memref<?x?x?x?xf32>,
+ %idx: index) -> vector<8x4x2x3xf32> {
+
+ %pad = arith.constant 0.000000e+00 : f32
+
+ %res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
+ in_bounds = [true, true, true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
+ } : memref<?x?x?x?xf32>, vector<8x4x2x3xf32>
+
+ return %res : vector<8x4x2x3xf32>
+}
+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
- %mem: memref<?x?x?x?xf32>, %idx: index) -> vector<8x[4]x2x3xf32> {
+ %mem: memref<?x?x?x?xf32>,
+ %idx: index) -> vector<8x[4]x2x3xf32> {
%pad = arith.constant 0.000000e+00 : f32
@@ -385,18 +408,41 @@ func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
return %res : vector<8x[4]x2x3xf32>
}
+// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_with_mask
+// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1>
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32>
+// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
+// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
+// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
+func.func @xfer_read_minor_identitiy_bcast_dims_with_mask(
+ %mem: memref<?x?x?x?xf32>,
+ %mask: vector<4x3xi1>,
+ %idx: index) -> vector<8x4x2x3xf32> {
+
+ %pad = arith.constant 0.000000e+00 : f32
+
+ %res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad, %mask {
+ in_bounds = [true, true, true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
+ } : memref<?x?x?x?xf32>, vector<8x4x2x3xf32>
+
+ return %res : vector<8x4x2x3xf32>
+}
+
// Masked version is not supported
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_masked
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
-// CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1>
-// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
+// CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1>
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32> {
// CHECK-NOT: vector.broadcast
-// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
+// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32> } : vector<4x3xi1> -> vector<8x4x2x3xf32>
func.func @xfer_read_minor_identitiy_bcast_dims_masked(
%mem: memref<?x?x?x?xf32>,
- %mask: vector<[4]x3xi1>,
- %idx: index) -> vector<8x[4]x2x3xf32> {
+ %mask: vector<4x3xi1>,
+ %idx: index) -> vector<8x4x2x3xf32> {
%pad = arith.constant 0.000000e+00 : f32
@@ -404,12 +450,15 @@ func.func @xfer_read_minor_identitiy_bcast_dims_masked(
vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
in_bounds = [true, true, true, true],
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
- } : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
- } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
+ } : memref<?x?x?x?xf32>, vector<8x4x2x3xf32>
+ } : vector<4x3xi1> -> vector<8x4x2x3xf32>
- return %res : vector<8x[4]x2x3xf32>
+ return %res : vector<8x4x2x3xf32>
}
+///----------------------------------------------------------------------------------------
+// TD sequence
+///----------------------------------------------------------------------------------------
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32> | ||
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32> | ||
// CHECK: return %[[BC]] : vector<8x4x2x3xf32> | ||
func.func @xfer_read_minor_identitiy_bcast_dims_with_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this test deserve a scalable flavour ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, why not :) And thanks for the review!
* Document the remaining test cases, add a note that these are exercising `TransferOpReduceRank` (addresses an existing TODO). * Add missing cases (fixed-width vectors). * Remove scalable from the negative test (the masked case) - this test will also fail with fixed-width vectors. For consistency, lets make all negative test use fixed-width vectors.
Add a case with scalable vec
c346cf6
to
4c906ac
Compare
exercising
TransferOpReduceRank
(addresses an existing TODO).will also fail with fixed-width vectors. For consistency, lets make
all negative test use fixed-width vectors.