Skip to content

[mlir][gpu] Support extf before contract when converting to MMA ops #91988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2024

Conversation

antiagainst
Copy link
Member

This commit allows inferFragType to see through all arith.ext op users before reaching contract op for figuring out the fragment type.

This commit allows `inferFragType` to see through all arith.ext
op users before reaching contract op for figuring out the fragment
type.
@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Lei Zhang (antiagainst)

Changes

This commit allows inferFragType to see through all arith.ext op users before reaching contract op for figuring out the fragment type.


Full diff: https://github.com/llvm/llvm-project/pull/91988.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+10-4)
  • (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir (+27)
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 782cc92f83fee..ad7408bb06fc1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -515,6 +515,13 @@ struct CombineTransferReadOpTranspose final
 // TODO: Change the GPU dialect to abstract the layout at the this level and
 // only care about it during lowering to NVVM.
 static const char *inferFragType(Operation *op) {
+  // We can have arith.ext ops before reaching contract ops. See through them.
+  if (op->hasOneUse()) {
+    Operation *extOp = *op->user_begin();
+    if (isa<arith::ExtFOp, arith::ExtUIOp, arith::ExtSIOp>(extOp))
+      return inferFragType(extOp);
+  }
+
   for (Operation *users : op->getUsers()) {
     auto contract = dyn_cast<vector::ContractionOp>(users);
     if (!contract)
@@ -560,13 +567,12 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
   if (op->hasOneUse()) {
     auto *user = *op->user_begin();
     // Infer the signedness of the mma type from the integer extend.
-    bool isSignedExtend = isa<arith::ExtSIOp>(user);
-    if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
+    if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
       elType = IntegerType::get(
           op.getContext(), cast<IntegerType>(elType).getWidth(),
-          isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
+          isa<arith::ExtSIOp>(user) ? IntegerType::Signed
+                                    : IntegerType::Unsigned);
       mappingResult = user->getResult(0);
-      fragType = inferFragType(user);
     }
   }
   gpu::MMAMatrixType type =
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 962ed7de584a2..8526ff1392599 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -490,3 +490,30 @@ func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector
 }
 
 // -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f16_to_f32_read
+//       CHECK:    %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//       CHECK:    %[[C:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:    %[[AE:.+]] = gpu.subgroup_mma_elementwise  extf %[[A]] : (!gpu.mma_matrix<16x16xf16, "AOp">) -> !gpu.mma_matrix<16x16xf32, "AOp">
+//       CHECK:    %[[CE:.+]] = gpu.subgroup_mma_elementwise  extf %[[C]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+//       CHECK:    %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//       CHECK:    %[[BE:.+]] = gpu.subgroup_mma_elementwise  extf %[[B]] : (!gpu.mma_matrix<16x16xf16, "BOp">) -> !gpu.mma_matrix<16x16xf32, "BOp">
+//       CHECK:    gpu.subgroup_mma_compute %[[AE]], %[[BE]], %[[CE]]
+func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %Aext = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
+  %Bext = arith.extf %B : vector<16x16xf16> to vector<16x16xf32>
+  %Cext = arith.extf %C : vector<16x16xf16> to vector<16x16xf32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                        %Aext, %Bext, %Cext : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+  vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
+  return
+}

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@antiagainst antiagainst merged commit a037d88 into llvm:main May 13, 2024
@antiagainst antiagainst deleted the mma-cast branch May 13, 2024 19:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants