|
22 | 22 | #include "llvm/ADT/StringExtras.h"
|
23 | 23 | #include "llvm/ADT/StringRef.h"
|
24 | 24 | #include "llvm/ADT/TypeSwitch.h"
|
| 25 | +#include "llvm/Frontend/OpenMP/OMPConstants.h" |
25 | 26 | #include <cstddef>
|
26 | 27 |
|
27 | 28 | #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
|
@@ -552,6 +553,191 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
|
552 | 553 | return success();
|
553 | 554 | }
|
554 | 555 |
|
| 556 | +//===----------------------------------------------------------------------===// |
| 557 | +// Parser, printer and verifier for Target Data |
| 558 | +//===----------------------------------------------------------------------===// |
| 559 | +/// Parses a Map Clause. |
| 560 | +/// |
| 561 | +/// map-clause = `map (` ( `(` `always, `? `close, `? `present, `? ( `to` | |
| 562 | +/// `from` | `delete` ) ` -> ` symbol-ref ` : ` type(symbol-ref) `)` )+ `)` |
| 563 | +/// Eg: map((release -> %1 : !llvm.ptr<array<1024 x i32>>), (always, close, from |
| 564 | +/// -> %2 : !llvm.ptr<array<1024 x i32>>)) |
| 565 | +static ParseResult |
| 566 | +parseMapClause(OpAsmParser &parser, |
| 567 | + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &map_operands, |
| 568 | + SmallVectorImpl<Type> &map_operand_types, ArrayAttr &map_types) { |
| 569 | + StringRef mapTypeMod; |
| 570 | + OpAsmParser::UnresolvedOperand arg1; |
| 571 | + Type arg1Type; |
| 572 | + IntegerAttr arg2; |
| 573 | + SmallVector<IntegerAttr> mapTypesVec; |
| 574 | + llvm::omp::OpenMPOffloadMappingFlags mapTypeBits; |
| 575 | + |
| 576 | + auto parseTypeAndMod = [&]() -> ParseResult { |
| 577 | + if (parser.parseKeyword(&mapTypeMod)) |
| 578 | + return failure(); |
| 579 | + |
| 580 | + if (mapTypeMod == "always") |
| 581 | + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; |
| 582 | + if (mapTypeMod == "close") |
| 583 | + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; |
| 584 | + if (mapTypeMod == "present") |
| 585 | + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; |
| 586 | + |
| 587 | + if (mapTypeMod == "to") |
| 588 | + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; |
| 589 | + if (mapTypeMod == "from") |
| 590 | + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; |
| 591 | + if (mapTypeMod == "tofrom") |
| 592 | + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | |
| 593 | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; |
| 594 | + if (mapTypeMod == "delete") |
| 595 | + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; |
| 596 | + return success(); |
| 597 | + }; |
| 598 | + |
| 599 | + auto parseMap = [&]() -> ParseResult { |
| 600 | + mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; |
| 601 | + |
| 602 | + if (parser.parseLParen() || |
| 603 | + parser.parseCommaSeparatedList(parseTypeAndMod) || |
| 604 | + parser.parseArrow() || parser.parseOperand(arg1) || |
| 605 | + parser.parseColon() || parser.parseType(arg1Type) || |
| 606 | + parser.parseRParen()) |
| 607 | + return failure(); |
| 608 | + map_operands.push_back(arg1); |
| 609 | + map_operand_types.push_back(arg1Type); |
| 610 | + arg2 = parser.getBuilder().getIntegerAttr( |
| 611 | + parser.getBuilder().getI64Type(), |
| 612 | + static_cast< |
| 613 | + std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( |
| 614 | + mapTypeBits)); |
| 615 | + mapTypesVec.push_back(arg2); |
| 616 | + return success(); |
| 617 | + }; |
| 618 | + |
| 619 | + if (parser.parseCommaSeparatedList(parseMap)) |
| 620 | + return failure(); |
| 621 | + |
| 622 | + SmallVector<Attribute> mapTypesAttr(mapTypesVec.begin(), mapTypesVec.end()); |
| 623 | + map_types = ArrayAttr::get(parser.getContext(), mapTypesAttr); |
| 624 | + return success(); |
| 625 | +} |
| 626 | + |
| 627 | +static void printMapClause(OpAsmPrinter &p, Operation *op, |
| 628 | + OperandRange map_operands, |
| 629 | + TypeRange map_operand_types, ArrayAttr map_types) { |
| 630 | + |
| 631 | + // Helper function to get bitwise AND of `value` and 'flag' |
| 632 | + auto bitAnd = [](int64_t value, |
| 633 | + llvm::omp::OpenMPOffloadMappingFlags flag) -> bool { |
| 634 | + return value & |
| 635 | + static_cast< |
| 636 | + std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( |
| 637 | + flag); |
| 638 | + }; |
| 639 | + |
| 640 | + assert(map_operands.size() == map_types.size()); |
| 641 | + |
| 642 | + for (unsigned i = 0, e = map_operands.size(); i < e; i++) { |
| 643 | + int64_t mapTypeBits = 0x00; |
| 644 | + Value mapOp = map_operands[i]; |
| 645 | + Attribute mapTypeOp = map_types[i]; |
| 646 | + |
| 647 | + assert(mapTypeOp.isa<mlir::IntegerAttr>()); |
| 648 | + mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt(); |
| 649 | + |
| 650 | + bool always = bitAnd(mapTypeBits, |
| 651 | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); |
| 652 | + bool close = bitAnd(mapTypeBits, |
| 653 | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); |
| 654 | + bool present = bitAnd( |
| 655 | + mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT); |
| 656 | + |
| 657 | + bool to = |
| 658 | + bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); |
| 659 | + bool from = |
| 660 | + bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); |
| 661 | + bool del = bitAnd(mapTypeBits, |
| 662 | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); |
| 663 | + |
| 664 | + std::string typeModStr, typeStr; |
| 665 | + llvm::raw_string_ostream typeMod(typeModStr), type(typeStr); |
| 666 | + |
| 667 | + if (always) |
| 668 | + typeMod << "always, "; |
| 669 | + if (close) |
| 670 | + typeMod << "close, "; |
| 671 | + if (present) |
| 672 | + typeMod << "present, "; |
| 673 | + |
| 674 | + if (to) |
| 675 | + type << "to"; |
| 676 | + if (from) |
| 677 | + type << "from"; |
| 678 | + if (del) |
| 679 | + type << "delete"; |
| 680 | + if (type.str().empty()) |
| 681 | + type << (isa<ExitDataOp>(op) ? "release" : "alloc"); |
| 682 | + |
| 683 | + p << '(' << typeMod.str() << type.str() << " -> " << mapOp << " : " |
| 684 | + << mapOp.getType() << ')'; |
| 685 | + if (i + 1 < e) |
| 686 | + p << ", "; |
| 687 | + } |
| 688 | +} |
| 689 | + |
| 690 | +static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands, |
| 691 | + ArrayAttr map_types) { |
| 692 | + // Helper function to get bitwise AND of `value` and 'flag' |
| 693 | + auto bitAnd = [](int64_t value, |
| 694 | + llvm::omp::OpenMPOffloadMappingFlags flag) -> bool { |
| 695 | + return value & |
| 696 | + static_cast< |
| 697 | + std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( |
| 698 | + flag); |
| 699 | + }; |
| 700 | + if (map_operands.size() != map_types.size()) |
| 701 | + return failure(); |
| 702 | + |
| 703 | + for (const auto &mapTypeOp : map_types) { |
| 704 | + int64_t mapTypeBits = 0x00; |
| 705 | + |
| 706 | + if (!mapTypeOp.isa<mlir::IntegerAttr>()) |
| 707 | + return failure(); |
| 708 | + |
| 709 | + mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt(); |
| 710 | + |
| 711 | + bool to = |
| 712 | + bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); |
| 713 | + bool from = |
| 714 | + bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); |
| 715 | + bool del = bitAnd(mapTypeBits, |
| 716 | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); |
| 717 | + |
| 718 | + if (isa<DataOp>(op) && del) |
| 719 | + return failure(); |
| 720 | + if (isa<EnterDataOp>(op) && (from || del)) |
| 721 | + return failure(); |
| 722 | + if (isa<ExitDataOp>(op) && to) |
| 723 | + return failure(); |
| 724 | + } |
| 725 | + |
| 726 | + return success(); |
| 727 | +} |
| 728 | + |
| 729 | +LogicalResult DataOp::verify() { |
| 730 | + return verifyMapClause(*this, getMapOperands(), getMapTypes()); |
| 731 | +} |
| 732 | + |
| 733 | +LogicalResult EnterDataOp::verify() { |
| 734 | + return verifyMapClause(*this, getMapOperands(), getMapTypes()); |
| 735 | +} |
| 736 | + |
| 737 | +LogicalResult ExitDataOp::verify() { |
| 738 | + return verifyMapClause(*this, getMapOperands(), getMapTypes()); |
| 739 | +} |
| 740 | + |
555 | 741 | //===----------------------------------------------------------------------===//
|
556 | 742 | // ParallelOp
|
557 | 743 | //===----------------------------------------------------------------------===//
|
|
0 commit comments