Skip to content

Commit a037d88

Browse files
authored
[mlir][gpu] Support extf before contract when converting to MMA ops (#91988)
This commit allows `inferFragType` to see through all arith.ext op and other elementwise users before reaching contract op for figuring out the fragment type.
1 parent 5944579 commit a037d88

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,14 @@ struct CombineTransferReadOpTranspose final
515515
// TODO: Change the GPU dialect to abstract the layout at the this level and
516516
// only care about it during lowering to NVVM.
517517
static const char *inferFragType(Operation *op) {
518+
// We can have arith.ext ops before reaching contract ops. See through them
519+
// and other kinds of elementwise ops.
520+
if (op->hasOneUse()) {
521+
Operation *userOp = *op->user_begin();
522+
if (userOp->hasTrait<OpTrait::Elementwise>())
523+
return inferFragType(userOp);
524+
}
525+
518526
for (Operation *users : op->getUsers()) {
519527
auto contract = dyn_cast<vector::ContractionOp>(users);
520528
if (!contract)
@@ -560,13 +568,12 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
560568
if (op->hasOneUse()) {
561569
auto *user = *op->user_begin();
562570
// Infer the signedness of the mma type from the integer extend.
563-
bool isSignedExtend = isa<arith::ExtSIOp>(user);
564-
if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
571+
if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
565572
elType = IntegerType::get(
566573
op.getContext(), cast<IntegerType>(elType).getWidth(),
567-
isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
574+
isa<arith::ExtSIOp>(user) ? IntegerType::Signed
575+
: IntegerType::Unsigned);
568576
mappingResult = user->getResult(0);
569-
fragType = inferFragType(user);
570577
}
571578
}
572579
gpu::MMAMatrixType type =

mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,30 @@ func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector
490490
}
491491

492492
// -----
493+
494+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
495+
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
496+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
497+
498+
// CHECK-LABEL: func @cast_f16_to_f32_read
499+
// CHECK: %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
500+
// CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
501+
// CHECK: %[[AE:.+]] = gpu.subgroup_mma_elementwise extf %[[A]] : (!gpu.mma_matrix<16x16xf16, "AOp">) -> !gpu.mma_matrix<16x16xf32, "AOp">
502+
// CHECK: %[[CE:.+]] = gpu.subgroup_mma_elementwise extf %[[C]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
503+
// CHECK: %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
504+
// CHECK: %[[BE:.+]] = gpu.subgroup_mma_elementwise extf %[[B]] : (!gpu.mma_matrix<16x16xf16, "BOp">) -> !gpu.mma_matrix<16x16xf32, "BOp">
505+
// CHECK: gpu.subgroup_mma_compute %[[AE]], %[[BE]], %[[CE]]
506+
func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) {
507+
%c0 = arith.constant 0 : index
508+
%cst = arith.constant 0.000000e+00 : f16
509+
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
510+
%B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
511+
%C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
512+
%Aext = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
513+
%Bext = arith.extf %B : vector<16x16xf16> to vector<16x16xf32>
514+
%Cext = arith.extf %C : vector<16x16xf16> to vector<16x16xf32>
515+
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
516+
%Aext, %Bext, %Cext : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
517+
vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
518+
return
519+
}

0 commit comments

Comments
 (0)