Skip to content

Commit aecb764

Browse files
authored
[mlir][gpu] GPUToROCDL/NVVM: use generic llvm conversion interface instead of hardcoded conversions. (llvm#124439)
Using `ConvertToLLVMPatternInterface` allows to unhardcode specific dialect conversions from passes and, more importantly, allows downstream projects to inject their ops/types translation here by registering corresponding interface. Add `allowed-dialects` option so user can control which dialects can be used to populate conversions.
1 parent 3e54964 commit aecb764

File tree

7 files changed

+114
-39
lines changed

7 files changed

+114
-39
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -572,14 +572,16 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
572572
];
573573
let options = [
574574
Option<"indexBitwidth", "index-bitwidth", "unsigned",
575-
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
575+
/*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
576576
"Bitwidth of the index type, 0 to use size of machine word">,
577577
Option<"hasRedux", "has-redux", "bool", /*default=*/"false",
578578
"Target gpu supports redux">,
579579
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
580580
/*default=*/"false",
581581
"Replace memref arguments in GPU functions with bare pointers. "
582-
"All memrefs must have static shape.">
582+
"All memrefs must have static shape.">,
583+
ListOption<"allowedDialects", "allowed-dialects", "std::string",
584+
"Run conversion patterns of only the specified dialects">,
583585
];
584586
}
585587

@@ -600,20 +602,24 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
600602
/*default=*/"\"gfx000\"",
601603
"Chipset that these operations will run on">,
602604
Option<"indexBitwidth", "index-bitwidth", "unsigned",
603-
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
605+
/*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
604606
"Bitwidth of the index type, 0 to use size of machine word">,
605607
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
606608
/*default=*/"false",
607609
"Replace memref arguments in GPU functions with bare pointers."
608610
"All memrefs must have static shape">,
609611
Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
610-
"::mlir::gpu::amd::Runtime::Unknown",
611-
"Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
612-
[{::llvm::cl::values(
613-
clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", "Unknown (default)"),
614-
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
615-
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", "OpenCL")
616-
)}]>
612+
"::mlir::gpu::amd::Runtime::Unknown",
613+
"Runtime code will be run on (default is Unknown, can also use HIP "
614+
"or OpenCL)",
615+
[{::llvm::cl::values(
616+
clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown",
617+
"Unknown (default)"),
618+
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
619+
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL",
620+
"OpenCL"))}]>,
621+
ListOption<"allowedDialects", "allowed-dialects", "std::string",
622+
"Run conversion patterns of only the specified dialects">,
617623
];
618624
}
619625

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,14 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15-
16-
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
17-
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1814
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
19-
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
2016
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
2117
#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
18+
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
2219
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
2320
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2421
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
25-
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
26-
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
2722
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
2823
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2924
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -342,10 +337,15 @@ struct AssertOpToAssertfailLowering
342337
///
343338
/// This pass only handles device code and is not meant to be run on GPU host
344339
/// code.
345-
struct LowerGpuOpsToNVVMOpsPass
340+
struct LowerGpuOpsToNVVMOpsPass final
346341
: public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
347342
using Base::Base;
348343

344+
void getDependentDialects(DialectRegistry &registry) const override {
345+
Base::getDependentDialects(registry);
346+
registerConvertToLLVMDependentDialectLoading(registry);
347+
}
348+
349349
void runOnOperation() override {
350350
gpu::GPUModuleOp m = getOperation();
351351

@@ -376,17 +376,41 @@ struct LowerGpuOpsToNVVMOpsPass
376376
LLVMTypeConverter converter(m.getContext(), options);
377377
configureGpuToNVVMTypeConverter(converter);
378378
RewritePatternSet llvmPatterns(m.getContext());
379+
LLVMConversionTarget target(getContext());
380+
381+
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
382+
allowedDialects.end());
383+
for (Dialect *dialect : getContext().getLoadedDialects()) {
384+
// Skip math patterns as nvvm needs custom math lowering.
385+
if (isa<math::MathDialect>(dialect))
386+
continue;
387+
388+
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
389+
// Empty `allowedDialectsSet` means all dialects are allowed.
390+
if (!allowedDialectsSet.empty() && !allowed)
391+
continue;
392+
393+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
394+
if (!iface) {
395+
// Error out if dialect was explicily specified but doesn't implement
396+
// conversion interface.
397+
if (allowed) {
398+
m.emitError()
399+
<< "dialect does not implement ConvertToLLVMPatternInterface: "
400+
<< dialect->getNamespace();
401+
return signalPassFailure();
402+
}
403+
continue;
404+
}
405+
406+
iface->populateConvertToLLVMConversionPatterns(target, converter,
407+
llvmPatterns);
408+
}
379409

