Skip to content

[mlir][vector] Update tests for xfer permutation lowering (1/N) #123076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 77 additions & 73 deletions mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,43 @@

// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map
// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) {
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>
// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
// CHECK: vector.transfer_write
// CHECK-NOT: permutation_map
// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
func.func @xfer_write_transposing_permutation_map(
%vec: vector<4x8xi16>,
%mem: memref<2x2x8x4xi16>) {
%mem: memref<2x2x8x4xi16>,
%idx: index) {

%c0 = arith.constant 0 : index
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
in_bounds = [true, true],
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
} : vector<4x8xi16>, memref<2x2x8x4xi16>

return
}

// Even with out-of-bounds, it is safe to apply this pattern
// Even with out-of-bounds accesses, it is safe to apply this pattern

// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds
// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>,
// CHECK-SAME: %[[IDX:.*]]: index) {
// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
// Expect the in_bounds attribute to be preserved. Since we don't print it when
// all flags are "false", it should not appear in the output.
// CHECK-NOT: in_bounds
// CHECK: vector.transfer_write
// CHECK-NOT: permutation_map
// CHECK-SAME: %[[TR]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
// CHECK-SAME: %[[TR]], %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
func.func @xfer_write_transposing_permutation_map_out_of_bounds(
%vec: vector<4x8xi16>,
%mem: memref<2x2x?x?xi16>) {
%mem: memref<2x2x?x?xi16>,
%idx: index) {

%c0 = arith.constant 0 : index
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
in_bounds = [false, false],
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
} : vector<4x8xi16>, memref<2x2x?x?xi16>
Expand All @@ -59,18 +60,19 @@ func.func @xfer_write_transposing_permutation_map_out_of_bounds(
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
// CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>) {
// CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>
// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
// CHECK: vector.transfer_write
// CHECK-NOT: permutation_map
// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
%vec: vector<4x[8]xi16>,
%mem: memref<2x2x?x4xi16>,
%mask: vector<[8]x4xi1>) {
%mask: vector<[8]x4xi1>,
%idx: index) {

%c0 = arith.constant 0 : index
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0], %mask {
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx], %mask {
in_bounds = [true, true],
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
} : vector<4x[8]xi16>, memref<2x2x?x4xi16>
Expand All @@ -79,16 +81,18 @@ func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
}

// Masked version is not supported

// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_masked
// CHECK-NOT: vector.transpose
func.func @xfer_write_transposing_permutation_map_masked(
%vec: vector<4x8xi16>,
%mem: memref<2x2x8x4xi16>,
%mask: vector<8x4xi1>) {
%mask: vector<8x4xi1>,
%idx: index) {

%c0 = arith.constant 0 : index
vector.mask %mask {
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
in_bounds = [true, true],
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
} : vector<4x8xi16>, memref<2x2x8x4xi16>
Expand Down Expand Up @@ -128,7 +132,8 @@ func.func @xfer_write_non_transposing_permutation_map(
return
}

// Even with out-of-bounds, it is safe to apply this pattern
// Even with out-of-bounds accesses, it is safe to apply this pattern

// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
Expand Down Expand Up @@ -157,8 +162,7 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
// CHECK: func.func @permutation_with_mask_xfer_write_scalable(
// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<1x4x?x1xi16>,
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
// CHECK: %[[BC_1:.*]] = vector.broadcast %[[VEC]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
// CHECK: %[[BC_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BC_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
Expand All @@ -167,18 +171,19 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
func.func @permutation_with_mask_xfer_write_scalable(
%vec: vector<4x[8]xi16>,
%mem: memref<1x4x?x1xi16>,
%mask: vector<4x[8]xi1>){
%mask: vector<4x[8]xi1>,
%idx: index){

%c0 = arith.constant 0 : index
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0], %mask {
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx], %mask {
in_bounds = [true, true],
permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
} : vector<4x[8]xi16>, memref<1x4x?x1xi16>

return
}

// transfer_write in MaskOp case not supported.
// Masked version is not supported

// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<16xf32>,
Expand All @@ -204,18 +209,19 @@ func.func @masked_permutation_xfer_write_fixed_width(
// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
// CHECK-SAME: -> tensor<?x?x?x?xf32> {
// CHECK-NOT: vector.transpose
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
func.func @masked_permutation_xfer_write_scalable(
%vec: vector<4x[8]xi16>,
%dest: tensor<?x?x?x?xf32>,
%mask: vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
%mask: vector<4x[8]xi1>,
%idx: index) -> tensor<?x?x?x?xf32> {

%c0 = arith.constant 0 : index
%res = vector.mask %mask {
vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0] {
vector.transfer_write %vec, %dest[%idx, %idx, %idx, %idx] {
in_bounds = [true, true],
permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
} : vector<4x[8]xi16>, tensor<?x?x?x?xf32>
Expand All @@ -224,22 +230,23 @@ func.func @masked_permutation_xfer_write_scalable(
return %res : tensor<?x?x?x?xf32>
}

// transfer_write in MaskOp case not supported.
// Masked version is not supported

// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[VEC:.*]]: vector<14x8x16xf32>
// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
// CHECK-SAME: %[[DIM:.*]]: index, %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
// CHECK-NOT: vector.broadcast
// CHECK: vector.mask %0 { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
func.func @masked_non_permutation_xfer_write_fixed_width(
%dest : tensor<?x?x?x?xf32>,
%vec : vector<14x8x16xf32>,
%dim : index) -> tensor<?x?x?x?xf32> {
%dim : index,
%idx: index) -> tensor<?x?x?x?xf32> {

%c0 = arith.constant 0 : index
%mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
%res = vector.mask %mask {
vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0] {
vector.transfer_write %vec, %dest[%idx, %idx, %idx, %idx] {
in_bounds = [false, false, true],
permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
Expand All @@ -259,25 +266,23 @@ func.func @masked_non_permutation_xfer_write_fixed_width(

// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_fixed_width(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[IDX_1:.*]]: index,
// CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x4x2xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x4x2xf32> {
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x4xi1>
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x4xi1>
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
// CHECK: return %[[TRANSPOSE]] : vector<8x4x2xf32>
func.func @permutation_with_mask_xfer_read_fixed_width(
%mem: memref<?x?xf32>,
%dim_1: index,
%dim_2: index) -> (vector<8x4x2xf32>) {
%dim_2: index,
%idx: index) -> (vector<8x4x2xf32>) {

%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%pad = arith.constant 0.000000e+00 : f32

%mask = vector.create_mask %dim_2, %dim_1 : vector<2x4xi1>
%res = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask {
%res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
} : memref<?x?xf32>, vector<8x4x2xf32>
Expand All @@ -287,46 +292,45 @@ func.func @permutation_with_mask_xfer_read_fixed_width(

// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_scalable(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[IDX_1:.*]]: index,
// CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x[4]xi1>
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
// CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x[4]x2xf32> {
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x[4]xi1>
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
func.func @permutation_with_mask_xfer_read_scalable(
%mem: memref<?x?xf32>,
%dim_1: index,
%dim_2: index) -> (vector<8x[4]x2xf32>) {
%dim_2: index,
%idx: index) -> (vector<8x[4]x2xf32>) {

%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%pad = arith.constant 0.000000e+00 : f32

%mask = vector.create_mask %dim_2, %dim_1 : vector<2x[4]xi1>
%res = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask {
%res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
} : memref<?x?xf32>, vector<8x[4]x2xf32>

return %res : vector<8x[4]x2xf32>
}

// transfer_read in MaskOp case not supported.
// Masked version is not supported

// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
// CHECK-SAME: %[[DEST:.*]]: tensor<?x1xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<4x1xi1>
// CHECK-NOT: vector.transpose
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
func.func @masked_permutation_xfer_read_fixed_width(
%dest: tensor<?x1xf32>,
%mask : vector<4x1xi1>) {
%mask : vector<4x1xi1>,
%idx: index) {

%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%pad = arith.constant 0.000000e+00 : f32
%3 = vector.mask %mask {
vector.transfer_read %dest[%c0, %c0], %cst {
vector.transfer_read %dest[%idx, %idx], %pad {
permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>
} : tensor<?x1xf32>, vector<1x4x4xf32>
} : vector<4x1xi1> -> vector<1x4x4xf32>
Expand All @@ -337,18 +341,18 @@ func.func @masked_permutation_xfer_read_fixed_width(

// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>
// CHECK-NOT: vector.transpose
// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
func.func @masked_permutation_xfer_read_scalable(
%dest: tensor<?x?xf32>,
%mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
%mask : vector<2x[4]xi1>,
%idx: index) -> vector<8x[4]x2xf32> {

%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%pad = arith.constant 0.000000e+00 : f32

%res = vector.mask %mask {
vector.transfer_read %dest[%c0, %c0], %cst_0 {
vector.transfer_read %dest[%idx, %idx], %pad {
in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
} : tensor<?x?xf32>, vector<8x[4]x2xf32>
Expand Down Expand Up @@ -377,41 +381,41 @@ module attributes {transform.with_named_sequence} {

// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
// CHECK: func.func @transfer_read_reduce_rank_scalable(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: This line deserves a CHECK-LABEL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed it as well for permutation_with_mask_xfer_write_scalable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressing this in #123237. Alongside other improvements :)

// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
func.func @transfer_read_reduce_rank_scalable(
%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
%mem: memref<?x?x?x?xf32>, %idx: index) -> vector<8x[4]x2x3xf32> {

%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%pad = arith.constant 0.000000e+00 : f32

%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 {
%res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
in_bounds = [true, true, true, true],
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>

return %res : vector<8x[4]x2x3xf32>
}

// Masked case not supported.
// Masked version is not supported

// CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
// CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
// CHECK-SAME: %[[DIM:.*]]: index,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
// CHECK-NOT: vector.broadcast
// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
func.func @masked_transfer_read_reduce_rank(
%mem: memref<?x?x?x?xf32>,
%dim: index) -> vector<8x[4]x2x3xf32> {
%dim: index,
%idx: index) -> vector<8x[4]x2x3xf32> {

%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%pad = arith.constant 0.000000e+00 : f32
%mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>

%res = vector.mask %mask {
vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 {
vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
in_bounds = [true, true, true, true],
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
Expand Down
Loading