Skip to content

Commit 1e92e25

Browse files
committed
[MLIR][OpenMP] Added MLIR translation support for use_device clauses
Added MLIR support for translating use_device_ptr and use_device_addr clauses for LLVMIR lowering. - use_device_ptr: The mapped variables marked with use_device_ptr are accessed through a copy of the base pointer mappers. The mapper is copied onto a new temporary pointer variable. - use_device_addr: The mapped variables marked with use_device_addr are accessed directly through the base pointer mappers. - If mapping information is not provided explicitly then default map_type of alloc/release is assumed and the map_size is set to 0. Depends on D152554 Reviewed By: kiranchandramohan, raghavendhra Differential Revision: https://reviews.llvm.org/D146648
1 parent 48bcaeb commit 1e92e25

File tree

4 files changed

+406
-58
lines changed

4 files changed

+406
-58
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4848,14 +4848,12 @@ void OpenMPIRBuilder::emitOffloadingArrays(
48484848
Builder.restoreIP(AllocaIP);
48494849
Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(PtrTy)};
48504850
Builder.restoreIP(CodeGenIP);
4851-
assert(DeviceAddrCB &&
4852-
"DeviceAddrCB missing for DevicePtr code generation");
4853-
DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
4851+
if (DeviceAddrCB)
4852+
DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
48544853
} else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
48554854
Info.DevicePtrInfoMap[BPVal] = {BP, BP};
4856-
assert(DeviceAddrCB &&
4857-
"DeviceAddrCB missing for DevicePtr code generation");
4858-
DeviceAddrCB(I, BP);
4855+
if (DeviceAddrCB)
4856+
DeviceAddrCB(I, BP);
48594857
}
48604858
}
48614859

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4894,6 +4894,8 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
48944894

48954895
CombinedInfo.BasePointers.emplace_back(Val1);
48964896
CombinedInfo.Pointers.emplace_back(Val1);
4897+
CombinedInfo.DevicePointers.emplace_back(
4898+
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
48974899
CombinedInfo.Sizes.emplace_back(Builder.getInt64(4));
48984900
CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(1));
48994901
uint32_t temp;
@@ -4951,6 +4953,8 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) {
49514953

49524954
CombinedInfo.BasePointers.emplace_back(Val1);
49534955
CombinedInfo.Pointers.emplace_back(Val1);
4956+
CombinedInfo.DevicePointers.emplace_back(
4957+
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
49544958
CombinedInfo.Sizes.emplace_back(Builder.getInt64(4));
49554959
CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(2));
49564960
uint32_t temp;
@@ -4996,44 +5000,79 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) {
49965000
Builder.CreateAlloca(Builder.getInt32Ty(), Builder.getInt64(1));
49975001
ASSERT_NE(Val1, nullptr);
49985002

5003+
AllocaInst *Val2 = Builder.CreateAlloca(Builder.getPtrTy());
5004+
ASSERT_NE(Val2, nullptr);
5005+
5006+
AllocaInst *Val3 = Builder.CreateAlloca(Builder.getPtrTy());
5007+
ASSERT_NE(Val3, nullptr);
5008+
49995009
IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
50005010
F->getEntryBlock().getFirstInsertionPt());
50015011

5012+
using DeviceInfoTy = llvm::OpenMPIRBuilder::DeviceInfoTy;
50025013
llvm::OpenMPIRBuilder::MapInfosTy CombinedInfo;
50035014
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
50045015
auto GenMapInfoCB =
50055016
[&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
50065017
// Get map clause information.
50075018
Builder.restoreIP(codeGenIP);
5019+
uint32_t temp;
50085020

50095021
CombinedInfo.BasePointers.emplace_back(Val1);
50105022
CombinedInfo.Pointers.emplace_back(Val1);
5023+
CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::None);
50115024
CombinedInfo.Sizes.emplace_back(Builder.getInt64(4));
50125025
CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(3));
5013-
uint32_t temp;
5026+
CombinedInfo.Names.emplace_back(
5027+
OMPBuilder.getOrCreateSrcLocStr("unknown", temp));
5028+
5029+
CombinedInfo.BasePointers.emplace_back(Val2);
5030+
CombinedInfo.Pointers.emplace_back(Val2);
5031+
CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
5032+
CombinedInfo.Sizes.emplace_back(Builder.getInt64(8));
5033+
CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(67));
5034+
CombinedInfo.Names.emplace_back(
5035+
OMPBuilder.getOrCreateSrcLocStr("unknown", temp));
5036+
5037+
CombinedInfo.BasePointers.emplace_back(Val3);
5038+
CombinedInfo.Pointers.emplace_back(Val3);
5039+
CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Address);
5040+
CombinedInfo.Sizes.emplace_back(Builder.getInt64(8));
5041+
CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(67));
50145042
CombinedInfo.Names.emplace_back(
50155043
OMPBuilder.getOrCreateSrcLocStr("unknown", temp));
50165044
return CombinedInfo;
50175045
};
50185046

