@@ -26,9 +26,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
26
26
27
27
SmallVector<LLVM::GlobalOp, 3 > workgroupBuffers;
28
28
workgroupBuffers.reserve (gpuFuncOp.getNumWorkgroupAttributions ());
29
- for (const auto &en : llvm::enumerate (gpuFuncOp.getWorkgroupAttributions ())) {
30
- BlockArgument attribution = en.value ();
31
-
29
+ for (const auto [idx, attribution] :
30
+ llvm::enumerate (gpuFuncOp.getWorkgroupAttributions ())) {
32
31
auto type = dyn_cast<MemRefType>(attribution.getType ());
33
32
assert (type && type.hasStaticShape () && " unexpected type in attribution" );
34
33
@@ -37,12 +36,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
37
36
auto elementType =
38
37
cast<Type>(typeConverter->convertType (type.getElementType ()));
39
38
auto arrayType = LLVM::LLVMArrayType::get (elementType, numElements);
40
- std::string name = std::string (
41
- llvm::formatv (" __wg_{0}_{1}" , gpuFuncOp.getName (), en. index () ));
39
+ std::string name =
40
+ std::string ( llvm::formatv (" __wg_{0}_{1}" , gpuFuncOp.getName (), idx ));
42
41
uint64_t alignment = 0 ;
43
42
if (auto alignAttr =
44
43
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr (
45
- en. index () , LLVM::LLVMDialect::getAlignAttrName ())))
44
+ idx , LLVM::LLVMDialect::getAlignAttrName ())))
46
45
alignment = alignAttr.getInt ();
47
46
auto globalOp = rewriter.create <LLVM::GlobalOp>(
48
47
gpuFuncOp.getLoc (), arrayType, /* isConstant=*/ false ,
@@ -105,8 +104,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
105
104
rewriter.setInsertionPointToStart (&gpuFuncOp.front ());
106
105
unsigned numProperArguments = gpuFuncOp.getNumArguments ();
107
106
108
- for (const auto &en : llvm::enumerate (workgroupBuffers)) {
109
- LLVM::GlobalOp global = en.value ();
107
+ for (const auto [idx, global] : llvm::enumerate (workgroupBuffers)) {
110
108
auto ptrType = LLVM::LLVMPointerType::get (rewriter.getContext (),
111
109
global.getAddrSpace ());
112
110
Value address = rewriter.create <LLVM::AddressOfOp>(
@@ -119,18 +117,18 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
119
117
// existing memref infrastructure. This may use more registers than
120
118
// otherwise necessary given that memref sizes are fixed, but we can try
121
119
// and canonicalize that away later.
122
- Value attribution = gpuFuncOp.getWorkgroupAttributions ()[en. index () ];
120
+ Value attribution = gpuFuncOp.getWorkgroupAttributions ()[idx ];
123
121
auto type = cast<MemRefType>(attribution.getType ());
124
122
auto descr = MemRefDescriptor::fromStaticShape (
125
123
rewriter, loc, *getTypeConverter (), type, memory);
126
- signatureConversion.remapInput (numProperArguments + en. index () , descr);
124
+ signatureConversion.remapInput (numProperArguments + idx , descr);
127
125
}
128
126
129
127
// Rewrite private memory attributions to alloca'ed buffers.
130
128
unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions ();
131
129
auto int64Ty = IntegerType::get (rewriter.getContext (), 64 );
132
- for (const auto &en : llvm::enumerate (gpuFuncOp. getPrivateAttributions ())) {
133
- Value attribution = en. value ();
130
+ for (const auto [idx, attribution] :
131
+ llvm::enumerate (gpuFuncOp. getPrivateAttributions ())) {
134
132
auto type = cast<MemRefType>(attribution.getType ());
135
133
assert (type && type.hasStaticShape () && " unexpected type in attribution" );
136
134
@@ -145,14 +143,14 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
145
143
uint64_t alignment = 0 ;
146
144
if (auto alignAttr =
147
145
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr (
148
- en. index () , LLVM::LLVMDialect::getAlignAttrName ())))
146
+ idx , LLVM::LLVMDialect::getAlignAttrName ())))
149
147
alignment = alignAttr.getInt ();
150
148
Value allocated = rewriter.create <LLVM::AllocaOp>(
151
149
gpuFuncOp.getLoc (), ptrType, elementType, numElements, alignment);
152
150
auto descr = MemRefDescriptor::fromStaticShape (
153
151
rewriter, loc, *getTypeConverter (), type, allocated);
154
152
signatureConversion.remapInput (
155
- numProperArguments + numWorkgroupAttributions + en. index () , descr);
153
+ numProperArguments + numWorkgroupAttributions + idx , descr);
156
154
}
157
155
}
158
156
@@ -169,15 +167,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
169
167
if (getTypeConverter ()->getOptions ().useBarePtrCallConv ) {
170
168
OpBuilder::InsertionGuard guard (rewriter);
171
169
rewriter.setInsertionPointToStart (&llvmFuncOp.getBody ().front ());
172
- for (const auto &en : llvm::enumerate (gpuFuncOp.getArgumentTypes ())) {
173
- auto memrefTy = dyn_cast<MemRefType>(en.value ());
170
+ for (const auto [idx, argTy] :
171
+ llvm::enumerate (gpuFuncOp.getArgumentTypes ())) {
172
+ auto memrefTy = dyn_cast<MemRefType>(argTy);
174
173
if (!memrefTy)
175
174
continue ;
176
175
assert (memrefTy.hasStaticShape () &&
177
176
" Bare pointer convertion used with dynamically-shaped memrefs" );
178
177
// Use a placeholder when replacing uses of the memref argument to prevent
179
178
// circular replacements.
180
- auto remapping = signatureConversion.getInputMapping (en. index () );
179
+ auto remapping = signatureConversion.getInputMapping (idx );
181
180
assert (remapping && remapping->size == 1 &&
182
181
" Type converter should produce 1-to-1 mapping for bare memrefs" );
183
182
BlockArgument newArg =
@@ -193,19 +192,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
193
192
194
193
// Get memref type from function arguments and set the noalias to
195
194
// pointer arguments.
196
- for (const auto &en : llvm::enumerate (gpuFuncOp.getArgumentTypes ())) {
197
- auto memrefTy = en.value ().dyn_cast <MemRefType>();
198
- NamedAttrList argAttr = argAttrs
199
- ? argAttrs[en.index ()].cast <DictionaryAttr>()
200
- : NamedAttrList ();
201
-
195
+ for (const auto [idx, argTy] :
196
+ llvm::enumerate (gpuFuncOp.getArgumentTypes ())) {
197
+ auto remapping = signatureConversion.getInputMapping (idx);
198
+ NamedAttrList argAttr =
199
+ argAttrs ? argAttrs[idx].cast <DictionaryAttr>() : NamedAttrList ();
200
+ auto copyAttribute = [&](StringRef attrName) {
201
+ Attribute attr = argAttr.erase (attrName);
202
+ if (!attr)
203
+ return ;
204
+ for (size_t i = 0 , e = remapping->size ; i < e; ++i)
205
+ llvmFuncOp.setArgAttr (remapping->inputNo + i, attrName, attr);
206
+ };
202
207
auto copyPointerAttribute = [&](StringRef attrName) {
203
208
Attribute attr = argAttr.erase (attrName);
204
209
205
- // This is a proxy for the bare pointer calling convention.
206
210
if (!attr)
207
211
return ;
208
- auto remapping = signatureConversion.getInputMapping (en.index ());
209
212
if (remapping->size > 1 &&
210
213
attrName == LLVM::LLVMDialect::getNoAliasAttrName ()) {
211
214
emitWarning (llvmFuncOp.getLoc (),
@@ -224,10 +227,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
224
227
if (argAttr.empty ())
225
228
continue ;
226
229
227
- if (memrefTy) {
230
+ copyAttribute (LLVM::LLVMDialect::getReturnedAttrName ());
231
+ copyAttribute (LLVM::LLVMDialect::getNoUndefAttrName ());
232
+ copyAttribute (LLVM::LLVMDialect::getInRegAttrName ());
233
+ bool lowersToPointer = false ;
234
+ for (size_t i = 0 , e = remapping->size ; i < e; ++i) {
235
+ lowersToPointer |= isa<LLVM::LLVMPointerType>(
236
+ llvmFuncOp.getArgument (remapping->inputNo + i).getType ());
237
+ }
238
+
239
+ if (lowersToPointer) {
228
240
copyPointerAttribute (LLVM::LLVMDialect::getNoAliasAttrName ());
241
+ copyPointerAttribute (LLVM::LLVMDialect::getNoCaptureAttrName ());
242
+ copyPointerAttribute (LLVM::LLVMDialect::getNoFreeAttrName ());
243
+ copyPointerAttribute (LLVM::LLVMDialect::getAlignAttrName ());
229
244
copyPointerAttribute (LLVM::LLVMDialect::getReadonlyAttrName ());
230
245
copyPointerAttribute (LLVM::LLVMDialect::getWriteOnlyAttrName ());
246
+ copyPointerAttribute (LLVM::LLVMDialect::getReadnoneAttrName ());
231
247
copyPointerAttribute (LLVM::LLVMDialect::getNonNullAttrName ());
232
248
copyPointerAttribute (LLVM::LLVMDialect::getDereferenceableAttrName ());
233
249
copyPointerAttribute (
0 commit comments