Skip to content

[mlir][vector] Add tests xfer-permute-lowering (nfc)(2/n) #96033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 70 additions & 15 deletions mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s

// TODO: Align naming with e.g. vector-transfer-flatten.mlir
// TODO: Replace %arg0 with %vec
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, do you think it would make sense to refactor all args the same way you did in vector-transfer-flatten.mlir ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge +1 - I will update the comment to reflect that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you updated the file AND accepted my suggestion generating repetition 😕

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes! Fixed in the latest commit :)

// TODO: Replace index args as %idx
// TODO: Align argument definition in CHECKS with function body.

///----------------------------------------------------------------------------------------
/// vector.transfer_write -> vector.transpose + vector.transfer_write
/// [Pattern: TransferWritePermutationLowering]
Expand All @@ -12,8 +17,8 @@
/// _is_ a minor identity

// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) {
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) {
// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
// CHECK: vector.transfer_write
// CHECK-NOT: permutation_map
Expand All @@ -31,6 +36,31 @@ func.func @xfer_write_transposing_permutation_map(
return
}

// Even with out-of-bounds, it is safe to apply this pattern
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
// Expect the in_bounds attribute to be preserved. Since we don't print it when
// all flags are "false", it should not appear in the output.
// CHECK-NOT: in_bounds
// CHECK: vector.transfer_write
// CHECK-NOT: permutation_map
// CHECK-SAME: %[[TR]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
func.func @xfer_write_transposing_permutation_map_out_of_bounds(
%arg0: vector<4x8xi16>,
%mem: memref<2x2x?x?xi16>) {

%c0 = arith.constant 0 : index
vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
in_bounds = [false, false],
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
} : vector<4x8xi16>, memref<2x2x?x?xi16>

return
}

// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
Expand Down Expand Up @@ -83,19 +113,44 @@ func.func @xfer_write_transposing_permutation_map_masked(
/// * vector.broadcast + vector.transpose + vector.transfer_write with a map
/// which _is_ a permutation of a minor identity

// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
func.func @permutation_with_mask_xfer_write_fixed_width(%mem : memref<?x?xf32>, %base1 : index,
%base2 : index) {

%fn1 = arith.constant -2.0 : f32
%vf0 = vector.splat %fn1 : vector<7xf32>
%mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
vector.transfer_write %vf0, %mem[%base1, %base2], %mask
// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index) {
// CHECK: %[[BC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
// CHECK: %[[TR:.*]] = vector.transpose %[[BC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{\[}}%[[IDX_1]], %[[IDX_2]]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
func.func @xfer_write_non_transposing_permutation_map(
%mem : memref<?x?xf32>,
%arg0 : vector<7xf32>,
%idx_1 : index,
%idx_2 : index) {

vector.transfer_write %arg0, %mem[%idx_1, %idx_2]
{permutation_map = affine_map<(d0, d1) -> (d0)>}
: vector<7xf32>, memref<?x?xf32>
return
}

// Even with out-of-bounds, it is safe to apply this pattern
// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[MASK:.*]]: vector<7xi1>) {
// CHECK: %[[BC_VEC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
// CHECK: %[[BC_MASK:.*]] = vector.broadcast %[[MASK]] : vector<7xi1> to vector<1x7xi1>
// CHECK: %[[TR_MASK:.*]] = vector.transpose %[[BC_MASK]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
// CHECK: %[[TR_VEC:.*]] = vector.transpose %[[BC_VEC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
// CHECK: vector.transfer_write %[[TR_VEC]], %[[MEM]]{{\[}}%[[IDX_1]], %[[IDX_2]]], %[[TR_MASK]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
%mem : memref<?x?xf32>,
%arg0 : vector<7xf32>,
%idx_1 : index,
%idx_2 : index,
%mask : vector<7xi1>) {

vector.transfer_write %arg0, %mem[%idx_1, %idx_2], %mask
{permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]}
: vector<7xf32>, memref<?x?xf32>
return
Expand Down
Loading