-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv][gpu] Clean up wmma to coop matrix NV conversion. NFC. #66278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This is a cleanup in preparation for adding a second conversion path using the KHR cooperative matrix extension. Make the existing lowering explicit about emitting ops from the NV coop matrix extension. Clean up surrounding code.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu ChangesThis is a cleanup in preparation for adding a second conversion path using the KHR cooperative matrix extension.Make the existing lowering explicit about emitting ops from the NV coop matrix extension. Clean up surrounding code.Full diff: https://github.com/llvm/llvm-project/pull/66278.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h index 3c3281513f60d89..6c4643da1884900 100644 --- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h +++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h @@ -30,11 +30,15 @@ class MMAMatrixType; void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); -/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV. -void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter, - RewritePatternSet &patterns); - -spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type); +/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, +/// using the NV Cooperative Matrix extension. +void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns( + SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); + +/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType +/// `type`. +spirv::CooperativeMatrixNVType +convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type); } // namespace mlir #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index f37c70a771f5916..d0ce58597f980d4 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -39,8 +39,7 @@ namespace { /// replace it). /// /// 2) Lower the body of the spirv::ModuleOp. -class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> { -public: +struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> { explicit GPUToSPIRVPass(bool mapMemorySpace) : mapMemorySpace(mapMemorySpace) {} void runOnOperation() override; @@ -48,7 +47,6 @@ class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> { private: bool mapMemorySpace; }; -} // namespace void GPUToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); @@ -89,11 +87,12 @@ void GPUToSPIRVPass::runOnOperation() { options.use64bitIndex = this->use64bitIndex; SPIRVTypeConverter typeConverter(targetAttr, options); typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type { - return convertMMAToSPIRVType(type); + return convertMMAToSPIRVCoopMatrixNVType(type); }); RewritePatternSet patterns(context); populateGPUToSPIRVPatterns(typeConverter, patterns); - populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns); + populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter, + patterns); // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns); @@ -105,6 +104,8 @@ void GPUToSPIRVPass::runOnOperation() { } } +} // namespace + std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertGPUToSPIRVPass(bool mapMemorySpace) { return std::make_unique<GPUToSPIRVPass>(mapMemorySpace); diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index 3851fb728b6654b..bf3fff027fe384a 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -1,4 +1,4 @@ -//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===// +//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // This file contains definitions of patterns to lower GPU Subgroup MMA ops to -// SPIRV Dialect ops. +// SPIRV Cooperative Matrix ops. // //===----------------------------------------------------------------------===// @@ -22,7 +22,8 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/TypeUtilities.h" -using namespace mlir; +namespace mlir::nv { +namespace { /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op /// when the elementwise op directly supports with cooperative matrix type. @@ -70,12 +71,10 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder, return false; } -namespace { - -/// This class implements the conversion of GPU MMA loadOp to -/// CooperativeMatrixLoad op in the SPIRV dialect. -struct WmmaLoadOpToSPIRVLowering - : public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> { +/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV +/// dialect. +struct WmmaLoadOpToSPIRVLowering final + : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -90,7 +89,7 @@ struct WmmaLoadOpToSPIRVLowering Value bufferPtr = spirv::getElementPtr( *getTypeConverter<const SPIRVTypeConverter>(), memrefType, adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter); - auto coopType = convertMMAToSPIRVType(retType); + auto coopType = convertMMAToSPIRVCoopMatrixNVType(retType); int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue(); auto i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create<spirv::ConstantOp>( @@ -105,10 +104,10 @@ struct WmmaLoadOpToSPIRVLowering } }; -/// This class implements the conversion of GPU MMA StoreOp to -/// CooperativeMatrixStore op in the SPIRV dialect. -struct WmmaStoreOpToSPIRVLowering - : public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> { +/// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV +/// dialect. +struct WmmaStoreOpToSPIRVLowering final + : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -136,10 +135,10 @@ struct WmmaStoreOpToSPIRVLowering } }; -/// This class implements the conversion of GPU MMA Compute to -/// CooperativeMatrixMulAdd op in the SPIRV dialect. -struct WmmaMmaOpToSPIRVLowering - : public OpConversionPattern<gpu::SubgroupMmaComputeOp> { +/// Converts GPU MMA Compute to +/// NVCooperativeMatrixMulAdd op in the SPIRV dialect. +struct WmmaMmaOpToSPIRVLowering final + : OpConversionPattern<gpu::SubgroupMmaComputeOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -153,9 +152,10 @@ struct WmmaMmaOpToSPIRVLowering } }; -/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops. -struct WmmaConstantOpToSPIRVLowering - : public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { +/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix +/// ops. +struct WmmaConstantOpToSPIRVLowering final + : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -163,7 +163,7 @@ struct WmmaConstantOpToSPIRVLowering OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value cst = adaptor.getOperands()[0]; - auto coopType = convertMMAToSPIRVType( + auto coopType = convertMMAToSPIRVCoopMatrixNVType( cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType())); rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( subgroupMmaConstantMatrixOp, coopType, cst); @@ -173,8 +173,8 @@ struct WmmaConstantOpToSPIRVLowering /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for /// the default case. -struct WmmaElementwiseOpToSPIRVDefaultLowering - : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { +struct WmmaElementwiseOpToSPIRVDefaultLowering final + : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -186,7 +186,7 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering if (!isa<spirv::CooperativeMatrixNVType>(operand.getType())) return failure(); } - auto coopType = convertMMAToSPIRVType( + auto coopType = convertMMAToSPIRVCoopMatrixNVType( cast<gpu::MMAMatrixType>(elementwiseOp.getType())); return success(createElementwiseOp(rewriter, elementwiseOp, coopType, adaptor.getOperands())); @@ -195,8 +195,8 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for /// matrix times scalar case. -struct WmmaElementwiseOpToSPIRVScalarMulLowering - : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { +struct WmmaElementwiseOpToSPIRVScalarMulLowering final + : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -238,7 +238,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering assert(cc.getConstituents().size() == 1); scalar = cc.getConstituents().front(); - auto coopType = convertMMAToSPIRVType( + auto coopType = convertMMAToSPIRVCoopMatrixNVType( cast<gpu::MMAMatrixType>(elementwiseOp.getType())); rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>( elementwiseOp, coopType, ValueRange{matrix, scalar}); @@ -247,23 +247,26 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering }; } // namespace +} // namespace mlir::nv -/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. mlir::spirv::CooperativeMatrixNVType -mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) { +mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) { ArrayRef<int64_t> retTypeShape = type.getShape(); Type elementType = type.getElementType(); return spirv::CooperativeMatrixNVType::get( elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]); } -void mlir::populateGpuWMMAToSPIRVConversionPatterns( +void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns( SPIRVTypeConverter &converter, RewritePatternSet &patterns) { + using namespace mlir; MLIRContext *context = patterns.getContext(); - patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering, - WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, - WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); + patterns + .add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering, + nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering, + nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); // Give the following patterns higher benefit to prevail over the default one. - patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context, - /*benefit=*/2); + patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, + context, + /*benefit=*/2); } diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir similarity index 98% rename from mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir rename to mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir index a53eca65fc98699..5811c791f308d1e 100644 --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s module attributes { gpu.container_module, |
qedawkins
approved these changes
Sep 13, 2023
ZijunZhaoCCK
pushed a commit
to ZijunZhaoCCK/llvm-project
that referenced
this pull request
Sep 19, 2023
…lvm#66278) This is a cleanup in preparation for adding a second conversion path using the KHR cooperative matrix extension. Make the existing lowering explicit about emitting ops from the NV coop matrix extension. Clean up surrounding code.
Guzhu-AMD
pushed a commit
to GPUOpen-Drivers/llvm-project
that referenced
this pull request
Sep 21, 2023
Local branch amd-gfx b396737 Merged main:25e8105bff4d into amd-gfx:0d570d01ad3e Remote branch main d6d4a52 [mlir][spirv][gpu] Clean up wmma to coop matrix NV conversion. NFC. (llvm#66278)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a cleanup in preparation for adding a second conversion path using the KHR cooperative matrix extension.
Make the existing lowering explicit about emitting ops from the NV coop matrix extension. Clean up surrounding code.