Skip to content

Commit c2fc8d9

Browse files
committed
[mlir][GPU] Allow bare pointer memrefs when calling GPU kernels
In the ROCm runtime (and probably CUDA as well), all kernel arguments are aligned. Therefore, enable using bare pointers for memref arguments to kernels when these memrefs have static shape and a trivial layout. This is a substantial optimization to launching kernels that use memrefs with known, static sizes, since it causes the kernel launch packet to no longer include information already known to the kernel, which can enable packing the kernel launch arguments into launch packets instead of having to allocate an entire separate structure to hold unneeded memref information. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D130716
1 parent 95a9299 commit c2fc8d9

File tree

9 files changed

+153
-27
lines changed

9 files changed

+153
-27
lines changed

mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@ using LoweringCallback = std::function<std::unique_ptr<llvm::Module>(
5050
/// This pass does not generate code to call GPU runtime APIs directly but
5151
/// instead uses a small wrapper library that exports a stable and conveniently
5252
/// typed ABI on top of GPU runtimes such as CUDA or ROCm (HIP).
53-
std::unique_ptr<OperationPass<ModuleOp>> createGpuToLLVMConversionPass();
53+
std::unique_ptr<OperationPass<ModuleOp>>
54+
createGpuToLLVMConversionPass(bool kernelBarePtrCallConv = false);
5455

5556
/// Collect a set of patterns to convert from the GPU dialect to LLVM and
5657
/// populate converter for gpu types.
5758
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
5859
RewritePatternSet &patterns,
59-
StringRef gpuBinaryAnnotation = {});
60+
StringRef gpuBinaryAnnotation = {},
61+
bool kernelBarePtrCallConv = false);
6062

6163
} // namespace mlir
6264

mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
4141
createLowerGpuOpsToROCDLOpsPass(
4242
const std::string &chipset = "gfx900",
4343
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
44+
bool useBarePtrCallConv = false,
4445
gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
4546

4647
} // namespace mlir

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ class LLVMTypeConverter : public TypeConverter {
8080

8181
const LowerToLLVMOptions &getOptions() const { return options; }
8282

83+
/// Set the lowering options to `newOptions`. Note: using this after some
84+
/// some conversions have been performed can lead to inconsistencies in the
85+
/// IR.
86+
void dangerousSetOptions(LowerToLLVMOptions newOptions) {
87+
options = std::move(newOptions);
88+
}
89+
8390
/// Promote the LLVM representation of all operands including promoting MemRef
8491
/// descriptors to stack and use pointers to struct to avoid the complexity
8592
/// of the platform-specific C/C++ ABI lowering related to struct argument
@@ -126,7 +133,7 @@ class LLVMTypeConverter : public TypeConverter {
126133
const DataLayout &layout);
127134

128135
/// Check if a memref type can be converted to a bare pointer.
129-
bool canConvertToBarePtr(BaseMemRefType type);
136+
static bool canConvertToBarePtr(BaseMemRefType type);
130137

131138
protected:
132139
/// Pointer to the LLVM dialect.

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
373373
Option<"indexBitwidth", "index-bitwidth", "unsigned",
374374
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
375375
"Bitwidth of the index type, 0 to use size of machine word">,
376+
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
377+
/*default=*/"false",
378+
"Replace memref arguments in GPU functions with bare pointers."
379+
"All memrefs must have static shape">,
376380
Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
377381
"::mlir::gpu::amd::Runtime::Unknown",
378382
"Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "GPUOpsLowering.h"
1010
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1111
#include "mlir/IR/Builders.h"
12+
#include "llvm/ADT/STLExtras.h"
1213
#include "llvm/Support/FormatVariadic.h"
1314

1415
using namespace mlir;
@@ -137,6 +138,34 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
137138
&signatureConversion)))
138139
return failure();
139140

141+
// If bare memref pointers are being used, remap them back to memref
142+
// descriptors This must be done after signature conversion to get rid of the
143+
// unrealized casts.
144+
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
145+
OpBuilder::InsertionGuard guard(rewriter);
146+
rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
147+
for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
148+
auto memrefTy = en.value().dyn_cast<MemRefType>();
149+
if (!memrefTy)
150+
continue;
151+
assert(memrefTy.hasStaticShape() &&
152+
"Bare pointer convertion used with dynamically-shaped memrefs");
153+
// Use a placeholder when replacing uses of the memref argument to prevent
154+
// circular replacements.
155+
auto remapping = signatureConversion.getInputMapping(en.index());
156+
assert(remapping && remapping->size == 1 &&
157+
"Type converter should produce 1-to-1 mapping for bare memrefs");
158+
BlockArgument newArg =
159+
llvmFuncOp.getBody().getArgument(remapping->inputNo);
160+
auto placeholder = rewriter.create<LLVM::UndefOp>(
161+
loc, getTypeConverter()->convertType(memrefTy));
162+
rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
163+
Value desc = MemRefDescriptor::fromStaticShape(
164+
rewriter, loc, *getTypeConverter(), memrefTy, newArg);
165+
rewriter.replaceOp(placeholder, {desc});
166+
}
167+
}
168+
140169
rewriter.eraseOp(gpuFuncOp);
141170
return success();
142171
}

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ class GpuToLLVMConversionPass
4949
public:
5050
GpuToLLVMConversionPass() = default;
5151

52+
GpuToLLVMConversionPass(bool kernelBarePtrCallConv)
53+
: GpuToLLVMConversionPass() {
54+
if (this->kernelBarePtrCallConv.getNumOccurrences() == 0)
55+
this->kernelBarePtrCallConv = kernelBarePtrCallConv;
56+
}
57+
5258
GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
5359
: GpuToLLVMConversionPassBase(other) {}
5460

@@ -60,6 +66,11 @@ class GpuToLLVMConversionPass
6066
*this, "gpu-binary-annotation",
6167
llvm::cl::desc("Annotation attribute string for GPU binary"),
6268
llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
69+
Option<bool> kernelBarePtrCallConv{
70+
*this, "use-bare-pointers-for-kernels",
71+
llvm::cl::desc("Use bare pointers to pass memref arguments to kernels. "
72+
"The kernel must use the same setting for this option."),
73+
llvm::cl::init(false)};
6374
};
6475

6576
struct FunctionCallBuilder {
@@ -290,9 +301,11 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
290301
: public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
291302
public:
292303
ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
293-
StringRef gpuBinaryAnnotation)
304+
StringRef gpuBinaryAnnotation,
305+
bool kernelBarePtrCallConv)
294306
: ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
295-
gpuBinaryAnnotation(gpuBinaryAnnotation) {}
307+
gpuBinaryAnnotation(gpuBinaryAnnotation),
308+
kernelBarePtrCallConv(kernelBarePtrCallConv) {}
296309

297310
private:
298311
Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
@@ -305,6 +318,7 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
305318
ConversionPatternRewriter &rewriter) const override;
306319

307320
llvm::SmallString<32> gpuBinaryAnnotation;
321+
bool kernelBarePtrCallConv;
308322
};
309323

310324
class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
@@ -377,7 +391,8 @@ void GpuToLLVMConversionPass::runOnOperation() {
377391
populateFuncToLLVMConversionPatterns(converter, patterns);
378392
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
379393
target);
380-
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
394+
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
395+
kernelBarePtrCallConv);
381396

