Skip to content

Commit 97c7966

Browse files
committed
[mlir][GPU] Expand LLVM function attribute copies
Expand the copying of attributes on GPU kernel arguments during LLVM lowering. Support copying attributes from values that are already LLVM pointers. Support copying attributes, like `noundef`, that aren't specific to (the pointer parts of) arguments.
1 parent c91fab5 commit 97c7966

File tree

2 files changed

+67
-26
lines changed

2 files changed

+67
-26
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
2626

2727
SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
2828
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
29-
for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
30-
BlockArgument attribution = en.value();
31-
29+
for (const auto [idx, attribution] :
30+
llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
3231
auto type = dyn_cast<MemRefType>(attribution.getType());
3332
assert(type && type.hasStaticShape() && "unexpected type in attribution");
3433

@@ -37,12 +36,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
3736
auto elementType =
3837
cast<Type>(typeConverter->convertType(type.getElementType()));
3938
auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
40-
std::string name = std::string(
41-
llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
39+
std::string name =
40+
std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
4241
uint64_t alignment = 0;
4342
if (auto alignAttr =
4443
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr(
45-
en.index(), LLVM::LLVMDialect::getAlignAttrName())))
44+
idx, LLVM::LLVMDialect::getAlignAttrName())))
4645
alignment = alignAttr.getInt();
4746
auto globalOp = rewriter.create<LLVM::GlobalOp>(
4847
gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
@@ -105,8 +104,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
105104
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
106105
unsigned numProperArguments = gpuFuncOp.getNumArguments();
107106

108-
for (const auto &en : llvm::enumerate(workgroupBuffers)) {
109-
LLVM::GlobalOp global = en.value();
107+
for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
110108
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
111109
global.getAddrSpace());
112110
Value address = rewriter.create<LLVM::AddressOfOp>(
@@ -119,18 +117,18 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
119117
// existing memref infrastructure. This may use more registers than
120118
// otherwise necessary given that memref sizes are fixed, but we can try
121119
// and canonicalize that away later.
122-
Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
120+
Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
123121
auto type = cast<MemRefType>(attribution.getType());
124122
auto descr = MemRefDescriptor::fromStaticShape(
125123
rewriter, loc, *getTypeConverter(), type, memory);
126-
signatureConversion.remapInput(numProperArguments + en.index(), descr);
124+
signatureConversion.remapInput(numProperArguments + idx, descr);
127125
}
128126

129127
// Rewrite private memory attributions to alloca'ed buffers.
130128
unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
131129
auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
132-
for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
133-
Value attribution = en.value();
130+
for (const auto [idx, attribution] :
131+
llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
134132
auto type = cast<MemRefType>(attribution.getType());
135133
assert(type && type.hasStaticShape() && "unexpected type in attribution");
136134

@@ -145,14 +143,14 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
145143
uint64_t alignment = 0;
146144
if (auto alignAttr =
147145
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
148-
en.index(), LLVM::LLVMDialect::getAlignAttrName())))
146+
idx, LLVM::LLVMDialect::getAlignAttrName())))
149147
alignment = alignAttr.getInt();
150148
Value allocated = rewriter.create<LLVM::AllocaOp>(
151149
gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
152150
auto descr = MemRefDescriptor::fromStaticShape(
153151
rewriter, loc, *getTypeConverter(), type, allocated);
154152
signatureConversion.remapInput(
155-
numProperArguments + numWorkgroupAttributions + en.index(), descr);
153+
numProperArguments + numWorkgroupAttributions + idx, descr);
156154
}
157155
}
158156

