Skip to content

Commit 52db7e2

Browse files
authored
[mlir][nvgpu] Improve WarpgroupAccumulator type to simplify IR (#68728)
`WarpgroupAccumulator` (or `!nvgpu.warpgroup.accumulator`) is a type that keeps the accumulator matrix that is used by warp-group level matrix multiplication. It is handy to have a special type for that as the matrix is distributed among the threads of the warp-group. However, current transformations requires to create and use multiple `WarpgroupAccumulator` if the shape of GEMM is larger than the supported shape of `wgmma.mma_async` instruction. This makes IR looks dense. This PR improves the transformation of `WarpgroupAccumulator` type in every nvgpu Op that uses it. **Example: Current GEMM in NVGPU-IR** ``` // Init %m1, %m2 = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> // GEMM %r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> -> !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> // Epilogue nvgpu.warpgroup.mma.store [%r1, %r2] to %sharedMemoryBuffer : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> into memref<128x128xf32,3> ``` **Example: This PR simplifies the IR as below:** ``` // Init %m = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> // GEMM %r1 = nvgpu.warpgroup.mma %descA, %descB, %m1 {transposeB}: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> -> !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> // Epilogue nvgpu.warpgroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> into memref<128x128xf32,3> ```
1 parent 838f289 commit 52db7e2

File tree

7 files changed

+177
-158
lines changed

7 files changed

+177
-158
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -719,8 +719,8 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
719719
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
720720
OptionalAttr<UnitAttr>:$transposeA,
721721
OptionalAttr<UnitAttr>:$transposeB,
722-
Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
723-
let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
722+
NVGPU_WarpgroupAccumulator:$matrixC);
723+
let results = (outs NVGPU_WarpgroupAccumulator:$matrixD);
724724
let assemblyFormat = [{
725725
$descriptorA`,` $descriptorB`,` $matrixC attr-dict
726726
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
@@ -739,11 +739,11 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
739739
Note that, the op must be run with warp group.
740740
}];
741741

742-
let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
742+
let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
743743
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
744744

745745
let assemblyFormat = [{
746-
`[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
746+
$matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
747747
}];
748748
let hasVerifier = 1;
749749
}
@@ -755,7 +755,7 @@ def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulat
755755
This Op generates and initializes the accumulator matrix for
756756
`nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
757757
}];
758-
let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
758+
let results = (outs NVGPU_WarpgroupAccumulator:$matrixC);
759759
let assemblyFormat = "attr-dict `->` type($matrixC)";
760760
let hasVerifier = 1;
761761
}

mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
constexpr int kWarpSize = 32;
2525

26+
/// M size of wgmma.mma_async instruction
27+
constexpr int kWgmmaSizeM = 64;
28+
2629
#define GET_ATTRDEF_CLASSES
2730
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
2831

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,28 @@ struct ConvertNVGPUToNVVMPass
412412
return converter.convertType(IntegerType::get(type.getContext(), 32));
413413
});
414414
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
415-
VectorType vtype = type.getFragmented();
415+
Type elemType = type.getFragmented().getElementType();
416+
int64_t sizeM = type.getFragmented().getDimSize(0);
417+
int64_t sizeN = type.getFragmented().getDimSize(1);
418+
419+
unsigned numMembers;
420+
if (elemType.isF32() || elemType.isInteger(32))
421+
numMembers = sizeN / 2;
422+
else if (elemType.isF16())
423+
numMembers = sizeN / 4;
424+
else
425+
llvm_unreachable("unsupported type for warpgroup accumulator");
426+
427+
SmallVector<Type> innerStructBody;
428+
for (unsigned i = 0; i < numMembers; i++)
429+
innerStructBody.push_back(elemType);
430+
auto innerStructType =
431+
LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
432+
416433
SmallVector<Type> structBody;
417-
for (unsigned i = 0; i < vtype.getDimSize(0); i++)
418-
structBody.push_back(vtype.getElementType());
434+
for (int i = 0; i < sizeM; i += kWgmmaSizeM)
435+
structBody.push_back(innerStructType);
436+
419437
auto convertedType =
420438
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
421439
return converter.convertType(convertedType);
@@ -1186,7 +1204,6 @@ struct NVGPUWarpgroupMmaOpLowering
11861204
nvgpu::WarpgroupMmaOp op;
11871205
ImplicitLocOpBuilder b;
11881206
OpAdaptor adaptor;
1189-
const LLVMTypeConverter &typeConverter;
11901207

