Skip to content

Commit 04582c0

Browse files
committed
fix transform dialect
1 parent acce6ab commit 04582c0

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,28 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
6262
});
6363
llvmTypeConverter.addConversion(
6464
[&](nvgpu::WarpgroupAccumulatorType type) -> Type {
65-
VectorType vtype = type.getFragmented();
65+
Type elemType = type.getFragmented().getElementType();
66+
int64_t sizeM = type.getFragmented().getDimSize(0);
67+
int64_t sizeN = type.getFragmented().getDimSize(1);
68+
69+
unsigned numMembers;
70+
if (elemType.isF32() || elemType.isInteger(32))
71+
numMembers = sizeN / 2;
72+
else if (elemType.isF16())
73+
numMembers = sizeN / 4;
74+
else
75+
llvm_unreachable("unsupported type for warpgroup accumulator");
76+
77+
SmallVector<Type> innerStructBody;
78+
for (unsigned i = 0; i < numMembers; i++)
79+
innerStructBody.push_back(elemType);
80+
auto innerStructType = LLVM::LLVMStructType::getLiteral(
81+
type.getContext(), innerStructBody);
82+
6683
SmallVector<Type> structBody;
67-
for (unsigned i = 0; i < vtype.getDimSize(0); i++)
68-
structBody.push_back(vtype.getElementType());
84+
for (int i = 0; i < sizeM; i += kWgmmaSizeM)
85+
structBody.push_back(innerStructType);
86+
6987
auto convertedType =
7088
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
7189
return llvmTypeConverter.convertType(convertedType);

0 commit comments

Comments
 (0)