@@ -183,6 +183,29 @@ static bool inDeviceContext(mlir::Operation *op) {
183
183
return false ;
184
184
}
185
185
186
+ static int computeWidth (mlir::Location loc, mlir::Type type,
187
+ fir::KindMapping &kindMap) {
188
+ auto eleTy = fir::unwrapSequenceType (type);
189
+ int width = 0 ;
190
+ if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
191
+ width = t.getWidth () / 8 ;
192
+ } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
193
+ width = t.getWidth () / 8 ;
194
+ } else if (eleTy.isInteger (1 )) {
195
+ width = 1 ;
196
+ } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
197
+ int kind = t.getFKind ();
198
+ width = kindMap.getLogicalBitsize (kind) / 8 ;
199
+ } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
200
+ int kind = t.getFKind ();
201
+ int elemSize = kindMap.getRealBitsize (kind) / 8 ;
202
+ width = 2 * elemSize;
203
+ } else {
204
+ llvm::report_fatal_error (" unsupported type" );
205
+ }
206
+ return width;
207
+ }
208
+
186
209
struct CufAllocOpConversion : public mlir ::OpRewritePattern<cuf::AllocOp> {
187
210
using OpRewritePattern::OpRewritePattern;
188
211
@@ -193,11 +216,6 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
193
216
mlir::LogicalResult
194
217
matchAndRewrite (cuf::AllocOp op,
195
218
mlir::PatternRewriter &rewriter) const override {
196
- auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType ());
197
-
198
- // Only convert cuf.alloc that allocates a descriptor.
199
- if (!boxTy)
200
- return failure ();
201
219
202
220
if (inDeviceContext (op.getOperation ())) {
203
221
// In device context just replace the cuf.alloc operation with a fir.alloc
@@ -212,11 +230,56 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
212
230
auto mod = op->getParentOfType <mlir::ModuleOp>();
213
231
fir::FirOpBuilder builder (rewriter, mod);
214
232
mlir::Location loc = op.getLoc ();
233
+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
234
+
235
+ if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType ())) {
236
+ // Convert scalar and known size array allocations.
237
+ mlir::Value bytes;
238
+ fir::KindMapping kindMap{fir::getKindMapping (mod)};
239
+ if (fir::isa_trivial (op.getInType ())) {
240
+ int width = computeWidth (loc, op.getInType (), kindMap);
241
+ bytes =
242
+ builder.createIntegerConstant (loc, builder.getIndexType (), width);
243
+ } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
244
+ op.getInType ())) {
245
+ mlir::Value width = builder.createIntegerConstant (
246
+ loc, builder.getIndexType (),
247
+ computeWidth (loc, seqTy.getEleTy (), kindMap));
248
+ mlir::Value nbElem;
249
+ if (fir::sequenceWithNonConstantShape (seqTy)) {
250
+ assert (!op.getShape ().empty () && " expect shape with dynamic arrays" );
251
+ nbElem = builder.loadIfRef (loc, op.getShape ()[0 ]);
252
+ for (unsigned i = 1 ; i < op.getShape ().size (); ++i) {
253
+ nbElem = rewriter.create <mlir::arith::MulIOp>(
254
+ loc, nbElem, builder.loadIfRef (loc, op.getShape ()[i]));
255
+ }
256
+ } else {
257
+ nbElem = builder.createIntegerConstant (loc, builder.getIndexType (),
258
+ seqTy.getConstantArraySize ());
259
+ }
260
+ bytes = rewriter.create <mlir::arith::MulIOp>(loc, nbElem, width);
261
+ }
262
+ mlir::func::FuncOp func =
263
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFMemAlloc)>(loc, builder);
264
+ auto fTy = func.getFunctionType ();
265
+ mlir::Value sourceLine =
266
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (3 ));
267
+ mlir::Value memTy = builder.createIntegerConstant (
268
+ loc, builder.getI32Type (), getMemType (op.getDataAttr ()));
269
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
270
+ builder, loc, fTy , bytes, memTy, sourceFile, sourceLine)};
271
+ auto callOp = builder.create <fir::CallOp>(loc, func, args);
272
+ auto convOp = builder.createConvert (loc, op.getResult ().getType (),
273
+ callOp.getResult (0 ));
274
+ rewriter.replaceOp (op, convOp);
275
+ return mlir::success ();
276
+ }
277
+
278
+ // Convert descriptor allocations to function call.
279
+ auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType ());
215
280
mlir::func::FuncOp func =
216
281
fir::runtime::getRuntimeFunc<mkRTKey (CUFAllocDesciptor)>(loc, builder);
217
-
218
282
auto fTy = func.getFunctionType ();
219
- mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
220
283
mlir::Value sourceLine =
221
284
fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
222
285
@@ -245,26 +308,39 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
245
308
mlir::LogicalResult
246
309
matchAndRewrite (cuf::FreeOp op,
247
310
mlir::PatternRewriter &rewriter) const override {
248
- // Only convert cuf.free on descriptor.
249
- if (!mlir::isa<fir::ReferenceType>(op.getDevptr ().getType ()))
250
- return failure ();
251
- auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr ().getType ());
252
- if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy ()))
253
- return failure ();
254
-
255
311
if (inDeviceContext (op.getOperation ())) {
256
312
rewriter.eraseOp (op);
257
313
return mlir::success ();
258
314
}
259
315
316
+ if (!mlir::isa<fir::ReferenceType>(op.getDevptr ().getType ()))
317
+ return failure ();
318
+
260
319
auto mod = op->getParentOfType <mlir::ModuleOp>();
261
320
fir::FirOpBuilder builder (rewriter, mod);
262
321
mlir::Location loc = op.getLoc ();
322
+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
323
+
324
+ auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr ().getType ());
325
+ if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy ())) {
326
+ mlir::func::FuncOp func =
327
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFMemFree)>(loc, builder);
328
+ auto fTy = func.getFunctionType ();
329
+ mlir::Value sourceLine =
330
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (3 ));
331
+ mlir::Value memTy = builder.createIntegerConstant (
332
+ loc, builder.getI32Type (), getMemType (op.getDataAttr ()));
333
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
334
+ builder, loc, fTy , op.getDevptr (), memTy, sourceFile, sourceLine)};
335
+ builder.create <fir::CallOp>(loc, func, args);
336
+ rewriter.eraseOp (op);
337
+ return mlir::success ();
338
+ }
339
+
340
+ // Convert cuf.free on descriptors.
263
341
mlir::func::FuncOp func =
264
342
fir::runtime::getRuntimeFunc<mkRTKey (CUFFreeDesciptor)>(loc, builder);
265
-
266
343
auto fTy = func.getFunctionType ();
267
- mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
268
344
mlir::Value sourceLine =
269
345
fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
270
346
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
@@ -275,29 +351,6 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
275
351
}
276
352
};
277
353
278
- static int computeWidth (mlir::Location loc, mlir::Type type,
279
- fir::KindMapping &kindMap) {
280
- auto eleTy = fir::unwrapSequenceType (type);
281
- int width = 0 ;
282
- if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
283
- width = t.getWidth () / 8 ;
284
- } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
285
- width = t.getWidth () / 8 ;
286
- } else if (eleTy.isInteger (1 )) {
287
- width = 1 ;
288
- } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
289
- int kind = t.getFKind ();
290
- width = kindMap.getLogicalBitsize (kind) / 8 ;
291
- } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
292
- int kind = t.getFKind ();
293
- int elemSize = kindMap.getRealBitsize (kind) / 8 ;
294
- width = 2 * elemSize;
295
- } else {
296
- llvm::report_fatal_error (" unsupported type" );
297
- }
298
- return width;
299
- }
300
-
301
354
static mlir::Value createConvertOp (mlir::PatternRewriter &rewriter,
302
355
mlir::Location loc, mlir::Type toTy,
303
356
mlir::Value val) {
@@ -456,16 +509,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
456
509
fir::support::getOrSetDataLayout (module , /* allowDefaultLayout=*/ false );
457
510
fir::LLVMTypeConverter typeConverter (module , /* applyTBAA=*/ false ,
458
511
/* forceUnifiedTBAATree=*/ false , *dl);
459
- target.addDynamicallyLegalOp <cuf::AllocOp>([](::cuf::AllocOp op) {
460
- return !mlir::isa<fir::BaseBoxType>(op.getInType ());
461
- });
462
- target.addDynamicallyLegalOp <cuf::FreeOp>([](::cuf::FreeOp op) {
463
- if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(
464
- op.getDevptr ().getType ())) {
465
- return !mlir::isa<fir::BaseBoxType>(refTy.getEleTy ());
466
- }
467
- return true ;
468
- });
469
512
target.addDynamicallyLegalOp <cuf::DataTransferOp>(
470
513
[](::cuf::DataTransferOp op) {
471
514
mlir::Type srcTy = fir::unwrapRefType (op.getSrc ().getType ());
0 commit comments