Skip to content

Commit 42bba97

Browse files
authored
[mlir] Extend CombineTransferReadOpTranspose pattern to handle extf ops. (llvm#74754)
This patch modifies the CombineTransferReadOpTranspose pattern to handle extf ops. Also adds a test which shows the transpose getting folded into the transfer_read.
1 parent bfd41c3 commit 42bba97

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,8 @@ struct CombineTransferReadOpTranspose final
455455
Type resultType = op.getType();
456456
Operation *extOp;
457457
if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
458-
(extOp = source.getDefiningOp<arith::ExtUIOp>())) {
458+
(extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
459+
(extOp = source.getDefiningOp<arith::ExtFOp>())) {
459460
source = extOp->getOperand(0);
460461
resultType =
461462
VectorType::get(cast<VectorType>(resultType).getShape(),
@@ -493,9 +494,12 @@ struct CombineTransferReadOpTranspose final
493494
if (isa<arith::ExtSIOp>(extOp))
494495
result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
495496
.getResult();
496-
else
497+
else if (isa<arith::ExtUIOp>(extOp))
497498
result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
498499
.getResult();
500+
else
501+
result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
502+
.getResult();
499503
}
500504

501505
rewriter.replaceOp(op, result);

mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,33 @@ func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf
460460
vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
461461
return
462462
}
463+
464+
// -----
465+
466+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
467+
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
468+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
469+
470+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
471+
// CHECK-LABEL: func @fold_transpose_into_transfer_read(
472+
// CHECK-SAME: %[[ALLOC:.+]]: memref<64x128xf16>
473+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
474+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16
475+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true], permutation_map = #[[$MAP]]}
476+
// CHECK: %[[EXTF1:.+]] = arith.extf %[[READ]]
477+
// CHECK-NOT: vector.transpose
478+
// CHECK: %[[RESULT:.+]] = vector.contract
479+
func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector: vector<32x128xf16>, %alloc2: memref<32x64xf32>) {
480+
%c0 = arith.constant 0 : index
481+
%cst = arith.constant 0.000000e+00 : f16
482+
%init = arith.constant dense<0.000000e+00> : vector<32x64xf32>
483+
%0 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x128xf16>, vector<64x128xf16>
484+
%1 = arith.extf %0 : vector<64x128xf16> to vector<64x128xf32>
485+
%2 = arith.extf %vector : vector<32x128xf16> to vector<32x128xf32>
486+
%3 = vector.transpose %1, [1, 0] : vector<64x128xf32> to vector<128x64xf32>
487+
%4 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %init : vector<32x128xf32>, vector<128x64xf32> into vector<32x64xf32>
488+
vector.transfer_write %4, %alloc2[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32>
489+
return
490+
}
491+
492+
// -----

0 commit comments

Comments
 (0)