@@ -62,10 +62,28 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
62
62
});
63
63
llvmTypeConverter.addConversion (
64
64
[&](nvgpu::WarpgroupAccumulatorType type) -> Type {
65
- VectorType vtype = type.getFragmented ();
65
+ Type elemType = type.getFragmented ().getElementType ();
66
+ int64_t sizeM = type.getFragmented ().getDimSize (0 );
67
+ int64_t sizeN = type.getFragmented ().getDimSize (1 );
68
+
69
+ unsigned numMembers;
70
+ if (elemType.isF32 () || elemType.isInteger (32 ))
71
+ numMembers = sizeN / 2 ;
72
+ else if (elemType.isF16 ())
73
+ numMembers = sizeN / 4 ;
74
+ else
75
+ llvm_unreachable (" unsupported type for warpgroup accumulator" );
76
+
77
+ SmallVector<Type> innerStructBody;
78
+ for (unsigned i = 0 ; i < numMembers; i++)
79
+ innerStructBody.push_back (elemType);
80
+ auto innerStructType = LLVM::LLVMStructType::getLiteral (
81
+ type.getContext (), innerStructBody);
82
+
66
83
SmallVector<Type> structBody;
67
- for (unsigned i = 0 ; i < vtype.getDimSize (0 ); i++)
68
- structBody.push_back (vtype.getElementType ());
84
+ for (int i = 0 ; i < sizeM; i += kWgmmaSizeM )
85
+ structBody.push_back (innerStructType);
86
+
69
87
auto convertedType =
70
88
LLVM::LLVMStructType::getLiteral (type.getContext (), structBody);
71
89
return llvmTypeConverter.convertType (convertedType);
0 commit comments