Skip to content

Commit ddd6acd

Browse files
authored
[mlir][GPU] Expand LLVM function attribute copies (#76755)
Expand the copying of attributes on GPU kernel arguments during LLVM lowering. Support copying attributes from values that are already LLVM pointers. Support copying attributes, like `noundef`, that aren't specific to (the pointer parts of) arguments.
1 parent f64d1c8 commit ddd6acd

File tree

2 files changed

+67
-26
lines changed

2 files changed

+67
-26
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
2626

2727
SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
2828
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
29-
for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
30-
BlockArgument attribution = en.value();
31-
29+
for (const auto [idx, attribution] :
30+
llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
3231
auto type = dyn_cast<MemRefType>(attribution.getType());
3332
assert(type && type.hasStaticShape() && "unexpected type in attribution");
3433

@@ -37,12 +36,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
3736
auto elementType =
3837
cast<Type>(typeConverter->convertType(type.getElementType()));
3938
auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
40-
std::string name = std::string(
41-
llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
39+
std::string name =
40+
std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
4241
uint64_t alignment = 0;
4342
if (auto alignAttr =
4443
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr(
45-
en.index(), LLVM::LLVMDialect::getAlignAttrName())))
44+
idx, LLVM::LLVMDialect::getAlignAttrName())))
4645
alignment = alignAttr.getInt();
4746
auto globalOp = rewriter.create<LLVM::GlobalOp>(
4847
gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
@@ -105,8 +104,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
105104
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
106105
unsigned numProperArguments = gpuFuncOp.getNumArguments();
107106

108-
for (const auto &en : llvm::enumerate(workgroupBuffers)) {
109-
LLVM::GlobalOp global = en.value();
107+
for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
110108
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
111109
global.getAddrSpace());
112110
Value address = rewriter.create<LLVM::AddressOfOp>(
@@ -119,18 +117,18 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
119117
// existing memref infrastructure. This may use more registers than
120118
// otherwise necessary given that memref sizes are fixed, but we can try
121119
// and canonicalize that away later.
122-
Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
120+
Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
123121
auto type = cast<MemRefType>(attribution.getType());
124122
auto descr = MemRefDescriptor::fromStaticShape(
125123
rewriter, loc, *getTypeConverter(), type, memory);
126-
signatureConversion.remapInput(numProperArguments + en.index(), descr);
124+
signatureConversion.remapInput(numProperArguments + idx, descr);
127125
}
128126

129127
// Rewrite private memory attributions to alloca'ed buffers.
130128
unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
131129
auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
132-
for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
133-
Value attribution = en.value();
130+
for (const auto [idx, attribution] :
131+
llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
134132
auto type = cast<MemRefType>(attribution.getType());
135133
assert(type && type.hasStaticShape() && "unexpected type in attribution");
136134

@@ -145,14 +143,14 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
145143
uint64_t alignment = 0;
146144
if (auto alignAttr =
147145
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
148-
en.index(), LLVM::LLVMDialect::getAlignAttrName())))
146+
idx, LLVM::LLVMDialect::getAlignAttrName())))
149147
alignment = alignAttr.getInt();
150148
Value allocated = rewriter.create<LLVM::AllocaOp>(
151149
gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
152150
auto descr = MemRefDescriptor::fromStaticShape(
153151
rewriter, loc, *getTypeConverter(), type, allocated);
154152
signatureConversion.remapInput(
155-
numProperArguments + numWorkgroupAttributions + en.index(), descr);
153+
numProperArguments + numWorkgroupAttributions + idx, descr);
156154
}
157155
}
158156

