Skip to content

Commit 985f7ff

Browse files
committed
[mlir][gpu] Add support for integer types in gpu.subgroup_mma ops
The signedness is carried by `!gpu.mma_matrix` types to most closely match the Cooperative Matrix specification which determines signedness with the type (and sometimes the operation). See: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/NV/SPV_NV_cooperative_matrix.html To handle the lowering from vector to gpu, ops such as arith.extsi are pattern matched next to `vector.transfer_read` and `vector.contract` to determine the signedness of the matrix type. Enables s8 and u8 WMMA types in NVVM for the GPUToNVVM conversion. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D143223
1 parent 622be09 commit 985f7ff

File tree

11 files changed

+265
-41
lines changed

11 files changed

+265
-41
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def GPU_MMAMatrix : DialectType<
101101
GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
102102

103103
// Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
104-
def GPU_MMAMemRef : MemRefOf<[F16, F32, VectorOfRankAndType<[1], [F16, F32]>]>;
104+
def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
105105

106106
class MMAMatrixOf<list<Type> allowedTypes> :
107107
ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,10 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
11501150
matrix which eventually allows the lowering to determine the size of each
11511151
row. If the `transpose` attribute is present then the op does a transposed load.
11521152

1153+
For integer types, the resulting `!gpu.mma_matrix` type needs to specify the
1154+
signedness of the data if the matrix type is an `A` or `B` operand for
1155+
`gpu.subgroup_mma_compute`.
1156+
11531157
This op is often meant to be used along with `gpu.subgroup_mma_store_matrix` and
11541158
`gpu.subgroup_mma_compute`.
11551159

@@ -1201,7 +1205,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
12011205
```
12021206
}];
12031207

1204-
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
1208+
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
12051209
Arg<GPU_MMAMemRef, "",[MemWrite]>:$dstMemref,
12061210
Variadic<Index>:$indices,
12071211
IndexAttr:$leadDimension,
@@ -1227,11 +1231,15 @@ def GPU_SubgroupMmaComputeOp
12271231
as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
12281232
the operation held by all threads in a subgroup. `a_transpose` or
12291233
`b_transpose` if present, signify that the respective operand was loaded in a
1230-
transposed manner. The transpose opernads are required to map to correct
1234+
transposed manner. The transpose operands are required to map to correct
12311235
underlying intrisics but they currently do not seem to affect correctness
12321236
even if they are absent given that the operands were loaded correctly using
12331237
the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op.
12341238

1239+
For integer types, the `A` and `B` matrices carry their signedness with their
1240+
types. The accumulator type is expected to be signless and imply a signed integer
1241+
with a greater width than the other two operands.
1242+
12351243
This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
12361244
`gpu.subgroup_mma_load_matrix` ops.
12371245

@@ -1244,9 +1252,9 @@ def GPU_SubgroupMmaComputeOp
12441252
```
12451253
}];
12461254

1247-
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$opA,
1248-
Arg<MMAMatrixOf<[F16, F32]>>:$opB,
1249-
Arg<MMAMatrixOf<[F16, F32]>>:$opC,
1255+
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
1256+
Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
1257+
Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
12501258
OptionalAttr<UnitAttr>:$a_transpose,
12511259
OptionalAttr<UnitAttr>:$b_transpose);
12521260

