Skip to content

Commit a4699a4

Browse files
committed
[MLIR][OpenMP] Added target data, exit data, and enter data operation definition for MLIR
This includes a basic implementation for the OpenMP 5.1 Target Data, Target Exit Data and Target Enter Data constructs operation. TODO: - Depend clause support for Target Enter and Exit Data. - Mapper and Iterator value support for Map Type Modifiers. - Verifier for the operations. Co-authored-by: abidmalikwaterloo <[email protected]> Co-authored-by: raghavendra <[email protected]> Differential Revision: https://reviews.llvm.org/D131915
1 parent 0c69cb2 commit a4699a4

File tree

3 files changed

+358
-0
lines changed

3 files changed

+358
-0
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,158 @@ def FlushOp : OpenMP_Op<"flush"> {
817817
}
818818
}];
819819
}
820+
821+
//===---------------------------------------------------------------------===//
822+
// 2.14.2 target data Construct
823+
//===---------------------------------------------------------------------===//
824+
825+
def Target_DataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments]>{
826+
let summary = "target data construct";
827+
let description = [{
828+
Map variables to a device data environment for the extent of the region.
829+
830+
The omp target data directive maps variables to a device data
831+
environment, and defines the lexical scope of the data environment
832+
that is created. The omp target data directive can reduce data copies
833+
to and from the offloading device when multiple target regions are using
834+
the same data.
835+
836+
The optional $if_expr parameter specifies a boolean result of a
837+
conditional check. If this value is 1 or is not provided then the target
838+
region runs on a device, if it is 0 then the target region is executed
839+
on the host device.
840+
841+
The optional $device parameter specifies the device number for the target
842+
region.
843+
844+
The optional $use_device_ptr specifies the device pointers to the
845+
corresponding list items in the device data environment.
846+
847+
The optional $use_device_addr specifies the address of the objects in the
848+
device data enviornment.
849+
850+
The $map_operands specifies the locator-list operands of the map clause.
851+
852+
The $map_types specifies the types and modifiers for the map clause.
853+
854+
TODO: depend clause and map_type_modifier values iterator and mapper.
855+
}];
856+
857+
let arguments = (ins Optional<I1>:$if_expr,
858+
Optional<AnyInteger>:$device,
859+
Variadic<AnyType>:$use_device_ptr,
860+
Variadic<AnyType>:$use_device_addr,
861+
Variadic<AnyType>:$map_operands,
862+
I64ArrayAttr:$map_types);
863+
864+
let regions = (region AnyRegion:$region);
865+
866+
let assemblyFormat = [{
867+
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
868+
| `device` `(` $device `:` type($device) `)`
869+
| `use_device_ptr` `(` $use_device_ptr `:` type($use_device_ptr) `)`
870+
| `use_device_addr` `(` $use_device_addr `:` type($use_device_addr) `)`)
871+
`map` `(` custom<MapClause>($map_operands, type($map_operands), $map_types) `)`
872+
$region attr-dict
873+
}];
874+
875+
let hasVerifier = 1;
876+
}
877+
878+
//===---------------------------------------------------------------------===//
879+
// 2.14.3 target enter data Construct
880+
//===---------------------------------------------------------------------===//
881+
882+
def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
883+
[AttrSizedOperandSegments]>{
884+
let summary = "target enter data construct";
885+
let description = [{
886+
The target enter data directive specifies that variables are mapped to
887+
a device data environment. The target enter data directive is a
888+
stand-alone directive.
889+
890+
The optional $if_expr parameter specifies a boolean result of a
891+
conditional check. If this value is 1 or is not provided then the target
892+
region runs on a device, if it is 0 then the target region is executed on
893+
the host device.
894+
895+
The optional $device parameter specifies the device number for the
896+
target region.
897+
898+
The optional $nowait eliminates the implicit barrier so the parent task
899+
can make progress even if the target task is not yet completed.
900+
901+
The $map_operands specifies the locator-list operands of the map clause.
902+
903+
The $map_types specifies the types and modifiers for the map clause.
904+
905+
TODO: depend clause and map_type_modifier values iterator and mapper.
906+
}];
907+
908+
let arguments = (ins Optional<I1>:$if_expr,
909+
Optional<AnyInteger>:$device,
910+
UnitAttr:$nowait,
911+
Variadic<AnyType>:$map_operands,
912+
I64ArrayAttr:$map_types);
913+
914+
let assemblyFormat = [{
915+
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
916+
| `device` `(` $device `:` type($device) `)`
917+
| `nowait` $nowait)
918+
`map` `(` custom<MapClause>($map_operands, type($map_operands), $map_types) `)`
919+
attr-dict
920+
}];
921+
922+
let hasVerifier = 1;
923+
}
924+
925+
//===---------------------------------------------------------------------===//
926+
// 2.14.4 target exit data Construct
927+
//===---------------------------------------------------------------------===//
928+
929+
def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
930+
[AttrSizedOperandSegments]>{
931+
let summary = "target exit data construct";
932+
let description = [{
933+
The target exit data directive specifies that variables are mapped to a
934+
device data environment. The target exit data directive is
935+
a stand-alone directive.
936+
937+
The optional $if_expr parameter specifies a boolean result of a
938+
conditional check. If this value is 1 or is not provided then the target
939+
region runs on a device, if it is 0 then the target region is executed
940+
on the host device.
941+
942+
The optional $device parameter specifies the device number for the
943+
target region.
944+
945+
The optional $nowait eliminates the implicit barrier so the parent
946+
task can make progress even if the target task is not yet completed.
947+
948+
The $map_operands specifies the locator-list operands of the map clause.
949+
950+
The $map_types specifies the types and modifiers for the map clause.
951+
952+
TODO: depend clause and map_type_modifier values iterator and mapper.
953+
}];
954+
955+
let arguments = (ins Optional<I1>:$if_expr,
956+
Optional<AnyInteger>:$device,
957+
UnitAttr:$nowait,
958+
Variadic<AnyType>:$map_operands,
959+
I64ArrayAttr:$map_types);
960+
961+
let assemblyFormat = [{
962+
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
963+
| `device` `(` $device `:` type($device) `)`
964+
| `nowait` $nowait)
965+
`map` `(` custom<MapClause>($map_operands, type($map_operands), $map_types) `)`
966+
attr-dict
967+
}];
968+
969+
let hasVerifier = 1;
970+
}
971+
820972
//===----------------------------------------------------------------------===//
821973
// 2.14.5 target construct
822974
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/ADT/StringExtras.h"
2323
#include "llvm/ADT/StringRef.h"
2424
#include "llvm/ADT/TypeSwitch.h"
25+
#include "llvm/Frontend/OpenMP/OMPConstants.h"
2526
#include <cstddef>
2627

