Skip to content

Commit c346cf6

Browse files
committed
[mlir][vector] Update tests for xfer permutation lowering (4/N)
* Document the remaining test cases, add a note that these are exercising `TransferOpReduceRank` (addresses an existing TODO). * Add missing cases (fixed-width vectors). * Remove scalable from the negative test (the masked case) - this test will also fail with fixed-width vectors. For consistency, lets make all negative test use fixed-width vectors.
1 parent 41be5bb commit c346cf6

File tree

1 file changed

+59
-10
lines changed

1 file changed

+59
-10
lines changed

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -363,17 +363,40 @@ func.func @xfer_read_minor_identity_transposed_masked_scalable(
363363
}
364364

365365
///----------------------------------------------------------------------------------------
366-
/// vector.transfer_read
366+
/// [Pattern: TransferOpReduceRank]
367+
///
368+
/// IN: vector.transfer_read (minor identity map + broadcast)
369+
/// OUT: vector.transfer_read + vector.broadcast
367370
///----------------------------------------------------------------------------------------
368371
/// TODO: Review and categorize
369372

373+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims
374+
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x4x2x3xf32> {
375+
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
376+
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
377+
// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
378+
func.func @xfer_read_minor_identitiy_bcast_dims(
379+
%mem: memref<?x?x?x?xf32>,
380+
%idx: index) -> vector<8x4x2x3xf32> {
381+
382+
%pad = arith.constant 0.000000e+00 : f32
383+
384+
%res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
385+
in_bounds = [true, true, true, true],
386+
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
387+
} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32>
388+
389+
return %res : vector<8x4x2x3xf32>
390+
}
391+
370392
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
371393
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
372394
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
373395
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
374396
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
375397
func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
376-
%mem: memref<?x?x?x?xf32>, %idx: index) -> vector<8x[4]x2x3xf32> {
398+
%mem: memref<?x?x?x?xf32>,
399+
%idx: index) -> vector<8x[4]x2x3xf32> {
377400

378401
%pad = arith.constant 0.000000e+00 : f32
379402

@@ -385,31 +408,57 @@ func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
385408
return %res : vector<8x[4]x2x3xf32>
386409
}
387410

411+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_with_mask
412+
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>
413+
// CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1>
414+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32>
415+
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
416+
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
417+
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
418+
// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
419+
func.func @xfer_read_minor_identitiy_bcast_dims_with_mask(
420+
%mem: memref<?x?x?x?xf32>,
421+
%mask: vector<4x3xi1>,
422+
%idx: index) -> vector<8x4x2x3xf32> {
423+
424+
%pad = arith.constant 0.000000e+00 : f32
425+
426+
%res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad, %mask {
427+
in_bounds = [true, true, true, true],
428+
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
429+
} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32>
430+
431+
return %res : vector<8x4x2x3xf32>
432+
}
433+
388434
// Masked version is not supported
389435

390436
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_masked
391437
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
392-
// CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1>
393-
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
438+
// CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1>
439+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32> {
394440
// CHECK-NOT: vector.broadcast
395-
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
441+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32> } : vector<4x3xi1> -> vector<8x4x2x3xf32>
396442
func.func @xfer_read_minor_identitiy_bcast_dims_masked(
397443
%mem: memref<?x?x?x?xf32>,
398-
%mask: vector<[4]x3xi1>,
399-
%idx: index) -> vector<8x[4]x2x3xf32> {
444+
%mask: vector<4x3xi1>,
445+
%idx: index) -> vector<8x4x2x3xf32> {
400446

401447
%pad = arith.constant 0.000000e+00 : f32
402448

403449
%res = vector.mask %mask {
404450
vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
405451
in_bounds = [true, true, true, true],
406452
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
407-
} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
408-
} : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
453+
} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32>
454+
} : vector<4x3xi1> -> vector<8x4x2x3xf32>
409455

410-
return %res : vector<8x[4]x2x3xf32>
456+
return %res : vector<8x4x2x3xf32>
411457
}
412458

459+
///----------------------------------------------------------------------------------------
460+
// TD sequence
461+
///----------------------------------------------------------------------------------------
413462
module attributes {transform.with_named_sequence} {
414463
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
415464
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)