11911208
// Entire shape of the given Op
11921209
int64_t totalM, totalN, totalK;
@@ -1330,7 +1347,7 @@ struct NVGPUWarpgroupMmaOpLowering
13301347

13311348
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
13321349
/// descriptors and arranges them based on induction variables: i, j, and k.
1333-
Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
1350+
Value generateWgmma(int i, int j, int k, Value matrixC) {
13341351
LLVM_DEBUG(DBGS() << "\t wgmma."
13351352
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
13361353
<< "(A[" << (iterationM * wgmmaM) << ":"
@@ -1359,34 +1376,36 @@ struct NVGPUWarpgroupMmaOpLowering
13591376
auto overflow = NVVM::MMAIntOverflowAttr::get(
13601377
op->getContext(), NVVM::MMAIntOverflow::wrapped);
13611378

1362-
Type resultStructType = typeConverter.convertType(matrixD.getType());
1363-
13641379
return b.create<NVVM::WgmmaMmaAsyncOp>(
1365-
resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
1380+
matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
13661381
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
13671382
}
13681383

13691384
/// Generates multiple wgmma instructions to complete the given GEMM shape
1370-
SmallVector<Value> generateWgmmaGroup() {
1371-
SmallVector<Value> wgmmaResults;
1385+
Value generateWgmmaGroup() {
1386+
Value wgmmaResult =
1387+
b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
13721388

13731389
// Perform GEMM
1390+
SmallVector<Value> wgmmaResults;
13741391
for (int i = 0; i < iterationM; ++i) {
1375-
Value matrixC = adaptor.getMatrixC()[i];
1376-
Value matrixD = op.getMatrixD()[i];
1392+
Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
13771393
for (int j = 0; j < iterationN; ++j)
13781394
for (int k = 0; k < iterationK; ++k)
1379-
matrixC = generateWgmma(i, j, k, matrixC, matrixD);
1395+
matrixC = generateWgmma(i, j, k, matrixC);
13801396
wgmmaResults.push_back(matrixC);
13811397
}
1382-
1383-
return wgmmaResults;
1398+
for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1399+
wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1400+
wgmmaResult, matrix, idx);
1401+
}
1402+
return wgmmaResult;
13841403
}
13851404

13861405
public:
13871406
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1388-
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
1389-
: op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
1407+
OpAdaptor adaptor)
1408+
: op(op), b(b), adaptor(adaptor) {
13901409
// Find the entire GEMM Shape
13911410
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
13921411
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
@@ -1411,27 +1430,27 @@ struct NVGPUWarpgroupMmaOpLowering
14111430
/// instructions and group synchronization, as well as waiting
14121431
/// (WgmmaGroupSyncAlignedOp) for group synchronization
14131432
/// (WgmmaWaitGroupSyncOp) after the instructions.
1414-
SmallVector<Value> generateWarpgroupMma() {
1433+
Value generateWarpgroupMma() {
14151434
b.create<NVVM::WgmmaFenceAlignedOp>();
1416-
SmallVector<Value> wgmmaResults = generateWgmmaGroup();
1435+
Value wgmmaResult = generateWgmmaGroup();
14171436
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
14181437
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1419-
return wgmmaResults;
1438+
return wgmmaResult;
14201439
}
14211440
};
1422-
14231441
LogicalResult
14241442
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
14251443
ConversionPatternRewriter &rewriter) const override {
14261444
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1445+
14271446
// Step 1. Build a helper class
1428-
WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
1447+
WarpgroupGemm warpgroupGemm(op, b, adaptor);
14291448

14301449
// Step 2. Get the entire GEMM Shape
1431-
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
1450+
Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
14321451

14331452
// Step 3. Replace fragmented result struct with the op results
1434-
rewriter.replaceOp(op, wgmmaResults);
1453+
rewriter.replaceOp(op, wgmmaResult);
14351454
return success();
14361455
}
14371456
};
@@ -1535,10 +1554,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
15351554
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
15361555
ConversionPatternRewriter &rewriter) const override {
15371556
int offset = 0;
1538-
ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
1539-
for (Value matrixD : adaptor.getMatrixD()) {
1540-
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
1541-
storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
1557+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1558+
Value matriDValue = adaptor.getMatrixD();
1559+
auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
1560+
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1561+
auto structType = matrixD.cast<LLVM::LLVMStructType>();
1562+
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1563+
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
15421564
offset += structType.getBody().size();
15431565
}
15441566
rewriter.eraseOp(op);
@@ -1554,23 +1576,27 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
15541576
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
15551577
ConversionPatternRewriter &rewriter) const override {
15561578
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1557-
SmallVector<Value> results;
1558-
for (OpResult m : op.getMatrixC()) {
1559-
nvgpu::WarpgroupAccumulatorType mType =
1560-
m.getType().cast<nvgpu::WarpgroupAccumulatorType>();
1561-
Type stype = getTypeConverter()->convertType(mType);
1562-
Value undefStruct = b.create<LLVM::UndefOp>(stype);
1563-
Type elemType = mType.getFragmented().getElementType();
1564-
int64_t elemSize = mType.getFragmented().getDimSize(0);
1565-
Value zero =
1566-
b.create<LLVM::ConstantOp>(elemType, rewriter.getZeroAttr(elemType));
1567-
for (int64_t i = 0; i < elemSize; ++i) {
1568-
undefStruct = b.create<LLVM::InsertValueOp>(stype, undefStruct, zero,
1569-
ArrayRef<int64_t>({i}));
1579+
LLVM::LLVMStructType structType =
1580+
getTypeConverter()
1581+
->convertType(op.getMatrixC().getType())
1582+
.cast<LLVM::LLVMStructType>();
1583+
Type elemType = structType.getBody()
1584+
.front()
1585+
.cast<LLVM::LLVMStructType>()
1586+
.getBody()
1587+
.front();
1588+
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1589+
Value structValue = b.create<LLVM::UndefOp>(structType);
1590+
for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
1591+
auto innerStructType = s.cast<LLVM::LLVMStructType>();
1592+
int ii = idx;
1593+
Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
1594+
for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
1595+
innerStructValue = b.create<LLVM::InsertValueOp>(
1596+
innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
15701597
}
1571-
results.push_back(undefStruct);
15721598
}
1573-
rewriter.replaceOp(op, results);
1599+
rewriter.replaceOp(op, structValue);
15741600
return success();
15751601
}
15761602
};

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 35 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,11 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
435435
return failure();
436436
}
437437

