Skip to content

Commit 820aa43

Browse files
authored
[mlir][vector] Update tests for xfer permutation lowering (4/N) (#127624)
* Document the remaining test cases, add a note that these are exercising `TransferOpReduceRank` (addresses an existing TODO). * Add missing cases (for fixed-width and scalable vectors). * Remove scalable vectors from the negative test (the masked case) - this test will also fail with fixed-width vectors. For consistency, lets make all negative test use fixed-width vectors.
1 parent dfa3af9 commit 820aa43

File tree

1 file changed

+82
-11
lines changed

1 file changed

+82
-11
lines changed

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -363,17 +363,39 @@ func.func @xfer_read_minor_identity_transposed_masked_scalable(
363363
}
364364

365365
///----------------------------------------------------------------------------------------
366-
/// vector.transfer_read
366+
/// [Pattern: TransferOpReduceRank]
367+
///
368+
/// IN: vector.transfer_read (minor identity map + broadcast)
369+
/// OUT: vector.transfer_read + vector.broadcast
367370
///----------------------------------------------------------------------------------------
368-
/// TODO: Review and categorize
371+
372+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims
373+
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x4x2x3xf32> {
374+
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
375+
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
376+
// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
377+
func.func @xfer_read_minor_identitiy_bcast_dims(
378+
%mem: memref<?x?x?x?xf32>,
379+
%idx: index) -> vector<8x4x2x3xf32> {
380+
381+
%pad = arith.constant 0.000000e+00 : f32
382+
383+
%res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
384+
in_bounds = [true, true, true, true],
385+
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
386+
} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32>
387+
388+
return %res : vector<8x4x2x3xf32>
389+
}
369390

370391
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
371392
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
372393
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
373394
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
374395
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
375396
func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
376-
%mem: memref<?x?x?x?xf32>, %idx: index) -> vector<8x[4]x2x3xf32> {
397+
%mem: memref<?x?x?x?xf32>,
398+
%idx: index) -> vector<8x[4]x2x3xf32> {
377399

378400
%pad = arith.constant 0.000000e+00 : f32
379401

@@ -385,31 +407,80 @@ func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
385407
return %res : vector<8x[4]x2x3xf32>
386408
}
387409

410+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_with_mask
411+
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>
412+
// CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1>
413+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32>
414+
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
415+
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
416+
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
417+
// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
418+
func.func @xfer_read_minor_identitiy_bcast_dims_with_mask(
419+
%mem: memref<?x?x?x?xf32>,
420+
%mask: vector<4x3xi1>,
421+
%idx: index) -> vector<8x4x2x3xf32> {
422+
423+
%pad = arith.constant 0.000000e+00 : f32
424+
425+
%res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad, %mask {
426+
in_bounds = [true, true, true, true],
427+
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
428+
} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32>
429+
430+
return %res : vector<8x4x2x3xf32>
431+
}
432+
433+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_with_mask_scalable
434+
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>
435+
// CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1>
436+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32>
437+
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
438+
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
439+
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
440+
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
441+
func.func @xfer_read_minor_identitiy_bcast_dims_with_mask_scalable(
442+
%mem: memref<?x?x?x?xf32>,
443+
%mask: vector<[4]x3xi1>,
444+
%idx: index) -> vector<8x[4]x2x3xf32> {
445+
446+
%pad = arith.constant 0.000000e+00 : f32
447+
448+
%res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad, %mask {
449+
in_bounds = [true, true, true, true],
450+
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
451+
} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
452+
453+
return %res : vector<8x[4]x2x3xf32>
454+
}
455+
388456
// Masked version is not supported
389457

390458
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_masked
391459
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
392-
// CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1>
393-
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
460+
// CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1>
461+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32> {
394462
// CHECK-NOT: vector.broadcast
395-
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
463+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32> } : vector<4x3xi1> -> vector<8x4x2x3xf32>
396464
func.func @xfer_read_minor_identitiy_bcast_dims_masked(
397465
%mem: memref<?x?x?x?xf32>,
398-
%mask: vector<[4]x3xi1>,
399-
%idx: index) -> vector<8x[4]x2x3xf32> {
466+
%mask: vector<4x3xi1>,
467+
%idx: index) -> vector<8x4x2x3xf32> {
400468

401469
%pad = arith.constant 0.000000e+00 : f32
402470

403471
%res = vector.mask %mask {
404472
vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
405473
in_bounds = [true, true, true, true],
406474
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
407-
} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
408-
} : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
475+
} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32>
476+
} : vector<4x3xi1> -> vector<8x4x2x3xf32>
409477

410-
return %res : vector<8x[4]x2x3xf32>
478+
return %res : vector<8x4x2x3xf32>
411479
}
412480

481+
///----------------------------------------------------------------------------------------
482+
// TD sequence
483+
///----------------------------------------------------------------------------------------
413484
module attributes {transform.with_named_sequence} {
414485
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
415486
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)