382397
if (failed(
383398
applyPartialConversion(getOperation(), target, std::move(patterns))))
@@ -635,9 +650,24 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
635650
gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
636651
auto loc = launchOp.getLoc();
637652
auto numKernelOperands = launchOp.getNumKernelOperands();
638-
auto arguments = getTypeConverter()->promoteOperands(
639-
loc, launchOp.getOperands().take_back(numKernelOperands),
640-
adaptor.getOperands().take_back(numKernelOperands), builder);
653+
SmallVector<Value, 4> arguments;
654+
if (kernelBarePtrCallConv) {
655+
// Hack the bare pointer value on just for the argument promotion
656+
LLVMTypeConverter *converter = getTypeConverter();
657+
LowerToLLVMOptions options = converter->getOptions();
658+
LowerToLLVMOptions overrideToMatchKernelOpts = options;
659+
overrideToMatchKernelOpts.useBarePtrCallConv = true;
660+
converter->dangerousSetOptions(overrideToMatchKernelOpts);
661+
arguments = converter->promoteOperands(
662+
loc, launchOp.getOperands().take_back(numKernelOperands),
663+
adaptor.getOperands().take_back(numKernelOperands), builder);
664+
converter->dangerousSetOptions(options);
665+
} else {
666+
arguments = getTypeConverter()->promoteOperands(
667+
loc, launchOp.getOperands().take_back(numKernelOperands),
668+
adaptor.getOperands().take_back(numKernelOperands), builder);
669+
}
670+
641671
auto numArguments = arguments.size();
642672
SmallVector<Type, 4> argumentTypes;
643673
argumentTypes.reserve(numArguments);
@@ -870,13 +900,14 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
870900
}
871901

872902
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
873-
mlir::createGpuToLLVMConversionPass() {
874-
return std::make_unique<GpuToLLVMConversionPass>();
903+
mlir::createGpuToLLVMConversionPass(bool kernelBarePtrCallConv) {
904+
return std::make_unique<GpuToLLVMConversionPass>(kernelBarePtrCallConv);
875905
}
876906

877907
void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
878908
RewritePatternSet &patterns,
879-
StringRef gpuBinaryAnnotation) {
909+
StringRef gpuBinaryAnnotation,
910+
bool kernelBarePtrCallConv) {
880911
converter.addConversion(
881912
[context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
882913
return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
@@ -890,7 +921,7 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
890921
ConvertWaitAsyncOpToGpuRuntimeCallPattern,
891922
ConvertWaitOpToGpuRuntimeCallPattern,
892923
ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
893-
patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
894-
gpuBinaryAnnotation);
924+
patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
925+
converter, gpuBinaryAnnotation, kernelBarePtrCallConv);
895926
patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
896927
}

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@
4141

4242
using namespace mlir;
4343

44+
/// Returns true if the given `gpu.func` can be safely called using the bare
45+
/// pointer calling convention.
46+
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
47+
bool canBeBare = true;
48+
for (Type type : func.getArgumentTypes())
49+
if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
50+
canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
51+
return canBeBare;
52+
}
53+
4454
namespace {
4555

4656
/// Import the GPU Ops to ROCDL Patterns.
@@ -55,10 +65,16 @@ struct LowerGpuOpsToROCDLOpsPass
5565
: public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
5666
LowerGpuOpsToROCDLOpsPass() = default;
5767
LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
68+
bool useBarePtrCallConv,
5869
gpu::amd::Runtime runtime) {
59-
this->chipset = chipset;
60-
this->indexBitwidth = indexBitwidth;
61-
this->runtime = runtime;
70+
if (this->chipset.getNumOccurrences() == 0)
71+
this->chipset = chipset;
72+
if (this->indexBitwidth.getNumOccurrences() == 0)
73+
this->indexBitwidth = indexBitwidth;
74+
if (this->useBarePtrCallConv.getNumOccurrences() == 0)
75+
this->useBarePtrCallConv = useBarePtrCallConv;
76+
if (this->runtime.getNumOccurrences() == 0)
77+
this->runtime = runtime;
6278
}
6379

