-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][gpu] Support extf before contract when converting to MMA ops #91988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This commit allows `inferFragType` to see through all arith.ext op users before reaching contract op for figuring out the fragment type.
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Lei Zhang (antiagainst) ChangesThis commit allows Full diff: https://github.com/llvm/llvm-project/pull/91988.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 782cc92f83fee..ad7408bb06fc1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -515,6 +515,13 @@ struct CombineTransferReadOpTranspose final
// TODO: Change the GPU dialect to abstract the layout at the this level and
// only care about it during lowering to NVVM.
static const char *inferFragType(Operation *op) {
+ // We can have arith.ext ops before reaching contract ops. See through them.
+ if (op->hasOneUse()) {
+ Operation *extOp = *op->user_begin();
+ if (isa<arith::ExtFOp, arith::ExtUIOp, arith::ExtSIOp>(extOp))
+ return inferFragType(extOp);
+ }
+
for (Operation *users : op->getUsers()) {
auto contract = dyn_cast<vector::ContractionOp>(users);
if (!contract)
@@ -560,13 +567,12 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
if (op->hasOneUse()) {
auto *user = *op->user_begin();
// Infer the signedness of the mma type from the integer extend.
- bool isSignedExtend = isa<arith::ExtSIOp>(user);
- if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
+ if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
elType = IntegerType::get(
op.getContext(), cast<IntegerType>(elType).getWidth(),
- isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
+ isa<arith::ExtSIOp>(user) ? IntegerType::Signed
+ : IntegerType::Unsigned);
mappingResult = user->getResult(0);
- fragType = inferFragType(user);
}
}
gpu::MMAMatrixType type =
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 962ed7de584a2..8526ff1392599 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -490,3 +490,30 @@ func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector
}
// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f16_to_f32_read
+// CHECK: %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+// CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[AE:.+]] = gpu.subgroup_mma_elementwise extf %[[A]] : (!gpu.mma_matrix<16x16xf16, "AOp">) -> !gpu.mma_matrix<16x16xf32, "AOp">
+// CHECK: %[[CE:.+]] = gpu.subgroup_mma_elementwise extf %[[C]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+// CHECK: %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+// CHECK: %[[BE:.+]] = gpu.subgroup_mma_elementwise extf %[[B]] : (!gpu.mma_matrix<16x16xf16, "BOp">) -> !gpu.mma_matrix<16x16xf32, "BOp">
+// CHECK: gpu.subgroup_mma_compute %[[AE]], %[[BE]], %[[CE]]
+func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %Aext = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
+ %Bext = arith.extf %B : vector<16x16xf16> to vector<16x16xf32>
+ %Cext = arith.extf %C : vector<16x16xf16> to vector<16x16xf32>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %Aext, %Bext, %Cext : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+ vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This commit allows
inferFragType
to see through all arith.ext op users before reaching contract op for figuring out the fragment type.