50195047
llvm::OpenMPIRBuilder::TargetDataInfo Info(
5020-
/*RequiresDevicePointerInfo=*/false,
5048+
/*RequiresDevicePointerInfo=*/true,
50215049
/*SeparateBeginEndCalls=*/true);
50225050

50235051
OMPBuilder.Config.setIsGPU(true);
50245052

5025-
auto BodyCB = [&](InsertPointTy CodeGenIP, int BodyGenType) {
5026-
if (BodyGenType == 3) {
5053+
using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5054+
auto BodyCB = [&](InsertPointTy CodeGenIP, BodyGenTy BodyGenType) {
5055+
if (BodyGenType == BodyGenTy::Priv) {
5056+
EXPECT_EQ(Info.DevicePtrInfoMap.size(), 2u);
50275057
Builder.restoreIP(CodeGenIP);
5028-
CallInst *TargetDataCall = dyn_cast<CallInst>(&BB->back());
5058+
CallInst *TargetDataCall =
5059+
dyn_cast<CallInst>(BB->back().getPrevNode()->getPrevNode());
50295060
EXPECT_NE(TargetDataCall, nullptr);
50305061
EXPECT_EQ(TargetDataCall->arg_size(), 9U);
50315062
EXPECT_EQ(TargetDataCall->getCalledFunction()->getName(),
50325063
"__tgt_target_data_begin_mapper");
50335064
EXPECT_TRUE(TargetDataCall->getOperand(1)->getType()->isIntegerTy(64));
50345065
EXPECT_TRUE(TargetDataCall->getOperand(2)->getType()->isIntegerTy(32));
50355066
EXPECT_TRUE(TargetDataCall->getOperand(8)->getType()->isPointerTy());
5036-
Builder.restoreIP(CodeGenIP);
5067+
5068+
LoadInst *LI = dyn_cast<LoadInst>(BB->back().getPrevNode());
5069+
EXPECT_NE(LI, nullptr);
5070+
StoreInst *SI = dyn_cast<StoreInst>(&BB->back());
5071+
EXPECT_NE(SI, nullptr);
5072+
EXPECT_EQ(SI->getValueOperand(), LI);
5073+
EXPECT_EQ(SI->getPointerOperand(), Info.DevicePtrInfoMap[Val2].second);
5074+
EXPECT_TRUE(isa<AllocaInst>(Info.DevicePtrInfoMap[Val2].second));
5075+
EXPECT_TRUE(isa<GetElementPtrInst>(Info.DevicePtrInfoMap[Val3].second));
50375076
Builder.CreateStore(Builder.getInt32(99), Val1);
50385077
}
50395078
return Builder.saveIP();

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 121 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,44 +1367,92 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
13671367
return success();
13681368
}
13691369

1370-
int64_t getSizeInBytes(DataLayout &DL, const mlir::Type &type) {
1370+
int64_t getSizeInBytes(DataLayout &DL, const Type &type) {
13711371
if (isa<LLVM::LLVMPointerType>(type))
13721372
return DL.getTypeSize(cast<LLVM::LLVMPointerType>(type).getElementType());
13731373

13741374
return 0;
13751375
}
13761376

1377+
// Generate all map related information and fill the combinedInfo.
13771378
static void genMapInfos(llvm::IRBuilderBase &builder,
13781379
LLVM::ModuleTranslation &moduleTranslation,
13791380
DataLayout &DL,
13801381
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
13811382
const SmallVector<Value> &mapOperands,
1382-
const ArrayAttr &mapTypes) {
1383-
// Get map clause information.
1383+
const ArrayAttr &mapTypes,
1384+
const SmallVector<Value> &devPtrOperands = {},
1385+
const SmallVector<Value> &devAddrOperands = {}) {
13841386
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1387+
1388+
auto fail = [&combinedInfo]() -> void {
1389+
combinedInfo.BasePointers.clear();
1390+
combinedInfo.Pointers.clear();
1391+
combinedInfo.DevicePointers.clear();
1392+
combinedInfo.Sizes.clear();
1393+
combinedInfo.Types.clear();
1394+
combinedInfo.Names.clear();
1395+
};
1396+
1397+
auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
1398+
index = 0;
1399+
for (auto basePtr : combinedInfo.BasePointers) {
1400+
if (basePtr == val)
1401+
return true;
1402+
index++;
1403+
}
1404+
return false;
1405+
};
1406+
13851407
unsigned index = 0;
13861408
for (const auto &mapOp : mapOperands) {
1387-
if (!mapOp.getType().isa<LLVM::LLVMPointerType>()) {
1388-
// TODO: Only LLVMPointerTypes are handled.
1389-
combinedInfo.BasePointers.clear();
1390-
combinedInfo.Pointers.clear();
1391-
combinedInfo.Sizes.clear();
1392-
combinedInfo.Types.clear();
1393-
combinedInfo.Names.clear();
1394-
return;
1395-
}
1409+
// TODO: Only LLVMPointerTypes are handled.
1410+
if (!mapOp.getType().isa<LLVM::LLVMPointerType>())
1411+
return fail();
13961412

13971413
llvm::Value *mapOpValue = moduleTranslation.lookupValue(mapOp);
13981414
combinedInfo.BasePointers.emplace_back(mapOpValue);
13991415
combinedInfo.Pointers.emplace_back(mapOpValue);
1416+
combinedInfo.DevicePointers.emplace_back(
1417+
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
14001418
combinedInfo.Names.emplace_back(
1401-
mlir::LLVM::createMappingInformation(mapOp.getLoc(), *ompBuilder));
1419+
LLVM::createMappingInformation(mapOp.getLoc(), *ompBuilder));
14021420
combinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(
1403-
mapTypes[index].dyn_cast<mlir::IntegerAttr>().getInt()));
1421+
mapTypes[index].dyn_cast<IntegerAttr>().getInt()));
14041422
combinedInfo.Sizes.emplace_back(
14051423
builder.getInt64(getSizeInBytes(DL, mapOp.getType())));
14061424
index++;
14071425
}
1426+
1427+
auto addDevInfos = [&, fail](auto devOperands, auto devOpType) -> void {
1428+
for (const auto &devOp : devOperands) {
1429+
// TODO: Only LLVMPointerTypes are handled.
1430+
if (!devOp.getType().template isa<LLVM::LLVMPointerType>())
1431+
return fail();
1432+
1433+
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devOp);
1434+
1435+
// Check if map info is already present for this entry.
1436+
unsigned infoIndex;
1437+
if (findMapInfo(mapOpValue, infoIndex)) {
1438+
combinedInfo.Types[infoIndex] |=
1439+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1440+
combinedInfo.DevicePointers[infoIndex] = devOpType;
1441+
} else {
1442+
combinedInfo.BasePointers.emplace_back(mapOpValue);
1443+
combinedInfo.Pointers.emplace_back(mapOpValue);
1444+
combinedInfo.DevicePointers.emplace_back(devOpType);
1445+
combinedInfo.Names.emplace_back(
1446+
LLVM::createMappingInformation(devOp.getLoc(), *ompBuilder));
1447+
combinedInfo.Types.emplace_back(
1448+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
1449+
combinedInfo.Sizes.emplace_back(builder.getInt64(0));
1450+
}
1451+
}
1452+
};
1453+
1454+
addDevInfos(devPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
1455+
addDevInfos(devAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
14081456
}
14091457

14101458
static LogicalResult
@@ -1413,6 +1461,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
14131461
llvm::Value *ifCond = nullptr;
14141462
int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
14151463
SmallVector<Value> mapOperands;
1464+
SmallVector<Value> useDevPtrOperands;
1465+
SmallVector<Value> useDevAddrOperands;
14161466
ArrayAttr mapTypes;
14171467
llvm::omp::RuntimeFunction RTLFn;
14181468
DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
@@ -1422,23 +1472,20 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
14221472
LogicalResult result =
14231473
llvm::TypeSwitch<Operation *, LogicalResult>(op)
14241474
.Case([&](omp::DataOp dataOp) {
1425-
if (!dataOp.getUseDeviceAddr().empty() ||
1426-
!dataOp.getUseDevicePtr().empty())
1427-
return failure();
1428-
14291475
if (auto ifExprVar = dataOp.getIfExpr())
14301476
ifCond = moduleTranslation.lookupValue(ifExprVar);
14311477

14321478
if (auto devId = dataOp.getDevice())
1433-
if (auto constOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(
1434-
devId.getDefiningOp()))
1435-
if (auto intAttr =
1436-
dyn_cast<mlir::IntegerAttr>(constOp.getValue()))
1479+
if (auto constOp =
1480+
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
1481+
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
14371482
deviceID = intAttr.getInt();
14381483

14391484
mapOperands = dataOp.getMapOperands();
14401485
if (dataOp.getMapTypes())
14411486
mapTypes = dataOp.getMapTypes().value();
1487+
useDevPtrOperands = dataOp.getUseDevicePtr();
1488+
useDevAddrOperands = dataOp.getUseDeviceAddr();
14421489
return success();
14431490
})
14441491
.Case([&](omp::EnterDataOp enterDataOp) {
@@ -1449,10 +1496,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
14491496
ifCond = moduleTranslation.lookupValue(ifExprVar);
14501497

14511498
if (auto devId = enterDataOp.getDevice())
1452-
if (auto constOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(
1453-
devId.getDefiningOp()))
1454-
if (auto intAttr =
1455-
dyn_cast<mlir::IntegerAttr>(constOp.getValue()))
1499+
if (auto constOp =
1500+
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
1501+
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
14561502
deviceID = intAttr.getInt();
14571503
RTLFn = llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
14581504
mapOperands = enterDataOp.getMapOperands();
@@ -1467,10 +1513,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
14671513
ifCond = moduleTranslation.lookupValue(ifExprVar);
14681514

14691515
if (auto devId = exitDataOp.getDevice())
1470-
if (auto constOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(
1471-
devId.getDefiningOp()))
1472-
if (auto intAttr =
1473-
dyn_cast<mlir::IntegerAttr>(constOp.getValue()))
1516+
if (auto constOp =
1517+
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
1518+
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
14741519
deviceID = intAttr.getInt();
14751520

14761521
RTLFn = llvm::omp::OMPRTL___tgt_target_data_end_mapper;
@@ -1493,38 +1538,68 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
14931538
auto genMapInfoCB =
14941539
[&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
14951540
builder.restoreIP(codeGenIP);
1496-
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapOperands,
1497-
mapTypes);
1541+
if (auto DataOp = dyn_cast<omp::DataOp>(op)) {
1542+
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapOperands,
1543+
mapTypes, useDevPtrOperands, useDevAddrOperands);
1544+
} else {
1545+
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapOperands,
1546+
mapTypes);
1547+
}
14981548
return combinedInfo;
14991549
};
15001550

1501-
LogicalResult bodyGenStatus = success();
1551+
llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
1552+
/*SeparateBeginEndCalls=*/true);
1553+
15021554
using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
1555+
LogicalResult bodyGenStatus = success();
15031556
auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) {
1557+
assert(isa<omp::DataOp>(op) && "BodyGen requested for non DataOp");
1558+
Region &region = cast<omp::DataOp>(op).getRegion();
15041559
switch (bodyGenType) {
15051560
case BodyGenTy::Priv:
1561+
// Check if any device ptr/addr info is available
1562+
if (!info.DevicePtrInfoMap.empty()) {
1563+
builder.restoreIP(codeGenIP);
1564+
unsigned argIndex = 0;
1565+
for (auto &devPtrOp : useDevPtrOperands) {
1566+
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devPtrOp);
1567+
const auto &arg = region.front().getArgument(argIndex);
1568+
moduleTranslation.mapValue(arg,
1569+
info.DevicePtrInfoMap[mapOpValue].second);
1570+
argIndex++;
1571+
}
1572+
1573+
for (auto &devAddrOp : useDevAddrOperands) {
1574+
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devAddrOp);
1575+
const auto &arg = region.front().getArgument(argIndex);
1576+
auto *LI = builder.CreateLoad(
1577+
builder.getPtrTy(), info.DevicePtrInfoMap[mapOpValue].second);
1578+
moduleTranslation.mapValue(arg, LI);
1579+
argIndex++;
1580+
}
1581+
1582+
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
1583+
builder, moduleTranslation);
1584+
}
15061585
break;
15071586
case BodyGenTy::DupNoPriv:
15081587
break;
1509-
case BodyGenTy::NoPriv: {
1510-
// DataOp has only one region associated with it.
1511-
auto &region = cast<omp::DataOp>(op).getRegion();
1512-
builder.restoreIP(codeGenIP);
1513-
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
1514-
builder, moduleTranslation);
1515-
}
1588+
case BodyGenTy::NoPriv:
1589+
// If device info is available then region has already been generated
1590+
if (info.DevicePtrInfoMap.empty()) {
1591+
builder.restoreIP(codeGenIP);
1592+
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
1593+
builder, moduleTranslation);
1594+
}
1595+
break;
15161596
}
15171597
return builder.saveIP();
15181598
};
15191599

15201600
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
15211601
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
15221602
findAllocaInsertPoint(builder, moduleTranslation);
1523-
1524-
// TODO: Add support for DevicePointerInfo
1525-
llvm::OpenMPIRBuilder::TargetDataInfo info(
1526-
/*RequiresDevicePointerInfo=*/false,
1527-
/*SeparateBeginEndCalls=*/true);
15281603
if (isa<omp::DataOp>(op)) {
15291604
builder.restoreIP(ompBuilder->createTargetData(
15301605
ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
@@ -1693,7 +1768,7 @@ convertDeclareTargetAttr(Operation *op,
16931768
// lowering while removing functions at the current time.
16941769
if (!isDeviceCompilation)
16951770
return success();
1696-
1771+
16971772
omp::DeclareTargetDeviceType declareType =
16981773
declareTargetAttr.getDeviceType().getValue();
16991774

0 commit comments

Comments
 (0)