Skip to content

[flang][OpenMP] Add frontend support for ompx_bare clause #111106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3474,6 +3474,16 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
Clause = ParseOpenMPOMPXAttributesClause(WrongDirective);
break;
case OMPC_ompx_bare:
if (DKind == llvm::omp::Directive::OMPD_target) {
// Flang splits the combined directives which requires OMPD_target to be
// marked as accepting the `ompx_bare` clause in `OMP.td`. Thus, we need
// to explicitly check whether this clause is applied to an `omp target`
// without `teams` and emit an error.
Diag(Tok, diag::err_omp_unexpected_clause)
<< getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);
ErrorFound = true;
WrongDirective = true;
}
if (WrongDirective)
Diag(Tok, diag::note_ompx_bare_clause)
<< getOpenMPClauseName(CKind) << "target teams";
Expand Down
4 changes: 4 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
// ClauseProcessor unique clauses
//===----------------------------------------------------------------------===//

bool ClauseProcessor::processBare(mlir::omp::BareClauseOps &result) const {
return markClauseOccurrence<omp::clause::OmpxBare>(result.bare);
}

bool ClauseProcessor::processBind(mlir::omp::BindClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Bind>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ClauseProcessor {
: converter(converter), semaCtx(semaCtx), clauses(clauses) {}

// 'Unique' clauses: They can appear at most once in the clause list.
bool processBare(mlir::omp::BareClauseOps &result) const;
bool processBind(mlir::omp::BindClauseOps &result) const;
bool
processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,7 @@ static void genTargetClauses(
llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processBare(clauseOps);
cp.processDepend(clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms);
Expand Down Expand Up @@ -2860,6 +2861,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
!std::holds_alternative<clause::Nowait>(clause.u) &&
!std::holds_alternative<clause::NumTeams>(clause.u) &&
!std::holds_alternative<clause::NumThreads>(clause.u) &&
!std::holds_alternative<clause::OmpxBare>(clause.u) &&
!std::holds_alternative<clause::Priority>(clause.u) &&
!std::holds_alternative<clause::Private>(clause.u) &&
!std::holds_alternative<clause::ProcBind>(clause.u) &&
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Parser/openmp-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ TYPE_PARSER(
parenthesized(scalarIntExpr))) ||
"NUM_THREADS" >> construct<OmpClause>(construct<OmpClause::NumThreads>(
parenthesized(scalarIntExpr))) ||
"OMPX_BARE" >> construct<OmpClause>(construct<OmpClause::OmpxBare>()) ||
"ORDER" >> construct<OmpClause>(construct<OmpClause::Order>(
parenthesized(Parser<OmpOrderClause>{}))) ||
"ORDERED" >> construct<OmpClause>(construct<OmpClause::Ordered>(
Expand Down
12 changes: 11 additions & 1 deletion flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2886,7 +2886,6 @@ CHECK_SIMPLE_CLAUSE(Align, OMPC_align)
CHECK_SIMPLE_CLAUSE(Compare, OMPC_compare)
CHECK_SIMPLE_CLAUSE(CancellationConstructType, OMPC_cancellation_construct_type)
CHECK_SIMPLE_CLAUSE(OmpxAttribute, OMPC_ompx_attribute)
CHECK_SIMPLE_CLAUSE(OmpxBare, OMPC_ompx_bare)
CHECK_SIMPLE_CLAUSE(Weak, OMPC_weak)

CHECK_REQ_SCALAR_INT_CLAUSE(NumTeams, OMPC_num_teams)
Expand Down Expand Up @@ -4411,6 +4410,17 @@ void OmpStructureChecker::Enter(const parser::OmpClause::To &x) {
}
}

void OmpStructureChecker::Enter(const parser::OmpClause::OmpxBare &x) {
// Don't call CheckAllowedClause, because it allows "ompx_bare" on
// a non-combined "target" directive (for reasons of splitting combined
// directives). In source code it's only allowed on "target teams".
if (GetContext().directive != llvm::omp::Directive::OMPD_target_teams) {
context_.Say(GetContext().clauseSource,
"%s clause is only allowed on combined TARGET TEAMS"_err_en_US,
parser::ToUpperCaseLetters(getClauseName(llvm::omp::OMPC_ompx_bare)));
}
}

llvm::StringRef OmpStructureChecker::getClauseName(llvm::omp::Clause clause) {
return llvm::omp::getOpenMPClauseName(clause);
}
Expand Down
10 changes: 10 additions & 0 deletions flang/test/Lower/OpenMP/KernelLanguage/bare-clause.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %s -o - | FileCheck %s

program test
integer :: tmp
!$omp target teams ompx_bare num_teams(42) thread_limit(43)
tmp = 1
!$omp end target teams
end program

! CHECK: omp.target ompx_bare
30 changes: 30 additions & 0 deletions flang/test/Semantics/OpenMP/ompx-bare.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
!RUN: %python %S/../test_errors.py %s %flang -fopenmp -fopenmp-version=51

subroutine test1
!ERROR: OMPX_BARE clause is only allowed on combined TARGET TEAMS
!$omp target ompx_bare
!$omp end target
end

subroutine test2
!$omp target
!ERROR: OMPX_BARE clause is only allowed on combined TARGET TEAMS
!$omp teams ompx_bare
!$omp end teams
!$omp end target
end

subroutine test3
integer i
!ERROR: OMPX_BARE clause is only allowed on combined TARGET TEAMS
!$omp target teams distribute ompx_bare
do i = 0, 10
end do
!$omp end target teams distribute
end

subroutine test4
!No errors
!$omp target teams ompx_bare
!$omp end target teams
end
9 changes: 9 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ struct ConstructDecompositionT {
bool
applyClause(const tomp::clause::OmpxAttributeT<TypeTy, IdTy, ExprTy> &clause,
const ClauseTy *);
bool applyClause(const tomp::clause::OmpxBareT<TypeTy, IdTy, ExprTy> &clause,
const ClauseTy *);

uint32_t version;
llvm::omp::Directive construct;
Expand Down Expand Up @@ -1103,6 +1105,13 @@ bool ConstructDecompositionT<C, H>::applyClause(
return applyToOutermost(node);
}

template <typename C, typename H>
bool ConstructDecompositionT<C, H>::applyClause(
const tomp::clause::OmpxBareT<TypeTy, IdTy, ExprTy> &clause,
const ClauseTy *node) {
return applyToOutermost(node);
}

template <typename C, typename H>
bool ConstructDecompositionT<C, H>::applyClause(
const tomp::clause::OmpxAttributeT<TypeTy, IdTy, ExprTy> &clause,
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMP.td
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,7 @@ def OMP_Target : Directive<"target"> {
VersionedClause<OMPC_Device>,
VersionedClause<OMPC_If>,
VersionedClause<OMPC_NoWait>,
VersionedClause<OMPC_OMPX_Bare>,
VersionedClause<OMPC_OMPX_DynCGroupMem>,
VersionedClause<OMPC_ThreadLimit, 51>,
];
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,32 @@ class OpenMP_AllocateClauseSkip<

def OpenMP_AllocateClause : OpenMP_AllocateClauseSkip<>;

//===----------------------------------------------------------------------===//
// LLVM OpenMP extension `ompx_bare` clause
//===----------------------------------------------------------------------===//

class OpenMP_BareClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
UnitAttr:$bare
);

let optAssemblyFormat = [{
`ompx_bare` $bare
}];

let description = [{
`ompx_bare` allows `omp target teams` to be executed on a GPU with an
explicit number of teams and threads. This clause also allows the teams and
threads sizes to have up to 3 dimensions.
}];
}

def OpenMP_BareClause : OpenMP_BareClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [16.1, 16.2] `cancel-directive-name` clause set
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 5 additions & 4 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1215,10 +1215,11 @@ def TargetOp : OpenMP_Op<"target", traits = [
OutlineableOpenMPOpInterface
], clauses = [
// TODO: Complete clause list (defaultmap, uses_allocators).
OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause,
OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip<assemblyFormat = true>,
OpenMP_NowaitClause, OpenMP_PrivateClause, OpenMP_ThreadLimitClause
OpenMP_AllocateClause, OpenMP_BareClause, OpenMP_DependClause,
OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_IfClause,
OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
OpenMP_PrivateClause, OpenMP_ThreadLimitClause,
], singleRegion = true> {
let summary = "target construct";
let description = [{
Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1709,13 +1709,13 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
// inReductionByref, inReductionSyms.
TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
/*private_maps=*/nullptr);
clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
clauses.ifExpr, /*in_reduction_vars=*/{},
/*in_reduction_byref=*/nullptr, /*in_reduction_syms=*/nullptr,
clauses.isDevicePtrVars, clauses.mapVars, clauses.nowait,
clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
clauses.threadLimit, /*private_maps=*/nullptr);
}

LogicalResult TargetOp::verify() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
result = todo("allocate");
};
auto checkBare = [&todo](auto op, LogicalResult &result) {
if (op.getBare())
result = todo("ompx_bare");
};
auto checkDepend = [&todo](auto op, LogicalResult &result) {
if (!op.getDependVars().empty() || op.getDependKinds())
result = todo("depend");
Expand Down Expand Up @@ -283,6 +287,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
[&](auto op) { checkDepend(op, result); })
.Case([&](omp::TargetOp op) {
checkAllocate(op, result);
checkBare(op, result);
checkDevice(op, result);
checkHasDeviceAddr(op, result);
checkIf(op, result);
Expand Down
Loading