Skip to content

[mlir][nvgpu] Improve WarpgroupAccumulator type to simplify IR #68728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,8 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
OptionalAttr<UnitAttr>:$transposeA,
OptionalAttr<UnitAttr>:$transposeB,
Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
NVGPU_WarpgroupAccumulator:$matrixC);
let results = (outs NVGPU_WarpgroupAccumulator:$matrixD);
let assemblyFormat = [{
$descriptorA`,` $descriptorB`,` $matrixC attr-dict
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
Expand All @@ -739,11 +739,11 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
Note that, the op must be run with warp group.
}];

let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);

let assemblyFormat = [{
`[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
$matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
}];
let hasVerifier = 1;
}
Expand All @@ -755,7 +755,7 @@ def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulat
This Op generates and initializes the accumulator matrix for
`nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
}];
let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
let results = (outs NVGPU_WarpgroupAccumulator:$matrixC);
let assemblyFormat = "attr-dict `->` type($matrixC)";
let hasVerifier = 1;
}
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

constexpr int kWarpSize = 32;

/// M size of wgmma.mma_async instruction
constexpr int kWgmmaSizeM = 64;

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"

Expand Down
112 changes: 69 additions & 43 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,28 @@ struct ConvertNVGPUToNVVMPass
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
VectorType vtype = type.getFragmented();
Type elemType = type.getFragmented().getElementType();
int64_t sizeM = type.getFragmented().getDimSize(0);
int64_t sizeN = type.getFragmented().getDimSize(1);

unsigned numMembers;
if (elemType.isF32() || elemType.isInteger(32))
numMembers = sizeN / 2;
else if (elemType.isF16())
numMembers = sizeN / 4;
else
llvm_unreachable("unsupported type for warpgroup accumulator");

SmallVector<Type> innerStructBody;
for (unsigned i = 0; i < numMembers; i++)
innerStructBody.push_back(elemType);
auto innerStructType =
LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);

SmallVector<Type> structBody;
for (unsigned i = 0; i < vtype.getDimSize(0); i++)
structBody.push_back(vtype.getElementType());
for (int i = 0; i < sizeM; i += kWgmmaSizeM)
structBody.push_back(innerStructType);

auto convertedType =
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
return converter.convertType(convertedType);
Expand Down Expand Up @@ -1186,7 +1204,6 @@ struct NVGPUWarpgroupMmaOpLowering
nvgpu::WarpgroupMmaOp op;
ImplicitLocOpBuilder b;
OpAdaptor adaptor;
const LLVMTypeConverter &typeConverter;

// Entire shape of the given Op
int64_t totalM, totalN, totalK;
Expand Down Expand Up @@ -1330,7 +1347,7 @@ struct NVGPUWarpgroupMmaOpLowering

/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
Value generateWgmma(int i, int j, int k, Value matrixC) {
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
<< "(A[" << (iterationM * wgmmaM) << ":"
Expand Down Expand Up @@ -1359,34 +1376,36 @@ struct NVGPUWarpgroupMmaOpLowering
auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);

Type resultStructType = typeConverter.convertType(matrixD.getType());

return b.create<NVVM::WgmmaMmaAsyncOp>(
resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
}

/// Generates multiple wgmma instructions to complete the given GEMM shape
SmallVector<Value> generateWgmmaGroup() {
SmallVector<Value> wgmmaResults;
Value generateWgmmaGroup() {
Value wgmmaResult =
b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());

// Perform GEMM
SmallVector<Value> wgmmaResults;
for (int i = 0; i < iterationM; ++i) {
Value matrixC = adaptor.getMatrixC()[i];
Value matrixD = op.getMatrixD()[i];
Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
matrixC = generateWgmma(i, j, k, matrixC, matrixD);
matrixC = generateWgmma(i, j, k, matrixC);
wgmmaResults.push_back(matrixC);
}

return wgmmaResults;
for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
wgmmaResult, matrix, idx);
}
return wgmmaResult;
}

public:
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
: op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
OpAdaptor adaptor)
: op(op), b(b), adaptor(adaptor) {
// Find the entire GEMM Shape
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
Expand All @@ -1411,27 +1430,27 @@ struct NVGPUWarpgroupMmaOpLowering
/// instructions and group synchronization, as well as waiting
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
SmallVector<Value> generateWarpgroupMma() {
Value generateWarpgroupMma() {
b.create<NVVM::WgmmaFenceAlignedOp>();
SmallVector<Value> wgmmaResults = generateWgmmaGroup();
Value wgmmaResult = generateWgmmaGroup();
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
return wgmmaResults;
return wgmmaResult;
}
};

LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);

// Step 1. Build a helper class
WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
WarpgroupGemm warpgroupGemm(op, b, adaptor);

// Step 2. Get the entire GEMM Shape
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();