@@ -1288,7 +1296,7 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
12881296
```
12891297
}];
12901298

1291-
let arguments = (ins AnyTypeOf<[F16, F32]>:$value);
1299+
let arguments = (ins AnyTypeOf<[SI8, UI8, I32, F16, F32]>:$value);
12921300

12931301
let results = (outs GPU_MMAMatrix:$res);
12941302

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ enum NVVMMemorySpace {
3737
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
3838
/// WMMA_REGS structure.
3939
std::pair<mlir::Type, unsigned> inferMMAType(mlir::NVVM::MMATypes type,
40-
mlir::NVVM::MMAFrag frag,
40+
mlir::NVVM::MMAFrag frag, int nRow,
41+
int nCol,
4142
mlir::MLIRContext *context);
4243
} // namespace NVVM
4344
} // namespace mlir

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,16 +385,20 @@ class NVVM_MMA_OPS {
385385
list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS<
386386
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
387387
["f16"], [], ["f16", "f32"], []>.ret;
388+
list<list<WMMA_REGS>> i8_wmma_ops = MMA_OPS<
389+
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
390+
["s8","u8"], [], ["s32"], []>.ret;
388391
list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
389392
tf32_wmma_ops,
390-
fp_wmma_ops);
393+
fp_wmma_ops,
394+
i8_wmma_ops);
391395

392396
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
393397
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
394-
["a", "b"], ["f16"]>.ret;
398+
["a", "b"], ["f16","s8","u8"]>.ret;
395399
list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
396400
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
397-
["c", "d"], ["f16", "f32"]>.ret;
401+
["c", "d"], ["f16", "f32","s32"]>.ret;
398402
list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS<
399403
[GEOM<16, 16, 8>],
400404
["a", "b"], ["tf32"]>.ret;

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
5757
if (type.getElementType().isF32())
5858
return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
5959
: NVVM::MMATypes::tf32;
60+
61+
if (type.getElementType().isSignedInteger(8))
62+
return NVVM::MMATypes::s8;
63+
// Accumulator type is signless and implies signed.
64+
if (type.getElementType().isInteger(32))
65+
return NVVM::MMATypes::s32;
6066
llvm_unreachable("Unsupported type");
6167
}
6268

@@ -106,8 +112,11 @@ struct WmmaLoadOpToNVVMLowering
106112
}
107113
NVVM::MMAFrag frag = convertOperand(retType.getOperand());
108114
// Check that there is an exisiting instruction for the combination we need.
109-
if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
115+
if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) {
116+
llvm::errs() << "No matching intrinsic " << m << " " << n << " " << k
117+
<< "\n";
110118
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
119+
}
111120

112121
Type resType = convertMMAToLLVMType(retType);
113122
Location loc = op->getLoc();
@@ -366,8 +375,10 @@ struct WmmaElementwiseOpToNVVMLowering
366375
LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
367376
NVVM::MMAFrag frag = convertOperand(type.getOperand());
368377
NVVM::MMATypes eltType = getElementType(type);
378+
auto nRow = type.getShape()[0];
379+
auto nCol = type.getShape()[1];
369380
std::pair<Type, unsigned> typeInfo =
370-
NVVM::inferMMAType(eltType, frag, type.getContext());
381+
NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
371382
return LLVM::LLVMStructType::getLiteral(
372383
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
373384
}

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
140140
return false;
141141
if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
142142
return false;
143+
144+
// Only allow integer types if the signedness can be inferred.
145+
if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8))
146+
if (!readOp->hasOneUse() || !isa<arith::ExtSIOp>(*readOp->user_begin()))
147+
return false;
148+
143149
AffineMap map = readOp.getPermutationMap();
144150
OpBuilder b(readOp.getContext());
145151
AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
@@ -185,8 +191,16 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
185191

186192
/// Return true if this is a broadcast from scalar to a 2D vector.
187193
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
188-
return broadcastOp.getVectorType().getRank() == 2 &&
189-
broadcastOp.getSource().getType().isa<FloatType>();
194+
return broadcastOp.getVectorType().getRank() == 2;
195+
}
196+
197+
/// Return true if this signed extend op can be folded into a contract op.
198+
static bool signedExtendSupportsMMAMatrixType(arith::ExtSIOp extOp) {
199+
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
200+
return false;
201+
return llvm::all_of(extOp->getUsers(), [](Operation *user) {
202+
return isa<vector::ContractionOp>(user);
203+
});
190204
}
191205

192206
/// Return the MMA elementwise enum associated with `op` if it is supported.
@@ -268,6 +282,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
268282
return constantSupportsMMAMatrixType(constant);
269283
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
270284
return broadcastSupportsMMAMatrixType(broadcast);
285+
if (auto extend = dyn_cast<arith::ExtSIOp>(op))
286+
return signedExtendSupportsMMAMatrixType(extend);
271287
return elementwiseSupportsMMAMatrixType(op);
272288
}
273289

@@ -411,8 +427,18 @@ struct CombineTransferReadOpTranspose final
411427

412428
LogicalResult matchAndRewrite(vector::TransposeOp op,
413429
PatternRewriter &rewriter) const override {
414-
auto transferReadOp =
415-
op.getVector().getDefiningOp<vector::TransferReadOp>();
430+
// Look through integer extend ops.
431+
Value source = op.getVector();
432+
auto extOp = source.getDefiningOp<arith::ExtSIOp>();
433+
auto resultType = op.getVectorType();
434+
if (extOp) {
435+
source = extOp.getOperand();
436+
resultType =
437+
VectorType::get(resultType.getShape(),
438+
source.getType().cast<VectorType>().getElementType());
439+
}
440+
441+
auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
416442
if (!transferReadOp)
417443
return failure();
418444

@@ -431,11 +457,23 @@ struct CombineTransferReadOpTranspose final
431457
AffineMap::getPermutationMap(permU, op.getContext());
432458
AffineMap newMap =
433459
permutationMap.compose(transferReadOp.getPermutationMap());
434-
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
435-
op, op.getType(), transferReadOp.getSource(),
436-
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
437-
transferReadOp.getPadding(), transferReadOp.getMask(),
438-
transferReadOp.getInBoundsAttr());
460+
461+
auto loc = op.getLoc();
462+
Value result =
463+
rewriter
464+
.create<vector::TransferReadOp>(
465+
loc, resultType, transferReadOp.getSource(),
466+
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
467+
transferReadOp.getPadding(), transferReadOp.getMask(),
468+
transferReadOp.getInBoundsAttr())
469+
.getResult();
470+
471+
// Fuse through the integer extend op.
472+
if (extOp)
473+
result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
474+
.getResult();
475+
476+
rewriter.replaceOp(op, result);
439477
return success();
440478
}
441479
};
@@ -479,14 +517,26 @@ static void convertTransferReadOp(vector::TransferReadOp op,
479517
stride = 0;
480518
}
481519
assert(stride);
520+
Value mappingResult = op.getResult();
521+
auto elType = op.getVectorType().getElementType();
482522
const char *fragType = inferFragType(op);
523+
if (op->hasOneUse()) {
524+
auto extOp = dyn_cast<arith::ExtSIOp>(*op->user_begin());
525+
// Infer the signedness of the mma type from the signed extend.
526+
if (extOp) {
527+
elType = IntegerType::get(op.getContext(),
528+
elType.cast<IntegerType>().getWidth(),
529+
IntegerType::Signed);
530+
mappingResult = extOp.getResult();
531+
fragType = inferFragType(extOp);
532+
}
533+
}
483534
gpu::MMAMatrixType type =
484-
gpu::MMAMatrixType::get(op.getVectorType().getShape(),
485-
op.getVectorType().getElementType(), fragType);
535+
gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
486536
Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
487537
op.getLoc(), type, op.getSource(), op.getIndices(),
488538
b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr());
489-
valueMapping[op.getResult()] = load;
539+
valueMapping[mappingResult] = load;
490540
}
491541

492542
static void convertTransferWriteOp(vector::TransferWriteOp op,

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
7878
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
7979

8080
bool MMAMatrixType::isValidElementType(Type elementType) {
81-
return elementType.isF16() || elementType.isF32();
81+
return elementType.isF16() || elementType.isF32() ||
82+
elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
83+
elementType.isInteger(32);
8284
}
8385

8486
LogicalResult
@@ -93,7 +95,8 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
9395
return emitError() << "MMAMatrixType must have exactly two dimensions";
9496

9597
if (!MMAMatrixType::isValidElementType(elementType))
96-
return emitError() << "MMAMatrixType elements must be F16 or F32";
98+
return emitError()
99+
<< "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
97100

98101
return success();
99102
}

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,8 @@ LogicalResult ShflOp::verify() {
537537
}
538538

539539
std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
540-
NVVM::MMAFrag frag,
540+
NVVM::MMAFrag frag, int nRow,
541+
int nCol,
541542
MLIRContext *context) {
542543
unsigned numberElements = 0;
543544
Type elementType;
@@ -555,11 +556,48 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
555556
} else if (type == NVVM::MMATypes::tf32) {
556557
elementType = builder.getI32Type();
557558
numberElements = 4;
559+
} else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
560+
elementType = builder.getI32Type();
561+
int parallelSize = 0;
562+
if (frag == NVVM::MMAFrag::a)
563+
parallelSize = nRow;
564+
if (frag == NVVM::MMAFrag::b)
565+
parallelSize = nCol;
566+
567+
// m == 16 && n == 16 && k == 16
568+
if (parallelSize == 16)
569+
numberElements = 2;
570+
// m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
571+
else if (parallelSize == 8)
572+
numberElements = 1;
573+
else if (parallelSize == 32)
574+
numberElements = 4;
575+
} else if (type == NVVM::MMATypes::s32) {
576+
elementType = builder.getI32Type();
577+
numberElements = 8;
558578
}
559579
assert(numberElements != 0 && elementType != nullptr);
560580
return std::make_pair(elementType, numberElements);
561581
}
562582

583+
static std::pair<mlir::Type, unsigned>
584+
inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
585+
int k, MLIRContext *context) {
586+
int nRow, nCol;
587+
if (frag == NVVM::MMAFrag::a) {
588+
nRow = m;
589+
nCol = k;
590+
} else if (frag == NVVM::MMAFrag::b) {
591+
nRow = k;
592+
nCol = n;
593+
} else {
594+
nRow = m;
595+
nCol = n;
596+
}
597+
assert(nRow && nCol);
598+
return inferMMAType(type, frag, nRow, nCol, context);
599+
}
600+
563601
LogicalResult NVVM::WMMALoadOp::verify() {
564602
unsigned addressSpace =
565603
getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
@@ -570,8 +608,8 @@ LogicalResult NVVM::WMMALoadOp::verify() {
570608
if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
571609
getEltype(), getFrag()) == 0)
572610
return emitOpError() << "invalid attribute combination";
573-
std::pair<Type, unsigned> typeInfo =
574-
inferMMAType(getEltype(), getFrag(), getContext());
611+
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
612+
getEltype(), getFrag(), getM(), getN(), getK(), getContext());
575613
Type dstType = LLVM::LLVMStructType::getLiteral(
576614
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
577615
if (getType() != dstType)
@@ -590,8 +628,8 @@ LogicalResult NVVM::WMMAStoreOp::verify() {
590628
if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
591629
getEltype()) == 0)
592630
return emitOpError() << "invalid attribute combination";
593-
std::pair<Type, unsigned> typeInfo =
594-
inferMMAType(getEltype(), NVVM::MMAFrag::c, getContext());
631+
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
632+
getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
595633
if (getArgs().size() != typeInfo.second)
596634
return emitOpError() << "expected " << typeInfo.second << " data operands";
597635
if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
@@ -606,12 +644,12 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
606644
getLayoutB(), getEltypeA(),
607645
getEltypeB()) == 0)
608646
return emitOpError() << "invalid attribute combination";
609-
std::pair<Type, unsigned> typeInfoA =
610-
inferMMAType(getEltypeA(), NVVM::MMAFrag::a, getContext());
611-
std::pair<Type, unsigned> typeInfoB =
612-
inferMMAType(getEltypeA(), NVVM::MMAFrag::b, getContext());
613-
std::pair<Type, unsigned> typeInfoC =
614-
inferMMAType(getEltypeB(), NVVM::MMAFrag::c, getContext());
647+
std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
648+
getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
649+
std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
650+
getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
651+
std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
652+
getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
615653
SmallVector<Type, 32> arguments;
616654
arguments.append(typeInfoA.second, typeInfoA.first);
617655
arguments.append(typeInfoB.second, typeInfoB.first);

0 commit comments

Comments
 (0)