@@ -169,15 +167,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
169167
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
170168
OpBuilder::InsertionGuard guard(rewriter);
171169
rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
172-
for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
173-
auto memrefTy = dyn_cast<MemRefType>(en.value());
170+
for (const auto [idx, argTy] :
171+
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
172+
auto memrefTy = dyn_cast<MemRefType>(argTy);
174173
if (!memrefTy)
175174
continue;
176175
assert(memrefTy.hasStaticShape() &&
177176
"Bare pointer convertion used with dynamically-shaped memrefs");
178177
// Use a placeholder when replacing uses of the memref argument to prevent
179178
// circular replacements.
180-
auto remapping = signatureConversion.getInputMapping(en.index());
179+
auto remapping = signatureConversion.getInputMapping(idx);
181180
assert(remapping && remapping->size == 1 &&
182181
"Type converter should produce 1-to-1 mapping for bare memrefs");
183182
BlockArgument newArg =
@@ -193,19 +192,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
193192

194193
// Get memref type from function arguments and set the noalias to
195194
// pointer arguments.
196-
for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
197-
auto memrefTy = en.value().dyn_cast<MemRefType>();
198-
NamedAttrList argAttr = argAttrs
199-
? argAttrs[en.index()].cast<DictionaryAttr>()
200-
: NamedAttrList();
201-
195+
for (const auto [idx, argTy] :
196+
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
197+
auto remapping = signatureConversion.getInputMapping(idx);
198+
NamedAttrList argAttr =
199+
argAttrs ? argAttrs[idx].cast<DictionaryAttr>() : NamedAttrList();
200+
auto copyAttribute = [&](StringRef attrName) {
201+
Attribute attr = argAttr.erase(attrName);
202+
if (!attr)
203+
return;
204+
for (size_t i = 0, e = remapping->size; i < e; ++i)
205+
llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
206+
};
202207
auto copyPointerAttribute = [&](StringRef attrName) {
203208
Attribute attr = argAttr.erase(attrName);
204209

205-
// This is a proxy for the bare pointer calling convention.
206210
if (!attr)
207211
return;
208-
auto remapping = signatureConversion.getInputMapping(en.index());
209212
if (remapping->size > 1 &&
210213
attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
211214
emitWarning(llvmFuncOp.getLoc(),
@@ -224,10 +227,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
224227
if (argAttr.empty())
225228
continue;
226229

227-
if (memrefTy) {
230+
copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
231+
copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
232+
copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
233+
bool lowersToPointer = false;
234+
for (size_t i = 0, e = remapping->size; i < e; ++i) {
235+
lowersToPointer |= isa<LLVM::LLVMPointerType>(
236+
llvmFuncOp.getArgument(remapping->inputNo + i).getType());
237+
}
238+
239+
if (lowersToPointer) {
228240
copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
241+
copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
242+
copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
243+
copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
229244
copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
230245
copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
246+
copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
231247
copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
232248
copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
233249
copyPointerAttribute(

mlir/test/Conversion/GPUCommon/memref-arg-attrs.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ gpu.module @kernel {
2424
// ROCDL-SAME: !llvm.ptr {llvm.writeonly}
2525
// NVVM-SAME: !llvm.ptr {llvm.writeonly}
2626

27+
// -----
28+
29+
gpu.module @kernel {
30+
gpu.func @test_func_readonly_ptr(%arg0 : !llvm.ptr {llvm.readonly} ) {
31+
gpu.return
32+
}
33+
}
34+
35+
// CHECK-LABEL: llvm.func @test_func_readonly_ptr
36+
// ROCDL-SAME: !llvm.ptr {llvm.readonly}
37+
// NVVM-SAME: !llvm.ptr {llvm.readonly}
2738

2839
// -----
2940

@@ -62,3 +73,17 @@ gpu.module @kernel {
6273
// CHECK-LABEL: llvm.func @test_func_dereferenceable_or_null
6374
// ROCDL-SAME: !llvm.ptr {llvm.dereferenceable_or_null = 4 : i64}
6475
// NVVM-SAME: !llvm.ptr {llvm.dereferenceable_or_null = 4 : i64}
76+
77+
// -----
78+
79+
gpu.module @kernel {
80+
gpu.func @test_func_noundef(%arg0 : memref<f32> {llvm.noundef} ) {
81+
gpu.return
82+
}
83+
}
84+
85+
// CHECK-LABEL: llvm.func @test_func_noundef
86+
// ROCDL-SAME: !llvm.ptr {llvm.noundef}
87+
// ROCDL-SAME: i64 {llvm.noundef}
88+
// NVVM-SAME: !llvm.ptr {llvm.noundef}
89+
// NVVM-SAME: i64 {llvm.noundef}

0 commit comments

Comments
 (0)