@@ -311,14 +311,21 @@ struct WmmaLoadOpToSPIRVLowering final
311
311
OpAdaptor adaptor,
312
312
ConversionPatternRewriter &rewriter) const override {
313
313
Location loc = subgroupMmaLoadMatrixOp->getLoc ();
314
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
315
+
314
316
gpu::MMAMatrixType retType =
315
317
cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes ().getType ());
316
318
auto memrefType =
317
319
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref ().getType ());
318
- Value bufferPtr = spirv::getElementPtr (
319
- *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
320
- adaptor.getSrcMemref (), adaptor.getIndices (), loc, rewriter);
321
- auto coopType = convertMMAToSPIRVCoopMatrixNVType (retType);
320
+ Value bufferPtr =
321
+ spirv::getElementPtr (typeConverter, memrefType, adaptor.getSrcMemref (),
322
+ adaptor.getIndices (), loc, rewriter);
323
+ auto coopType =
324
+ typeConverter.convertType <spirv::CooperativeMatrixNVType>(retType);
325
+ if (!coopType)
326
+ return rewriter.notifyMatchFailure (subgroupMmaLoadMatrixOp,
327
+ " type conversion failed" );
328
+
322
329
int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension ().getSExtValue ();
323
330
auto i32Type = rewriter.getI32Type ();
324
331
auto strideValue = rewriter.create <spirv::ConstantOp>(
@@ -385,30 +392,6 @@ struct WmmaMmaOpToSPIRVLowering final
385
392
} // namespace nv
386
393
} // namespace mlir
387
394
388
- mlir::spirv::CooperativeMatrixNVType
389
- mlir::convertMMAToSPIRVCoopMatrixNVType (gpu::MMAMatrixType type) {
390
- ArrayRef<int64_t > retTypeShape = type.getShape ();
391
- Type elementType = type.getElementType ();
392
- return spirv::CooperativeMatrixNVType::get (
393
- elementType, spirv::Scope::Subgroup, retTypeShape[0 ], retTypeShape[1 ]);
394
- }
395
-
396
- mlir::spirv::CooperativeMatrixType
397
- mlir::convertMMAToSPIRVCoopMatrixType (gpu::MMAMatrixType type) {
398
- ArrayRef<int64_t > retTypeShape = type.getShape ();
399
- Type elementType = type.getElementType ();
400
-
401
- auto use =
402
- llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand ())
403
- .Case (" AOp" , spirv::CooperativeMatrixUseKHR::MatrixA)
404
- .Case (" BOp" , spirv::CooperativeMatrixUseKHR::MatrixB)
405
- .Default (spirv::CooperativeMatrixUseKHR::MatrixAcc);
406
-
407
- return spirv::CooperativeMatrixType::get (elementType, retTypeShape[0 ],
408
- retTypeShape[1 ],
409
- spirv::Scope::Subgroup, use);
410
- }
411
-
412
395
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns (
413
396
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
414
397
using namespace mlir ;
@@ -432,3 +415,31 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
432
415
patterns.add <WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
433
416
/* benefit=*/ 2 );
434
417
}
418
+
419
+ void mlir::populateMMAToSPIRVCoopMatrixTypeConversion (
420
+ mlir::SPIRVTypeConverter &typeConverter, bool useNVTypes) {
421
+ if (useNVTypes) {
422
+ typeConverter.addConversion ([](gpu::MMAMatrixType type) {
423
+ ArrayRef<int64_t > retTypeShape = type.getShape ();
424
+ Type elementType = type.getElementType ();
425
+ return spirv::CooperativeMatrixNVType::get (
426
+ elementType, spirv::Scope::Subgroup, retTypeShape[0 ],
427
+ retTypeShape[1 ]);
428
+ });
429
+ return ;
430
+ }
431
+
432
+ typeConverter.addConversion ([](gpu::MMAMatrixType type) {
433
+ ArrayRef<int64_t > retTypeShape = type.getShape ();
434
+ Type elementType = type.getElementType ();
435
+ auto use =
436
+ llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand ())
437
+ .Case (" AOp" , spirv::CooperativeMatrixUseKHR::MatrixA)
438
+ .Case (" BOp" , spirv::CooperativeMatrixUseKHR::MatrixB)
439
+ .Default (spirv::CooperativeMatrixUseKHR::MatrixAcc);
440
+
441
+ return spirv::CooperativeMatrixType::get (elementType, retTypeShape[0 ],
442
+ retTypeShape[1 ],
443
+ spirv::Scope::Subgroup, use);
444
+ });
445
+ }
0 commit comments