|
15 | 15 | #include "flang/Optimizer/HLFIR/HLFIROps.h"
|
16 | 16 | #include "flang/Optimizer/Support/DataLayout.h"
|
17 | 17 | #include "flang/Runtime/CUDA/descriptor.h"
|
| 18 | +#include "flang/Runtime/CUDA/memory.h" |
18 | 19 | #include "flang/Runtime/allocatable.h"
|
19 | 20 | #include "mlir/Pass/Pass.h"
|
20 | 21 | #include "mlir/Transforms/DialectConversion.h"
|
@@ -255,6 +256,171 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
|
255 | 256 | }
|
256 | 257 | };
|
257 | 258 |
|
| 259 | +static int computeWidth(mlir::Location loc, mlir::Type type, |
| 260 | + fir::KindMapping &kindMap) { |
| 261 | + auto eleTy = fir::unwrapSequenceType(type); |
| 262 | + int width = 0; |
| 263 | + if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) { |
| 264 | + width = t.getWidth() / 8; |
| 265 | + } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) { |
| 266 | + width = t.getWidth() / 8; |
| 267 | + } else if (eleTy.isInteger(1)) { |
| 268 | + width = 1; |
| 269 | + } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) { |
| 270 | + int kind = t.getFKind(); |
| 271 | + width = kindMap.getLogicalBitsize(kind) / 8; |
| 272 | + } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) { |
| 273 | + int kind = t.getFKind(); |
| 274 | + int elemSize = kindMap.getRealBitsize(kind) / 8; |
| 275 | + width = 2 * elemSize; |
| 276 | + } else { |
| 277 | + llvm::report_fatal_error("unsupported type"); |
| 278 | + } |
| 279 | + return width; |
| 280 | +} |
| 281 | + |
| 282 | +static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, |
| 283 | + mlir::Location loc, mlir::Type toTy, |
| 284 | + mlir::Value val) { |
| 285 | + if (val.getType() != toTy) |
| 286 | + return rewriter.create<fir::ConvertOp>(loc, toTy, val); |
| 287 | + return val; |
| 288 | +} |
| 289 | + |
| 290 | +struct CufDataTransferOpConversion |
| 291 | + : public mlir::OpRewritePattern<cuf::DataTransferOp> { |
| 292 | + using OpRewritePattern::OpRewritePattern; |
| 293 | + |
| 294 | + mlir::LogicalResult |
| 295 | + matchAndRewrite(cuf::DataTransferOp op, |
| 296 | + mlir::PatternRewriter &rewriter) const override { |
| 297 | + |
| 298 | + mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); |
| 299 | + mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); |
| 300 | + |
| 301 | + // Only convert cuf.data_transfer with at least one descripor. |
| 302 | + if (!mlir::isa<fir::BaseBoxType>(srcTy) && |
| 303 | + !mlir::isa<fir::BaseBoxType>(dstTy)) |
| 304 | + return failure(); |
| 305 | + |
| 306 | + unsigned mode; |
| 307 | + if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) { |
| 308 | + mode = kHostToDevice; |
| 309 | + } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) { |
| 310 | + mode = kDeviceToHost; |
| 311 | + } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) { |
| 312 | + mode = kDeviceToDevice; |
| 313 | + } |
| 314 | + |
| 315 | + auto mod = op->getParentOfType<mlir::ModuleOp>(); |
| 316 | + fir::FirOpBuilder builder(rewriter, mod); |
| 317 | + mlir::Location loc = op.getLoc(); |
| 318 | + |
| 319 | + if (mlir::isa<fir::BaseBoxType>(srcTy) && |
| 320 | + mlir::isa<fir::BaseBoxType>(dstTy)) { |
| 321 | + // Transfer between two descriptor. |
| 322 | + mlir::func::FuncOp func = |
| 323 | + fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>( |
| 324 | + loc, builder); |
| 325 | + |
| 326 | + auto fTy = func.getFunctionType(); |
| 327 | + mlir::Value modeValue = |
| 328 | + builder.createIntegerConstant(loc, builder.getI32Type(), mode); |
| 329 | + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); |
| 330 | + mlir::Value sourceLine = |
| 331 | + fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); |
| 332 | + mlir::Value dst = builder.loadIfRef(loc, op.getDst()); |
| 333 | + mlir::Value src = builder.loadIfRef(loc, op.getSrc()); |
| 334 | + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( |
| 335 | + builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)}; |
| 336 | + builder.create<fir::CallOp>(loc, func, args); |
| 337 | + rewriter.eraseOp(op); |
| 338 | + } else if (mlir::isa<fir::BaseBoxType>(dstTy) && fir::isa_trivial(srcTy)) { |
| 339 | + // Scalar to descriptor transfer. |
| 340 | + mlir::Value val = op.getSrc(); |
| 341 | + if (op.getSrc().getDefiningOp() && |
| 342 | + mlir::isa<mlir::arith::ConstantOp>(op.getSrc().getDefiningOp())) { |
| 343 | + mlir::Value alloc = builder.createTemporary(loc, srcTy); |
| 344 | + builder.create<fir::StoreOp>(loc, op.getSrc(), alloc); |
| 345 | + val = alloc; |
| 346 | + } |
| 347 | + |
| 348 | + mlir::func::FuncOp func = |
| 349 | + fir::runtime::getRuntimeFunc<mkRTKey(CUFMemsetDescriptor)>(loc, |
| 350 | + builder); |
| 351 | + auto fTy = func.getFunctionType(); |
| 352 | + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); |
| 353 | + mlir::Value sourceLine = |
| 354 | + fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); |
| 355 | + mlir::Value dst = builder.loadIfRef(loc, op.getDst()); |
| 356 | + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( |
| 357 | + builder, loc, fTy, dst, val, sourceFile, sourceLine)}; |
| 358 | + builder.create<fir::CallOp>(loc, func, args); |
| 359 | + rewriter.eraseOp(op); |
| 360 | + } else { |
| 361 | + mlir::Value modeValue = |
| 362 | + builder.createIntegerConstant(loc, builder.getI32Type(), mode); |
| 363 | + // Type used to compute the width. |
| 364 | + mlir::Type computeType = dstTy; |
| 365 | + auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy); |
| 366 | + bool dstIsDesc = false; |
| 367 | + if (mlir::isa<fir::BaseBoxType>(dstTy)) { |
| 368 | + dstIsDesc = true; |
| 369 | + computeType = srcTy; |
| 370 | + seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy); |
| 371 | + } |
| 372 | + fir::KindMapping kindMap{fir::getKindMapping(mod)}; |
| 373 | + int width = computeWidth(loc, computeType, kindMap); |
| 374 | + |
| 375 | + mlir::Value nbElement; |
| 376 | + mlir::Type idxTy = rewriter.getIndexType(); |
| 377 | + if (!op.getShape()) { |
| 378 | + nbElement = rewriter.create<mlir::arith::ConstantOp>( |
| 379 | + loc, idxTy, |
| 380 | + rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize())); |
| 381 | + } else { |
| 382 | + auto shapeOp = |
| 383 | + mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp()); |
| 384 | + nbElement = |
| 385 | + createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]); |
| 386 | + for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) { |
| 387 | + auto operand = |
| 388 | + createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]); |
| 389 | + nbElement = |
| 390 | + rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand); |
| 391 | + } |
| 392 | + } |
| 393 | + |
| 394 | + mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>( |
| 395 | + loc, idxTy, rewriter.getIntegerAttr(idxTy, width)); |
| 396 | + mlir::Value bytes = |
| 397 | + rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue); |
| 398 | + |
| 399 | + mlir::func::FuncOp func = |
| 400 | + dstIsDesc |
| 401 | + ? fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescPtr)>( |
| 402 | + loc, builder) |
| 403 | + : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>( |
| 404 | + loc, builder); |
| 405 | + auto fTy = func.getFunctionType(); |
| 406 | + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); |
| 407 | + mlir::Value sourceLine = |
| 408 | + fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); |
| 409 | + mlir::Value dst = |
| 410 | + dstIsDesc ? builder.loadIfRef(loc, op.getDst()) : op.getDst(); |
| 411 | + mlir::Value src = mlir::isa<fir::BaseBoxType>(srcTy) |
| 412 | + ? builder.loadIfRef(loc, op.getSrc()) |
| 413 | + : op.getSrc(); |
| 414 | + llvm::SmallVector<mlir::Value> args{ |
| 415 | + fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes, |
| 416 | + modeValue, sourceFile, sourceLine)}; |
| 417 | + builder.create<fir::CallOp>(loc, func, args); |
| 418 | + rewriter.eraseOp(op); |
| 419 | + } |
| 420 | + return mlir::success(); |
| 421 | + } |
| 422 | +}; |
| 423 | + |
258 | 424 | class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
|
259 | 425 | public:
|
260 | 426 | void runOnOperation() override {
|
@@ -285,10 +451,17 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
|
285 | 451 | [](::cuf::AllocateOp op) { return needDoubleDescriptor(op); });
|
286 | 452 | target.addDynamicallyLegalOp<cuf::DeallocateOp>(
|
287 | 453 | [](::cuf::DeallocateOp op) { return needDoubleDescriptor(op); });
|
288 |
| - target.addLegalDialect<fir::FIROpsDialect>(); |
| 454 | + target.addDynamicallyLegalOp<cuf::DataTransferOp>( |
| 455 | + [](::cuf::DataTransferOp op) { |
| 456 | + mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); |
| 457 | + mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); |
| 458 | + return !mlir::isa<fir::BaseBoxType>(srcTy) && |
| 459 | + !mlir::isa<fir::BaseBoxType>(dstTy); |
| 460 | + }); |
| 461 | + target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect>(); |
289 | 462 | patterns.insert<CufAllocOpConversion>(ctx, &*dl, &typeConverter);
|
290 | 463 | patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
|
291 |
| - CufFreeOpConversion>(ctx); |
| 464 | + CufFreeOpConversion, CufDataTransferOpConversion>(ctx); |
292 | 465 | if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
|
293 | 466 | std::move(patterns)))) {
|
294 | 467 | mlir::emitError(mlir::UnknownLoc::get(ctx),
|
|
0 commit comments