Skip to content

Commit 22c6e7b

Browse files
[mlir][nvvm] Fix support for tf32 data type in mma.sync
The NVVM dialect test coverage for all possible type/shape combinations in the `nvvm.mma.sync` op is mostly complete. However, there were tests missing for TF32 datatype support. This change adds tests for the one relevant shape/type combination. This uncovered a small bug in the op verifier, which this change also fixes. Differential Revision: https://reviews.llvm.org/D124975
1 parent 6385c03 commit 22c6e7b

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
8181
return NVVM::MMATypes::f64;
8282
if (operandElType.isF16() || operandElType == half2Type)
8383
return NVVM::MMATypes::f16;
84-
if (operandElType.isF32())
84+
if (operandElType.isF32() && isAccumulator)
8585
return NVVM::MMATypes::f32;
86+
if (operandElType.isF32() && !isAccumulator)
87+
return NVVM::MMATypes::tf32;
8688
if (operandElType.isa<IntegerType>()) {
8789
if (isAccumulator)
8890
return NVVM::MMATypes::s32;
@@ -291,7 +293,7 @@ ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
291293
parser.getNameLoc(),
292294
"expected one type for each operand segment but got " +
293295
Twine(operandTypes.size()) + " types");
294-
for (const auto& iter : llvm::enumerate(operandTypes)) {
296+
for (const auto &iter : llvm::enumerate(operandTypes)) {
295297
auto &frag = frags[iter.index()];
296298
frag.regTypes.resize(frag.regs.size(), iter.value());
297299
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
@@ -376,8 +378,9 @@ LogicalResult MmaOp::verify() {
376378
switch (multiplicandAPtxType().getValue()) {
377379
case MMATypes::tf32:
378380
kFactor = 4;
381+
multiplicandFragType = i32Ty;
379382
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
380-
context, {i32Ty, i32Ty, i32Ty, i32Ty}));
383+
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
381384
break;
382385
case MMATypes::f16:
383386
case MMATypes::bf16:

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,17 @@ func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
152152
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
153153
}
154154

155+
func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
156+
%b0 : i32,
157+
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
158+
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
159+
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
160+
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
161+
multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
162+
shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
163+
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
164+
}
165+
155166
func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
156167
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
157168
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,17 @@ llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
203203
llvm.return %0 : !llvm.struct<(f64, f64)>
204204
}
205205

206+
llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
207+
%b0 : i32,
208+
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
209+
// CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32
210+
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
211+
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
212+
multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
213+
shape = {m = 16 : i32, n = 8 : i32, k = 4 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
214+
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
215+
}
216+
206217
// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
207218
// in the LLVM NVPTX backend.
208219
// CHECK-LABEL: @gpu_wmma_load_op

0 commit comments

Comments
 (0)