@@ -169,15 +167,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
169167
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
170168
OpBuilder::InsertionGuard guard(rewriter);
171169
rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
172-
for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
173-
auto memrefTy = dyn_cast<MemRefType>(en.value());
170+
for (const auto [idx, argTy] :
171+
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
172+
auto memrefTy = dyn_cast<MemRefType>(argTy);
174173
if (!memrefTy)
175174
continue;
176175
assert(memrefTy.hasStaticShape() &&
177176
"Bare pointer convertion used with dynamically-shaped memrefs");
178177
// Use a placeholder when replacing uses of the memref argument to prevent
179178
// circular replacements.
180-
auto remapping = signatureConversion.getInputMapping(en.index());
179+
auto remapping = signatureConversion.getInputMapping(idx);
181180
assert(remapping && remapping->size == 1 &&
182181
"Type converter should produce 1-to-1 mapping for bare memrefs");
183182
BlockArgument newArg =
@@ -193,19 +192,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
193192

194193
// Get memref type from function arguments and set the noalias to
195194
// pointer arguments.
196-
for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
197-
auto memrefTy = en.value().dyn_cast<MemRefType>();
198-
NamedAttrList argAttr = argAttrs
199-
? argAttrs[en.index()].cast<DictionaryAttr>()
200-
: NamedAttrList();
201-
195+
for (const auto [idx, argTy] :
196+
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
197+
auto remapping = signatureConversion.getInputMapping(idx);
198+
NamedAttrList argAttr =
199+
argAttrs ? argAttrs[idx].cast<DictionaryAttr>() : NamedAttrList();
200+
auto copyAttribute = [&](StringRef attrName) {
201+
Attribute attr = argAttr.erase(attrName);
202+
if (!attr)
203+
return;
204+
for (size_t i = 0, e = remapping->size; i < e; ++i)
205+
llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
206+
};
202207
auto copyPointerAttribute = [&](StringRef attrName) {
203208
Attribute attr = argAttr.erase(attrName);
204209

205-
// This is a proxy for the bare pointer calling convention.
206210
if (!attr)
207211
return;
208-
auto remapping = signatureConversion.getInputMapping(en.index());
209212
if (remapping->size > 1 &&
210213
attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
211214
emitWarning(llvmFuncOp.getLoc(),
@@ -224,10 +227,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
224227
if (argAttr.empty())
225228
continue;
226229

227-
if (memrefTy) {
230+
copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
231+
copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
232+
copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
233+
bool lowersToPointer = false;
234+
for (size_t i = 0, e = remapping->size; i < e; ++i) {
235+
lowersToPointer |= isa<LLVM::LLVMPointerType>(
236+
llvmFuncOp.getArgument(remapping->inputNo + i).getType());
237+
}
238+
239+
if (lowersToPointer) {
228240
copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
241+
copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
242+
copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
243+
copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
229244
copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
230245
copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
246+
copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
231247
copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
232248
copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
233249
copyPointerAttribute(

mlir/test/Conversion/GPUCommon/memref-arg-attrs.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ gpu.module @kernel {
2424
// ROCDL-SAME: !llvm.ptr {llvm.writeonly}
2525
// NVVM-SAME: !llvm.ptr {llvm.writeonly}
2626

27+
// -----
28+
29+
gpu.module @kernel {
30+
gpu.func @test_func_readonly_ptr(%arg0 : !llvm.ptr {llvm.readonly} ) {
31+
gpu.return
32+
}
33+
}
34+
35+
// CHECK-LABEL: llvm.func @test_func_readonly_ptr
36+
// ROCDL-SAME: !llvm.ptr {llvm.readonly}
37+
// NVVM-SAME: !llvm.ptr {llvm.readonly}
2738

2839
// -----
2940

@@ -62,3 +73,17 @@ gpu.module @kernel {
6273
// CHECK-LABEL: llvm.func @test_func_dereferenceable_or_null
6374
// ROCDL-SAME: !llvm.ptr {llvm.dereferenceable_or_null = 4 : i64}
6475
// NVVM-SAME: !llvm.ptr {llvm.dereferenceable_or_null = 4 : i64}
76+
77+
// -----
78+
79+
gpu.module @kernel {
80+
gpu.func @test_func_noundef(%arg0 : memref<f32> {llvm.noundef} ) {
81+
gpu.return
82+
}
83+
}
84+
85+
// CHECK-LABEL: llvm.func @test_func_noundef
86+
// ROCDL-SAME: !llvm.ptr {llvm.noundef}
87+
// ROCDL-SAME: i64 {llvm.noundef}
88+
// NVVM-SAME: !llvm.ptr {llvm.noundef}
89+
// NVVM-SAME: i64 {llvm.noundef}

0 commit comments

Comments
 (0)