380-
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
381-
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
382-
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
383-
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
384410
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
385411
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
386-
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
387412
if (this->hasRedux)
388413
populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
389-
LLVMConversionTarget target(getContext());
390414
configureGpuToNVVMConversionLegality(target);
391415
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
392416
signalPassFailure();
@@ -397,6 +421,7 @@ struct LowerGpuOpsToNVVMOpsPass
397421

398422
void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
399423
target.addIllegalOp<func::FuncOp>();
424+
target.addIllegalOp<cf::AssertOp>();
400425
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
401426
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
402427
target.addIllegalDialect<gpu::GPUDialect>();
@@ -472,8 +497,10 @@ void mlir::populateGpuToNVVMConversionPatterns(
472497
using gpu::index_lowering::IndexKind;
473498
using gpu::index_lowering::IntrType;
474499
populateWithGenerated(patterns);
500+
501+
// Set higher benefit, so patterns will run before generic LLVM lowering.
475502
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
476-
converter);
503+
converter, /*benefit*/ 10);
477504
patterns.add<
478505
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
479506
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,22 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1514
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
1615
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1716
#include "mlir/Pass/Pass.h"
1817
#include "mlir/Pass/PassManager.h"
1918
#include "mlir/Transforms/Passes.h"
2019

2120
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
22-
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
23-
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
22+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
2423
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
2524
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
2625
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2726
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2827
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
2928
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
3029
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
31-
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
32-
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
3330
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
3431
#include "mlir/Dialect/Func/IR/FuncOps.h"
3532
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -202,7 +199,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
202199
//
203200
// This pass only handles device code and is not meant to be run on GPU host
204201
// code.
205-
struct LowerGpuOpsToROCDLOpsPass
202+
struct LowerGpuOpsToROCDLOpsPass final
206203
: public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
207204
LowerGpuOpsToROCDLOpsPass() = default;
208205
LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
@@ -218,6 +215,11 @@ struct LowerGpuOpsToROCDLOpsPass
218215
this->runtime = runtime;
219216
}
220217

218+
void getDependentDialects(DialectRegistry &registry) const override {
219+
Base::getDependentDialects(registry);
220+
registerConvertToLLVMDependentDialectLoading(registry);
221+
}
222+
221223
void runOnOperation() override {
222224
gpu::GPUModuleOp m = getOperation();
223225
MLIRContext *ctx = m.getContext();
@@ -289,18 +291,36 @@ struct LowerGpuOpsToROCDLOpsPass
289291
});
290292

291293
RewritePatternSet llvmPatterns(ctx);
294+
LLVMConversionTarget target(getContext());
295+
296+
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
297+
allowedDialects.end());
298+
for (Dialect *dialect : ctx->getLoadedDialects()) {
299+
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
300+
// Empty `allowedDialectsSet` means all dialects are allowed.
301+
if (!allowedDialectsSet.empty() && !allowed)
302+
continue;
303+
304+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
305+
if (!iface) {
306+
// Error out if dialect was explicily specified but doesn't implement
307+
// conversion interface.
308+
if (allowed) {
309+
m.emitError()
310+
<< "dialect does not implement ConvertToLLVMPatternInterface: "
311+
<< dialect->getNamespace();
312+
return signalPassFailure();
313+
}
314+
continue;
315+
}
316+
317+
iface->populateConvertToLLVMConversionPatterns(target, converter,
318+
llvmPatterns);
319+
}
292320

293-
mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
294321
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
295322
*maybeChipset);
296-
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
297-
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
298-
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
299-
cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
300-
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
301-
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
302323
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
303-
LLVMConversionTarget target(getContext());
304324
configureGpuToROCDLConversionLegality(target);
305325
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
306326
signalPassFailure();
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-opt %s -convert-gpu-to-nvvm='allowed-dialects=test' -verify-diagnostics
2+
3+
// expected-error @+1 {{dialect does not implement ConvertToLLVMPatternInterface: test}}
4+
gpu.module @test_module_1 {
5+
func.func @test(%0 : index) -> index {
6+
%1 = test.increment %0 : index
7+
func.return %1 : index
8+
}
9+
}
10+

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 allowed-dialects=func,arith,cf' -split-input-file | FileCheck %s
23
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
34
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
45

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-opt %s -convert-gpu-to-rocdl='allowed-dialects=test' -verify-diagnostics
2+
3+
// expected-error @+1 {{dialect does not implement ConvertToLLVMPatternInterface: test}}
4+
gpu.module @test_module_1 {
5+
func.func @test(%0 : index) -> index {
6+
%1 = test.increment %0 : index
7+
func.return %1 : index
8+
}
9+
}
10+

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -convert-gpu-to-rocdl='allowed-dialects=func,arith,math' -split-input-file | FileCheck %s
23
// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
34

45
// CHECK-LABEL: @test_module

0 commit comments

Comments
 (0)