Skip to content

Commit 0bbebf6

Browse files
authored
[flang][cuda] Convert cuf.data_transfer with descriptors (#108890)
Convert cuf.data_transfer operations involving descriptors to the newly introduced entry points (#108244).
1 parent 39a4b32 commit 0bbebf6

File tree

2 files changed

+315
-2
lines changed

2 files changed

+315
-2
lines changed

flang/lib/Optimizer/Transforms/CufOpConversion.cpp

Lines changed: 175 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1616
#include "flang/Optimizer/Support/DataLayout.h"
1717
#include "flang/Runtime/CUDA/descriptor.h"
18+
#include "flang/Runtime/CUDA/memory.h"
1819
#include "flang/Runtime/allocatable.h"
1920
#include "mlir/Pass/Pass.h"
2021
#include "mlir/Transforms/DialectConversion.h"
@@ -255,6 +256,171 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
255256
}
256257
};
257258

259+
static int computeWidth(mlir::Location loc, mlir::Type type,
260+
fir::KindMapping &kindMap) {
261+
auto eleTy = fir::unwrapSequenceType(type);
262+
int width = 0;
263+
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
264+
width = t.getWidth() / 8;
265+
} else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
266+
width = t.getWidth() / 8;
267+
} else if (eleTy.isInteger(1)) {
268+
width = 1;
269+
} else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
270+
int kind = t.getFKind();
271+
width = kindMap.getLogicalBitsize(kind) / 8;
272+
} else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
273+
int kind = t.getFKind();
274+
int elemSize = kindMap.getRealBitsize(kind) / 8;
275+
width = 2 * elemSize;
276+
} else {
277+
llvm::report_fatal_error("unsupported type");
278+
}
279+
return width;
280+
}
281+
282+
static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
283+
mlir::Location loc, mlir::Type toTy,
284+
mlir::Value val) {
285+
if (val.getType() != toTy)
286+
return rewriter.create<fir::ConvertOp>(loc, toTy, val);
287+
return val;
288+
}
289+
290+
struct CufDataTransferOpConversion
291+
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
292+
using OpRewritePattern::OpRewritePattern;
293+
294+
mlir::LogicalResult
295+
matchAndRewrite(cuf::DataTransferOp op,
296+
mlir::PatternRewriter &rewriter) const override {
297+
298+
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
299+
mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
300+
301+
// Only convert cuf.data_transfer with at least one descripor.
302+
if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
303+
!mlir::isa<fir::BaseBoxType>(dstTy))
304+
return failure();
305+
306+
unsigned mode;
307+
if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
308+
mode = kHostToDevice;
309+
} else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) {
310+
mode = kDeviceToHost;
311+
} else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) {
312+
mode = kDeviceToDevice;
313+
}
314+
315+
auto mod = op->getParentOfType<mlir::ModuleOp>();
316+
fir::FirOpBuilder builder(rewriter, mod);
317+
mlir::Location loc = op.getLoc();
318+
319+
if (mlir::isa<fir::BaseBoxType>(srcTy) &&
320+
mlir::isa<fir::BaseBoxType>(dstTy)) {
321+
// Transfer between two descriptor.
322+
mlir::func::FuncOp func =
323+
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
324+
loc, builder);
325+
326+
auto fTy = func.getFunctionType();
327+
mlir::Value modeValue =
328+
builder.createIntegerConstant(loc, builder.getI32Type(), mode);
329+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
330+
mlir::Value sourceLine =
331+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
332+
mlir::Value dst = builder.loadIfRef(loc, op.getDst());
333+
mlir::Value src = builder.loadIfRef(loc, op.getSrc());
334+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
335+
builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
336+
builder.create<fir::CallOp>(loc, func, args);
337+
rewriter.eraseOp(op);
338+
} else if (mlir::isa<fir::BaseBoxType>(dstTy) && fir::isa_trivial(srcTy)) {
339+
// Scalar to descriptor transfer.
340+
mlir::Value val = op.getSrc();
341+
if (op.getSrc().getDefiningOp() &&
342+
mlir::isa<mlir::arith::ConstantOp>(op.getSrc().getDefiningOp())) {
343+
mlir::Value alloc = builder.createTemporary(loc, srcTy);
344+
builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
345+
val = alloc;
346+
}
347+
348+
mlir::func::FuncOp func =
349+
fir::runtime::getRuntimeFunc<mkRTKey(CUFMemsetDescriptor)>(loc,
350+
builder);
351+
auto fTy = func.getFunctionType();
352+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
353+
mlir::Value sourceLine =
354+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
355+
mlir::Value dst = builder.loadIfRef(loc, op.getDst());
356+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
357+
builder, loc, fTy, dst, val, sourceFile, sourceLine)};
358+
builder.create<fir::CallOp>(loc, func, args);
359+
rewriter.eraseOp(op);
360+
} else {
361+
mlir::Value modeValue =
362+
builder.createIntegerConstant(loc, builder.getI32Type(), mode);
363+
// Type used to compute the width.
364+
mlir::Type computeType = dstTy;
365+
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
366+
bool dstIsDesc = false;
367+
if (mlir::isa<fir::BaseBoxType>(dstTy)) {
368+
dstIsDesc = true;
369+
computeType = srcTy;
370+
seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
371+
}
372+
fir::KindMapping kindMap{fir::getKindMapping(mod)};
373+
int width = computeWidth(loc, computeType, kindMap);
374+
375+
mlir::Value nbElement;
376+
mlir::Type idxTy = rewriter.getIndexType();
377+
if (!op.getShape()) {
378+
nbElement = rewriter.create<mlir::arith::ConstantOp>(
379+
loc, idxTy,
380+
rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize()));
381+
} else {
382+
auto shapeOp =
383+
mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
384+
nbElement =
385+
createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]);
386+
for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
387+
auto operand =
388+
createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]);
389+
nbElement =
390+
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
391+
}
392+
}
393+
394+
mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
395+
loc, idxTy, rewriter.getIntegerAttr(idxTy, width));
396+
mlir::Value bytes =
397+
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
398+
399+
mlir::func::FuncOp func =
400+
dstIsDesc
401+
? fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescPtr)>(
402+
loc, builder)
403+
: fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
404+
loc, builder);
405+
auto fTy = func.getFunctionType();
406+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
407+
mlir::Value sourceLine =
408+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
409+
mlir::Value dst =
410+
dstIsDesc ? builder.loadIfRef(loc, op.getDst()) : op.getDst();
411+
mlir::Value src = mlir::isa<fir::BaseBoxType>(srcTy)
412+
? builder.loadIfRef(loc, op.getSrc())
413+
: op.getSrc();
414+
llvm::SmallVector<mlir::Value> args{
415+
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
416+
modeValue, sourceFile, sourceLine)};
417+
builder.create<fir::CallOp>(loc, func, args);
418+
rewriter.eraseOp(op);
419+
}
420+
return mlir::success();
421+
}
422+
};
423+
258424
class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
259425
public:
260426
void runOnOperation() override {
@@ -285,10 +451,17 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
285451
[](::cuf::AllocateOp op) { return needDoubleDescriptor(op); });
286452
target.addDynamicallyLegalOp<cuf::DeallocateOp>(
287453
[](::cuf::DeallocateOp op) { return needDoubleDescriptor(op); });
288-
target.addLegalDialect<fir::FIROpsDialect>();
454+
target.addDynamicallyLegalOp<cuf::DataTransferOp>(
455+
[](::cuf::DataTransferOp op) {
456+
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
457+
mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
458+
return !mlir::isa<fir::BaseBoxType>(srcTy) &&
459+
!mlir::isa<fir::BaseBoxType>(dstTy);
460+
});
461+
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect>();
289462
patterns.insert<CufAllocOpConversion>(ctx, &*dl, &typeConverter);
290463
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
291-
CufFreeOpConversion>(ctx);
464+
CufFreeOpConversion, CufDataTransferOpConversion>(ctx);
292465
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
293466
std::move(patterns)))) {
294467
mlir::emitError(mlir::UnknownLoc::get(ctx),

0 commit comments

Comments
 (0)