2728
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
@@ -552,6 +553,191 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
552553
return success();
553554
}
554555

556+
//===----------------------------------------------------------------------===//
557+
// Parser, printer and verifier for Target Data
558+
//===----------------------------------------------------------------------===//
559+
/// Parses a Map Clause.
560+
///
561+
/// map-clause = `map (` ( `(` `always, `? `close, `? `present, `? ( `to` |
562+
/// `from` | `delete` ) ` -> ` symbol-ref ` : ` type(symbol-ref) `)` )+ `)`
563+
/// Eg: map((release -> %1 : !llvm.ptr<array<1024 x i32>>), (always, close, from
564+
/// -> %2 : !llvm.ptr<array<1024 x i32>>))
565+
static ParseResult
566+
parseMapClause(OpAsmParser &parser,
567+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &map_operands,
568+
SmallVectorImpl<Type> &map_operand_types, ArrayAttr &map_types) {
569+
StringRef mapTypeMod;
570+
OpAsmParser::UnresolvedOperand arg1;
571+
Type arg1Type;
572+
IntegerAttr arg2;
573+
SmallVector<IntegerAttr> mapTypesVec;
574+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits;
575+
576+
auto parseTypeAndMod = [&]() -> ParseResult {
577+
if (parser.parseKeyword(&mapTypeMod))
578+
return failure();
579+
580+
if (mapTypeMod == "always")
581+
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
582+
if (mapTypeMod == "close")
583+
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
584+
if (mapTypeMod == "present")
585+
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
586+
587+
if (mapTypeMod == "to")
588+
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
589+
if (mapTypeMod == "from")
590+
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
591+
if (mapTypeMod == "tofrom")
592+
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
593+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
594+
if (mapTypeMod == "delete")
595+
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
596+
return success();
597+
};
598+
599+
auto parseMap = [&]() -> ParseResult {
600+
mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
601+
602+
if (parser.parseLParen() ||
603+
parser.parseCommaSeparatedList(parseTypeAndMod) ||
604+
parser.parseArrow() || parser.parseOperand(arg1) ||
605+
parser.parseColon() || parser.parseType(arg1Type) ||
606+
parser.parseRParen())
607+
return failure();
608+
map_operands.push_back(arg1);
609+
map_operand_types.push_back(arg1Type);
610+
arg2 = parser.getBuilder().getIntegerAttr(
611+
parser.getBuilder().getI64Type(),
612+
static_cast<
613+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
614+
mapTypeBits));
615+
mapTypesVec.push_back(arg2);
616+
return success();
617+
};
618+
619+
if (parser.parseCommaSeparatedList(parseMap))
620+
return failure();
621+
622+
SmallVector<Attribute> mapTypesAttr(mapTypesVec.begin(), mapTypesVec.end());
623+
map_types = ArrayAttr::get(parser.getContext(), mapTypesAttr);
624+
return success();
625+
}
626+
627+
static void printMapClause(OpAsmPrinter &p, Operation *op,
628+
OperandRange map_operands,
629+
TypeRange map_operand_types, ArrayAttr map_types) {
630+
631+
// Helper function to get bitwise AND of `value` and 'flag'
632+
auto bitAnd = [](int64_t value,
633+
llvm::omp::OpenMPOffloadMappingFlags flag) -> bool {
634+
return value &
635+
static_cast<
636+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
637+
flag);
638+
};
639+
640+
assert(map_operands.size() == map_types.size());
641+
642+
for (unsigned i = 0, e = map_operands.size(); i < e; i++) {
643+
int64_t mapTypeBits = 0x00;
644+
Value mapOp = map_operands[i];
645+
Attribute mapTypeOp = map_types[i];
646+
647+
assert(mapTypeOp.isa<mlir::IntegerAttr>());
648+
mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt();
649+
650+
bool always = bitAnd(mapTypeBits,
651+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
652+
bool close = bitAnd(mapTypeBits,
653+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
654+
bool present = bitAnd(
655+
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
656+
657+
bool to =
658+
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
659+
bool from =
660+
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
661+
bool del = bitAnd(mapTypeBits,
662+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
663+
664+
std::string typeModStr, typeStr;
665+
llvm::raw_string_ostream typeMod(typeModStr), type(typeStr);
666+
667+
if (always)
668+
typeMod << "always, ";
669+
if (close)
670+
typeMod << "close, ";
671+
if (present)
672+
typeMod << "present, ";
673+
674+
if (to)
675+
type << "to";
676+
if (from)
677+
type << "from";
678+
if (del)
679+
type << "delete";
680+
if (type.str().empty())
681+
type << (isa<ExitDataOp>(op) ? "release" : "alloc");
682+
683+
p << '(' << typeMod.str() << type.str() << " -> " << mapOp << " : "
684+
<< mapOp.getType() << ')';
685+
if (i + 1 < e)
686+
p << ", ";
687+
}
688+
}
689+
690+
static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands,
691+
ArrayAttr map_types) {
692+
// Helper function to get bitwise AND of `value` and 'flag'
693+
auto bitAnd = [](int64_t value,
694+
llvm::omp::OpenMPOffloadMappingFlags flag) -> bool {
695+
return value &
696+
static_cast<
697+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
698+
flag);
699+
};
700+
if (map_operands.size() != map_types.size())
701+
return failure();
702+
703+
for (const auto &mapTypeOp : map_types) {
704+
int64_t mapTypeBits = 0x00;
705+
706+
if (!mapTypeOp.isa<mlir::IntegerAttr>())
707+
return failure();
708+
709+
mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt();
710+
711+
bool to =
712+
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
713+
bool from =
714+
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
715+
bool del = bitAnd(mapTypeBits,
716+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
717+
718+
if (isa<DataOp>(op) && del)
719+
return failure();
720+
if (isa<EnterDataOp>(op) && (from || del))
721+
return failure();
722+
if (isa<ExitDataOp>(op) && to)
723+
return failure();
724+
}
725+
726+
return success();
727+
}
728+
729+
LogicalResult DataOp::verify() {
730+
return verifyMapClause(*this, getMapOperands(), getMapTypes());
731+
}
732+
733+
LogicalResult EnterDataOp::verify() {
734+
return verifyMapClause(*this, getMapOperands(), getMapTypes());
735+
}
736+
737+
LogicalResult ExitDataOp::verify() {
738+
return verifyMapClause(*this, getMapOperands(), getMapTypes());
739+
}
740+
555741
//===----------------------------------------------------------------------===//
556742
// ParallelOp
557743
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,26 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32) -> ()
495495
return
496496
}
497497