// Step 3. Replace fragmented result struct with the op results
rewriter.replaceOp(op, wgmmaResults);
rewriter.replaceOp(op, wgmmaResult);
return success();
}
};
Expand Down Expand Up @@ -1535,10 +1554,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int offset = 0;
ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
for (Value matrixD : adaptor.getMatrixD()) {
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value matriDValue = adaptor.getMatrixD();
auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
auto structType = matrixD.cast<LLVM::LLVMStructType>();
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
}
rewriter.eraseOp(op);
Expand All @@ -1554,23 +1576,27 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
SmallVector<Value> results;
for (OpResult m : op.getMatrixC()) {
nvgpu::WarpgroupAccumulatorType mType =
m.getType().cast<nvgpu::WarpgroupAccumulatorType>();
Type stype = getTypeConverter()->convertType(mType);
Value undefStruct = b.create<LLVM::UndefOp>(stype);
Type elemType = mType.getFragmented().getElementType();
int64_t elemSize = mType.getFragmented().getDimSize(0);
Value zero =
b.create<LLVM::ConstantOp>(elemType, rewriter.getZeroAttr(elemType));
for (int64_t i = 0; i < elemSize; ++i) {
undefStruct = b.create<LLVM::InsertValueOp>(stype, undefStruct, zero,
ArrayRef<int64_t>({i}));
LLVM::LLVMStructType structType =
getTypeConverter()
->convertType(op.getMatrixC().getType())
.cast<LLVM::LLVMStructType>();
Type elemType = structType.getBody()
.front()
.cast<LLVM::LLVMStructType>()
.getBody()
.front();
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
Value structValue = b.create<LLVM::UndefOp>(structType);
for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
auto innerStructType = s.cast<LLVM::LLVMStructType>();
int ii = idx;
Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
innerStructValue = b.create<LLVM::InsertValueOp>(
innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
}
results.push_back(undefStruct);
}
rewriter.replaceOp(op, results);
rewriter.replaceOp(op, structValue);
return success();
}
};
Expand Down
99 changes: 35 additions & 64 deletions mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,11 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
return failure();
}

LogicalResult isAllowedSizeM(int sizeM) { return success(sizeM == 64); }
LogicalResult isAllowedSizeM(int sizeM) {
if (sizeM % kWgmmaSizeM)
return failure();
return success();
}

LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
Expand All @@ -458,35 +462,16 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {

LogicalResult WarpgroupMmaOp::verify() {
if (getTransposeA() && !getTransposeB())
return emitOpError() << "supports non-transpose A (Row Major) "
"and transpose B (Column Major) for the time being";
return emitOpError()
<< "supports non-transpose A (Row Major) "
"and transpose B (Column Major) for the time being ";
MemRefType matrixA = getDescriptorA().getType().getTensor();
MemRefType matrixB = getDescriptorB().getType().getTensor();
VectorType matrixC = getMatrixC()
.front()
.getType()
.cast<WarpgroupAccumulatorType>()
.getFragmented();
VectorType matrixD = getMatrixD()
.front()
.getType()
.cast<WarpgroupAccumulatorType>()
.getFragmented();
unsigned sizeAcc = getMatrixC().size();

if (getMatrixC().size() != getMatrixD().size())
return emitOpError() << "number of matrix C and matrix D must be the same";

if (llvm::all_of(getMatrixC(),
[&](Value rhs) { return rhs.getType() == matrixC; })) {
return emitOpError()
<< "types of all operands in matrix C must be the same";
}
if (llvm::all_of(getMatrixD(),
[&](Value rhs) { return rhs.getType() == matrixC; })) {
return emitOpError()
<< "types of all operands in matrix D must be the same as matrix C";
}
VectorType matrixC = getMatrixC().getType().getFragmented();
VectorType matrixD = getMatrixD().getType().getFragmented();

if (matrixC != matrixD)
return emitOpError() << "type of matrix C and matrix D must be the same";

if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
matrixC.getRank() != 2 || matrixD.getRank() != 2) {
Expand All @@ -498,7 +483,7 @@ LogicalResult WarpgroupMmaOp::verify() {
return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
<< ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
<< " )";
if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc))
if (matrixA.getShape()[0] != matrixC.getShape()[0])
return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
<< " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
<< " )";
Expand Down Expand Up @@ -534,29 +519,16 @@ LogicalResult WarpgroupMmaOp::verify() {

LogicalResult WarpgroupMmaStoreOp::verify() {
MemRefType dstMemrefType = getDstMemref().getType();
VectorType firstVtype = getMatrixD()
.front()
.getType()
.cast<WarpgroupAccumulatorType>()
.getFragmented();

int64_t totalFirstDimension = 0;
for (Value result : getMatrixD()) {
VectorType vtype =
result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
if (vtype != firstVtype)
return emitOpError() << "all fragmented types must be the same";
// Limitation
if (!vtype.getElementType().isF32()) {
return emitOpError()
<< "hit a limitation: only f32 results for the time being";
}
totalFirstDimension += vtype.getDimSize(0);
VectorType vtype = getMatrixD().getType().getFragmented();

// Limitation
if (!vtype.getElementType().isF32()) {
return emitOpError()
<< "hit a limitation: only f32 results for the time being";
}
if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
return emitOpError() << "results [" << totalFirstDimension << "]["
<< firstVtype.getDimSize(1)
if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
<< "] values. However, destination memref["
<< dstMemrefType.getDimSize(0) << "]["
<< dstMemrefType.getDimSize(1)
Expand All @@ -570,19 +542,18 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
for (OpResult matrix : getMatrixC()) {
VectorType vectorType = matrix.getType()
.cast<nvgpu::WarpgroupAccumulatorType>()
.getFragmented();
// Check [M][N] shape
if (failed(isAllowedSizeM(vectorType.getDimSize(0))) ||
failed(isAllowedSizeN(vectorType.getDimSize(1),
vectorType.getElementType()))) {
return emitOpError() << "has type " << vectorType
<< ". It does not fit into warp-group "
"level (wgmma) matrix multiplication instruction "
"(or not supported yet)";
}

nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
int64_t sizeM = accType.getFragmented().getDimSize(0);
int64_t sizeN = accType.getFragmented().getDimSize(1);
Type elemType = accType.getFragmented().getElementType();

if (failed(isAllowedSizeM(sizeM)) ||
failed(isAllowedSizeN(sizeN, elemType))) {
return emitOpError() << "has type " << accType.getFragmented()
<< ". It does not fit into warp-group "
"level (wgmma) matrix multiplication instruction "
"(or not supported yet)";
}
return success();
}
Expand Down
Loading