@@ -1367,44 +1367,92 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
1367
1367
return success ();
1368
1368
}
1369
1369
1370
- int64_t getSizeInBytes (DataLayout &DL, const mlir:: Type &type) {
1370
+ int64_t getSizeInBytes (DataLayout &DL, const Type &type) {
1371
1371
if (isa<LLVM::LLVMPointerType>(type))
1372
1372
return DL.getTypeSize (cast<LLVM::LLVMPointerType>(type).getElementType ());
1373
1373
1374
1374
return 0 ;
1375
1375
}
1376
1376
1377
+ // Generate all map related information and fill the combinedInfo.
1377
1378
static void genMapInfos (llvm::IRBuilderBase &builder,
1378
1379
LLVM::ModuleTranslation &moduleTranslation,
1379
1380
DataLayout &DL,
1380
1381
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
1381
1382
const SmallVector<Value> &mapOperands,
1382
- const ArrayAttr &mapTypes) {
1383
- // Get map clause information.
1383
+ const ArrayAttr &mapTypes,
1384
+ const SmallVector<Value> &devPtrOperands = {},
1385
+ const SmallVector<Value> &devAddrOperands = {}) {
1384
1386
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
1387
+
1388
+ auto fail = [&combinedInfo]() -> void {
1389
+ combinedInfo.BasePointers .clear ();
1390
+ combinedInfo.Pointers .clear ();
1391
+ combinedInfo.DevicePointers .clear ();
1392
+ combinedInfo.Sizes .clear ();
1393
+ combinedInfo.Types .clear ();
1394
+ combinedInfo.Names .clear ();
1395
+ };
1396
+
1397
+ auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
1398
+ index = 0 ;
1399
+ for (auto basePtr : combinedInfo.BasePointers ) {
1400
+ if (basePtr == val)
1401
+ return true ;
1402
+ index++;
1403
+ }
1404
+ return false ;
1405
+ };
1406
+
1385
1407
unsigned index = 0 ;
1386
1408
for (const auto &mapOp : mapOperands) {
1387
- if (!mapOp.getType ().isa <LLVM::LLVMPointerType>()) {
1388
- // TODO: Only LLVMPointerTypes are handled.
1389
- combinedInfo.BasePointers .clear ();
1390
- combinedInfo.Pointers .clear ();
1391
- combinedInfo.Sizes .clear ();
1392
- combinedInfo.Types .clear ();
1393
- combinedInfo.Names .clear ();
1394
- return ;
1395
- }
1409
+ // TODO: Only LLVMPointerTypes are handled.
1410
+ if (!mapOp.getType ().isa <LLVM::LLVMPointerType>())
1411
+ return fail ();
1396
1412
1397
1413
llvm::Value *mapOpValue = moduleTranslation.lookupValue (mapOp);
1398
1414
combinedInfo.BasePointers .emplace_back (mapOpValue);
1399
1415
combinedInfo.Pointers .emplace_back (mapOpValue);
1416
+ combinedInfo.DevicePointers .emplace_back (
1417
+ llvm::OpenMPIRBuilder::DeviceInfoTy::None);
1400
1418
combinedInfo.Names .emplace_back (
1401
- mlir:: LLVM::createMappingInformation (mapOp.getLoc (), *ompBuilder));
1419
+ LLVM::createMappingInformation (mapOp.getLoc (), *ompBuilder));
1402
1420
combinedInfo.Types .emplace_back (llvm::omp::OpenMPOffloadMappingFlags (
1403
- mapTypes[index].dyn_cast <mlir:: IntegerAttr>().getInt ()));
1421
+ mapTypes[index].dyn_cast <IntegerAttr>().getInt ()));
1404
1422
combinedInfo.Sizes .emplace_back (
1405
1423
builder.getInt64 (getSizeInBytes (DL, mapOp.getType ())));
1406
1424
index++;
1407
1425
}
1426
+
1427
+ auto addDevInfos = [&, fail](auto devOperands, auto devOpType) -> void {
1428
+ for (const auto &devOp : devOperands) {
1429
+ // TODO: Only LLVMPointerTypes are handled.
1430
+ if (!devOp.getType ().template isa <LLVM::LLVMPointerType>())
1431
+ return fail ();
1432
+
1433
+ llvm::Value *mapOpValue = moduleTranslation.lookupValue (devOp);
1434
+
1435
+ // Check if map info is already present for this entry.
1436
+ unsigned infoIndex;
1437
+ if (findMapInfo (mapOpValue, infoIndex)) {
1438
+ combinedInfo.Types [infoIndex] |=
1439
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1440
+ combinedInfo.DevicePointers [infoIndex] = devOpType;
1441
+ } else {
1442
+ combinedInfo.BasePointers .emplace_back (mapOpValue);
1443
+ combinedInfo.Pointers .emplace_back (mapOpValue);
1444
+ combinedInfo.DevicePointers .emplace_back (devOpType);
1445
+ combinedInfo.Names .emplace_back (
1446
+ LLVM::createMappingInformation (devOp.getLoc (), *ompBuilder));
1447
+ combinedInfo.Types .emplace_back (
1448
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
1449
+ combinedInfo.Sizes .emplace_back (builder.getInt64 (0 ));
1450
+ }
1451
+ }
1452
+ };
1453
+
1454
+ addDevInfos (devPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
1455
+ addDevInfos (devAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
1408
1456
}
1409
1457
1410
1458
static LogicalResult
@@ -1413,6 +1461,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
1413
1461
llvm::Value *ifCond = nullptr ;
1414
1462
int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
1415
1463
SmallVector<Value> mapOperands;
1464
+ SmallVector<Value> useDevPtrOperands;
1465
+ SmallVector<Value> useDevAddrOperands;
1416
1466
ArrayAttr mapTypes;
1417
1467
llvm::omp::RuntimeFunction RTLFn;
1418
1468
DataLayout DL = DataLayout (op->getParentOfType <ModuleOp>());
@@ -1422,23 +1472,20 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
1422
1472
LogicalResult result =
1423
1473
llvm::TypeSwitch<Operation *, LogicalResult>(op)
1424
1474
.Case ([&](omp::DataOp dataOp) {
1425
- if (!dataOp.getUseDeviceAddr ().empty () ||
1426
- !dataOp.getUseDevicePtr ().empty ())
1427
- return failure ();
1428
-
1429
1475
if (auto ifExprVar = dataOp.getIfExpr ())
1430
1476
ifCond = moduleTranslation.lookupValue (ifExprVar);
1431
1477
1432
1478
if (auto devId = dataOp.getDevice ())
1433
- if (auto constOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(
1434
- devId.getDefiningOp ()))
1435
- if (auto intAttr =
1436
- dyn_cast<mlir::IntegerAttr>(constOp.getValue ()))
1479
+ if (auto constOp =
1480
+ dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp ()))
1481
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue ()))
1437
1482
deviceID = intAttr.getInt ();
1438
1483
1439
1484
mapOperands = dataOp.getMapOperands ();
1440
1485
if (dataOp.getMapTypes ())
1441
1486
mapTypes = dataOp.getMapTypes ().value ();
1487
+ useDevPtrOperands = dataOp.getUseDevicePtr ();
1488
+ useDevAddrOperands = dataOp.getUseDeviceAddr ();
1442
1489
return success ();
1443
1490
})
1444
1491
.Case ([&](omp::EnterDataOp enterDataOp) {
@@ -1449,10 +1496,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
1449
1496
ifCond = moduleTranslation.lookupValue (ifExprVar);
1450
1497
1451
1498
if (auto devId = enterDataOp.getDevice ())
1452
- if (auto constOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(
1453
- devId.getDefiningOp ()))
1454
- if (auto intAttr =
1455
- dyn_cast<mlir::IntegerAttr>(constOp.getValue ()))
1499
+ if (auto constOp =
1500
+ dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp ()))
1501
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue ()))
1456
1502
deviceID = intAttr.getInt ();
1457
1503
RTLFn = llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
1458
1504
mapOperands = enterDataOp.getMapOperands ();
@@ -1467,10 +1513,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
1467
1513
ifCond = moduleTranslation.lookupValue (ifExprVar);
1468
1514
1469
1515
if (auto devId = exitDataOp.getDevice ())
1470
- if (auto constOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(
1471
- devId.getDefiningOp ()))
1472
- if (auto intAttr =
1473
- dyn_cast<mlir::IntegerAttr>(constOp.getValue ()))
1516
+ if (auto constOp =
1517
+ dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp ()))
1518
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue ()))
1474
1519
deviceID = intAttr.getInt ();
1475
1520
1476
1521
RTLFn = llvm::omp::OMPRTL___tgt_target_data_end_mapper;
@@ -1493,38 +1538,68 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
1493
1538
auto genMapInfoCB =
1494
1539
[&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
1495
1540
builder.restoreIP (codeGenIP);
1496
- genMapInfos (builder, moduleTranslation, DL, combinedInfo, mapOperands,
1497
- mapTypes);
1541
+ if (auto DataOp = dyn_cast<omp::DataOp>(op)) {
1542
+ genMapInfos (builder, moduleTranslation, DL, combinedInfo, mapOperands,
1543
+ mapTypes, useDevPtrOperands, useDevAddrOperands);
1544
+ } else {
1545
+ genMapInfos (builder, moduleTranslation, DL, combinedInfo, mapOperands,
1546
+ mapTypes);
1547
+ }
1498
1548
return combinedInfo;
1499
1549
};
1500
1550
1501
- LogicalResult bodyGenStatus = success ();
1551
+ llvm::OpenMPIRBuilder::TargetDataInfo info (/* RequiresDevicePointerInfo=*/ true ,
1552
+ /* SeparateBeginEndCalls=*/ true );
1553
+
1502
1554
using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
1555
+ LogicalResult bodyGenStatus = success ();
1503
1556
auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) {
1557
+ assert (isa<omp::DataOp>(op) && " BodyGen requested for non DataOp" );
1558
+ Region ®ion = cast<omp::DataOp>(op).getRegion ();
1504
1559
switch (bodyGenType) {
1505
1560
case BodyGenTy::Priv:
1561
+ // Check if any device ptr/addr info is available
1562
+ if (!info.DevicePtrInfoMap .empty ()) {
1563
+ builder.restoreIP (codeGenIP);
1564
+ unsigned argIndex = 0 ;
1565
+ for (auto &devPtrOp : useDevPtrOperands) {
1566
+ llvm::Value *mapOpValue = moduleTranslation.lookupValue (devPtrOp);
1567
+ const auto &arg = region.front ().getArgument (argIndex);
1568
+ moduleTranslation.mapValue (arg,
1569
+ info.DevicePtrInfoMap [mapOpValue].second );
1570
+ argIndex++;
1571
+ }
1572
+
1573
+ for (auto &devAddrOp : useDevAddrOperands) {
1574
+ llvm::Value *mapOpValue = moduleTranslation.lookupValue (devAddrOp);
1575
+ const auto &arg = region.front ().getArgument (argIndex);
1576
+ auto *LI = builder.CreateLoad (
1577
+ builder.getPtrTy (), info.DevicePtrInfoMap [mapOpValue].second );
1578
+ moduleTranslation.mapValue (arg, LI);
1579
+ argIndex++;
1580
+ }
1581
+
1582
+ bodyGenStatus = inlineConvertOmpRegions (region, " omp.data.region" ,
1583
+ builder, moduleTranslation);
1584
+ }
1506
1585
break ;
1507
1586
case BodyGenTy::DupNoPriv:
1508
1587
break ;
1509
- case BodyGenTy::NoPriv: {
1510
- // DataOp has only one region associated with it.
1511
- auto ®ion = cast<omp::DataOp>(op).getRegion ();
1512
- builder.restoreIP (codeGenIP);
1513
- bodyGenStatus = inlineConvertOmpRegions (region, " omp.data.region" ,
1514
- builder, moduleTranslation);
1515
- }
1588
+ case BodyGenTy::NoPriv:
1589
+ // If device info is available then region has already been generated
1590
+ if (info.DevicePtrInfoMap .empty ()) {
1591
+ builder.restoreIP (codeGenIP);
1592
+ bodyGenStatus = inlineConvertOmpRegions (region, " omp.data.region" ,
1593
+ builder, moduleTranslation);
1594
+ }
1595
+ break ;
1516
1596
}
1517
1597
return builder.saveIP ();
1518
1598
};
1519
1599
1520
1600
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
1521
1601
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1522
1602
findAllocaInsertPoint (builder, moduleTranslation);
1523
-
1524
- // TODO: Add support for DevicePointerInfo
1525
- llvm::OpenMPIRBuilder::TargetDataInfo info (
1526
- /* RequiresDevicePointerInfo=*/ false ,
1527
- /* SeparateBeginEndCalls=*/ true );
1528
1603
if (isa<omp::DataOp>(op)) {
1529
1604
builder.restoreIP (ompBuilder->createTargetData (
1530
1605
ompLoc, allocaIP, builder.saveIP (), builder.getInt64 (deviceID), ifCond,
@@ -1693,7 +1768,7 @@ convertDeclareTargetAttr(Operation *op,
1693
1768
// lowering while removing functions at the current time.
1694
1769
if (!isDeviceCompilation)
1695
1770
return success ();
1696
-
1771
+
1697
1772
omp::DeclareTargetDeviceType declareType =
1698
1773
declareTargetAttr.getDeviceType ().getValue ();
1699
1774
0 commit comments