438-
LogicalResult isAllowedSizeM(int sizeM) { return success(sizeM == 64); }
438+
LogicalResult isAllowedSizeM(int sizeM) {
439+
if (sizeM % kWgmmaSizeM)
440+
return failure();
441+
return success();
442+
}
439443

440444
LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
441445
SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
@@ -458,35 +462,16 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
458462

459463
LogicalResult WarpgroupMmaOp::verify() {
460464
if (getTransposeA() && !getTransposeB())
461-
return emitOpError() << "supports non-transpose A (Row Major) "
462-
"and transpose B (Column Major) for the time being";
465+
return emitOpError()
466+
<< "supports non-transpose A (Row Major) "
467+
"and transpose B (Column Major) for the time being ";
463468
MemRefType matrixA = getDescriptorA().getType().getTensor();
464469
MemRefType matrixB = getDescriptorB().getType().getTensor();
465-
VectorType matrixC = getMatrixC()
466-
.front()
467-
.getType()
468-
.cast<WarpgroupAccumulatorType>()
469-
.getFragmented();
470-
VectorType matrixD = getMatrixD()
471-
.front()
472-
.getType()
473-
.cast<WarpgroupAccumulatorType>()
474-
.getFragmented();
475-
unsigned sizeAcc = getMatrixC().size();
476-
477-
if (getMatrixC().size() != getMatrixD().size())
478-
return emitOpError() << "number of matrix C and matrix D must be the same";
479-
480-
if (llvm::all_of(getMatrixC(),
481-
[&](Value rhs) { return rhs.getType() == matrixC; })) {
482-
return emitOpError()
483-
<< "types of all operands in matrix C must be the same";
484-
}
485-
if (llvm::all_of(getMatrixD(),
486-
[&](Value rhs) { return rhs.getType() == matrixC; })) {
487-
return emitOpError()
488-
<< "types of all operands in matrix D must be the same as matrix C";
489-
}
470+
VectorType matrixC = getMatrixC().getType().getFragmented();
471+
VectorType matrixD = getMatrixD().getType().getFragmented();
472+
473+
if (matrixC != matrixD)
474+
return emitOpError() << "type of matrix C and matrix D must be the same";
490475

491476
if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
492477
matrixC.getRank() != 2 || matrixD.getRank() != 2) {
@@ -498,7 +483,7 @@ LogicalResult WarpgroupMmaOp::verify() {
498483
return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
499484
<< ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
500485
<< " )";
501-
if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc))
486+
if (matrixA.getShape()[0] != matrixC.getShape()[0])
502487
return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
503488
<< " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
504489
<< " )";
@@ -534,29 +519,16 @@ LogicalResult WarpgroupMmaOp::verify() {
534519

535520
LogicalResult WarpgroupMmaStoreOp::verify() {
536521
MemRefType dstMemrefType = getDstMemref().getType();
537-
VectorType firstVtype = getMatrixD()
538-
.front()
539-
.getType()
540-
.cast<WarpgroupAccumulatorType>()
541-
.getFragmented();
542-
543-
int64_t totalFirstDimension = 0;
544-
for (Value result : getMatrixD()) {
545-
VectorType vtype =
546-
result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
547-
if (vtype != firstVtype)
548-
return emitOpError() << "all fragmented types must be the same";
549-
// Limitation
550-
if (!vtype.getElementType().isF32()) {
551-
return emitOpError()
552-
<< "hit a limitation: only f32 results for the time being";
553-
}
554-
totalFirstDimension += vtype.getDimSize(0);
522+
VectorType vtype = getMatrixD().getType().getFragmented();
523+
524+
// Limitation
525+
if (!vtype.getElementType().isF32()) {
526+
return emitOpError()
527+
<< "hit a limitation: only f32 results for the time being";
555528
}
556-
if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
557-
firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
558-
return emitOpError() << "results [" << totalFirstDimension << "]["
559-
<< firstVtype.getDimSize(1)
529+
if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
530+
vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
531+
return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
560532
<< "] values. However, destination memref["
561533
<< dstMemrefType.getDimSize(0) << "]["
562534
<< dstMemrefType.getDimSize(1)
@@ -570,19 +542,18 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
570542
//===----------------------------------------------------------------------===//
571543

572544
LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
573-
for (OpResult matrix : getMatrixC()) {
574-
VectorType vectorType = matrix.getType()
575-
.cast<nvgpu::WarpgroupAccumulatorType>()
576-
.getFragmented();
577-
// Check [M][N] shape
578-
if (failed(isAllowedSizeM(vectorType.getDimSize(0))) ||
579-
failed(isAllowedSizeN(vectorType.getDimSize(1),
580-
vectorType.getElementType()))) {
581-
return emitOpError() << "has type " << vectorType
582-
<< ". It does not fit into warp-group "
583-
"level (wgmma) matrix multiplication instruction "
584-
"(or not supported yet)";
585-
}
545+
546+
nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
547+
int64_t sizeM = accType.getFragmented().getDimSize(0);
548+
int64_t sizeN = accType.getFragmented().getDimSize(1);
549+
Type elemType = accType.getFragmented().getElementType();
550+
551+
if (failed(isAllowedSizeM(sizeM)) ||
552+
failed(isAllowedSizeN(sizeN, elemType))) {
553+
return emitOpError() << "has type " << accType.getFragmented()
554+
<< ". It does not fit into warp-group "
555+
"level (wgmma) matrix multiplication instruction "
556+
"(or not supported yet)";
586557
}
587558
return success();
588559
}

0 commit comments

Comments
 (0)