498+
// CHECK-LABEL: omp_target_data
499+
func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref<i32>, %device_addr: memref<?xi32>, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
500+
// CHECK: omp.target_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) map((always, from -> %[[VAL_2:.*]] : memref<?xi32>))
501+
omp.target_data if(%if_cond : i1) device(%device : si32) map((always, from -> %map1 : memref<?xi32>)){}
502+
503+
// CHECK: omp.target_data use_device_ptr(%[[VAL_3:.*]] : memref<i32>) use_device_addr(%[[VAL_4:.*]] : memref<?xi32>) map((close, present, to -> %[[VAL_2:.*]] : memref<?xi32>))
504+
omp.target_data use_device_ptr(%device_ptr : memref<i32>) use_device_addr(%device_addr : memref<?xi32>) map((close, present, to -> %map1 : memref<?xi32>)){}
505+
506+
// CHECK: omp.target_data map((tofrom -> %[[VAL_2]] : memref<?xi32>), (alloc -> %[[VAL_5:.*]] : memref<?xi32>))
507+
omp.target_data map((tofrom -> %map1 : memref<?xi32>), (alloc -> %map2 : memref<?xi32>)){}
508+
509+
// CHECK: omp.target_enter_data if(%[[VAL_0]] : i1) device(%[[VAL_1]] : si32) nowait map((alloc -> %[[VAL_2]] : memref<?xi32>))
510+
omp.target_enter_data if(%if_cond : i1) device(%device : si32) nowait map((alloc -> %map1 : memref<?xi32>))
511+
512+
// CHECK: omp.target_exit_data if(%[[VAL_0]] : i1) device(%[[VAL_1]] : si32) nowait map((release -> %[[VAL_5]] : memref<?xi32>))
513+
omp.target_exit_data if(%if_cond : i1) device(%device : si32) nowait map((release -> %map2 : memref<?xi32>))
514+
515+
return
516+
}
517+
498518
// CHECK-LABEL: omp_target_pretty
499519
func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : i32) -> () {
500520
// CHECK: omp.target if({{.*}}) device({{.*}})

0 commit comments

Comments
 (0)