6480
void runOnOperation() override {
@@ -82,6 +98,23 @@ struct LowerGpuOpsToROCDLOpsPass
8298
ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
8399
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
84100
options.overrideIndexBitwidth(indexBitwidth);
101+
102+
if (useBarePtrCallConv) {
103+
options.useBarePtrCallConv = true;
104+
WalkResult canUseBarePointers =
105+
m.walk([](gpu::GPUFuncOp func) -> WalkResult {
106+
if (canBeCalledWithBarePointers(func))
107+
return WalkResult::advance();
108+
return WalkResult::interrupt();
109+
});
110+
if (canUseBarePointers.wasInterrupted()) {
111+
emitError(UnknownLoc::get(ctx),
112+
"bare pointer calling convention requires all memrefs to "
113+
"have static shape and use the identity map");
114+
return signalPassFailure();
115+
}
116+
}
117+
85118
LLVMTypeConverter converter(ctx, options);
86119

87120
RewritePatternSet patterns(ctx);
@@ -189,7 +222,8 @@ void mlir::populateGpuToROCDLConversionPatterns(
189222
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
190223
mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
191224
unsigned indexBitwidth,
225+
bool useBarePtrCallConv,
192226
gpu::amd::Runtime runtime) {
193-
return std::make_unique<LowerGpuOpsToROCDLOpsPass>(chipset, indexBitwidth,
194-
runtime);
227+
return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
228+
chipset, indexBitwidth, useBarePtrCallConv, runtime);
195229
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s \
3+
// RUN: -convert-gpu-to-rocdl=use-bare-ptr-memref-call-conv=true \
4+
// RUN: -split-input-file \
5+
// RUN: | FileCheck %s --check-prefix=BARE
6+
7+
gpu.module @memref_conversions {
8+
// CHECK: llvm.func @kern
9+
// CHECK-SAME: (%{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64)
10+
// BARE: llvm.func @kern
11+
// BARE-SAME: (%{{.*}}: !llvm.ptr<f32>)
12+
gpu.func @kern(%arg0: memref<8xf32>) kernel {
13+
gpu.return
14+
}
15+
}

mlir/test/Integration/GPU/ROCM/vecadd.mlir

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
// RUN: mlir-opt %s \
22
// RUN: -convert-scf-to-cf \
33
// RUN: -gpu-kernel-outlining \
4-
// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl,gpu-to-hsaco{chip=%chip})' \
5-
// RUN: -gpu-to-llvm \
4+
// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl{use-bare-ptr-memref-call-conv=true},gpu-to-hsaco{chip=%chip})' \
5+
// RUN: -gpu-to-llvm=use-bare-pointers-for-kernels=true \
66
// RUN: | mlir-cpu-runner \
77
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext \
88
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
99
// RUN: --entry-point-result=void \
1010
// RUN: | FileCheck %s
1111

12-
func.func @vecadd(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>, %arg2 : memref<?xf32>) {
12+
func.func @vecadd(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %arg2 : memref<5xf32>) {
1313
%c0 = arith.constant 0 : index
1414
%c1 = arith.constant 1 : index
15-
%block_dim = memref.dim %arg0, %c0 : memref<?xf32>
15+
%block_dim = arith.constant 5 : index
1616
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
1717
threads(%tx, %ty, %tz) in (%block_x = %block_dim, %block_y = %c1, %block_z = %c1) {
18-
%a = memref.load %arg0[%tx] : memref<?xf32>
19-
%b = memref.load %arg1[%tx] : memref<?xf32>
18+
%a = memref.load %arg0[%tx] : memref<5xf32>
19+
%b = memref.load %arg1[%tx] : memref<5xf32>
2020
%c = arith.addf %a, %b : f32
21-
memref.store %c, %arg2[%tx] : memref<?xf32>
21+
memref.store %c, %arg2[%tx] : memref<5xf32>
2222
gpu.terminator
2323
}
2424
return
@@ -49,8 +49,11 @@ func.func @main() {
4949
%9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref<?xf32>) -> (memref<?xf32>)
5050
%10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref<?xf32>) -> (memref<?xf32>)
5151
%11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref<?xf32>) -> (memref<?xf32>)
52+
%12 = memref.cast %9 : memref<?xf32> to memref<5xf32>
53+
%13 = memref.cast %10 : memref<?xf32> to memref<5xf32>
54+
%14 = memref.cast %11 : memref<?xf32> to memref<5xf32>
5255

53-
call @vecadd(%9, %10, %11) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
56+
call @vecadd(%12, %13, %14) : (memref<5xf32>, memref<5xf32>, memref<5xf32>) -> ()
5457
call @printMemrefF32(%8) : (memref<*xf32>) -> ()
5558
return
5659
}

0 commit comments

Comments
 (0)