Skip to content

[flang][OpenMP] Fix construct privatization in default clause #72510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions flang/include/flang/Lower/AbstractConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ class AbstractConverter {
virtual bool isPresentShallowLookup(Fortran::semantics::Symbol &sym) = 0;

/// Collect the set of symbols with \p flag in \p eval
/// region if \p collectSymbols is true. Likewise, collect the
/// region if \p collectSymbols is true. Otherwise, collect the
/// set of the host symbols with \p flag of the associated symbols in \p eval
/// region if collectHostAssociatedSymbols is true.
/// region if collectHostAssociatedSymbols is true. This allows gathering
/// host association details of symbols particularly in nested directives
/// irrespective of \p flag \p, and can be useful where host
/// association details are needed in flag-agnostic manner.
virtual void collectSymbolSet(
pft::Evaluation &eval,
llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet,
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
bool collectSymbol) {
if (collectSymbol && oriSymbol.test(flag))
symbolSet.insert(&oriSymbol);
if (checkHostAssociatedSymbols)
else if (checkHostAssociatedSymbols)
if (const auto *details{
oriSymbol
.detailsIf<Fortran::semantics::HostAssocDetails>()})
Expand Down
38 changes: 27 additions & 11 deletions flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,21 +302,38 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
}
}

void DataSharingProcessor::collectSymbolsInNestedRegions(
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::Symbol::Flag flag,
llvm::SetVector<const Fortran::semantics::Symbol *>
&symbolsInNestedRegions) {
for (Fortran::lower::pft::Evaluation &nestedEval :
eval.getNestedEvaluations()) {
if (nestedEval.hasNestedEvaluations()) {
if (nestedEval.isConstruct())
// Recursively look for OpenMP constructs within `nestedEval`'s region
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean look for non-OpenMP constructs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we want to capture symbols within nested OpenMP constructs in symbolsInNestedRegions.

In this case, the current evaluation is a non-OpenMP construct (like do). This recursive calls intends to find if there are any OpenMP constructs within this outer non-OpenMP construct.

collectSymbolsInNestedRegions(nestedEval, flag, symbolsInNestedRegions);
else
converter.collectSymbolSet(nestedEval, symbolsInNestedRegions, flag,
/*collectSymbols=*/true,
/*collectHostAssociatedSymbols=*/false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is collectHostAssociatedSymbols false here? Is it because a previous invocation of collectSymbolSet has already collected it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is correct.

}
}
}

// Collect symbols to be default privatized in two steps.
// In step 1, collect all symbols in `eval` that match `flag` into
// `defaultSymbols`. In step 2, for nested constructs (if any), if and only if
// the nested construct is an OpenMP construct, collect those nested
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean if it not an OpenMP construct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-OpenMP constructs, we skip collecting any symbols in symbolsInNestedRegions.

// symbols skipping host associated symbols into `symbolsInNestedRegions`.
// Later, in current context, all symbols in the set
// `defaultSymbols` - `symbolsInNestedRegions` will be privatized.
void DataSharingProcessor::collectSymbols(
Fortran::semantics::Symbol::Flag flag) {
converter.collectSymbolSet(eval, defaultSymbols, flag,
/*collectSymbols=*/true,
/*collectHostAssociatedSymbols=*/true);
Comment on lines 333 to 335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be folded into collectSymbolsInNestedRegions? I see the value of collectHostAssociatedSymbols is different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, that might not work. defaultSymbols collect both the symbols as well as host-associations (if any). While collectSymbolsInNestedRegions skip host associations. This prevents any duplicate privatization

for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
if (e.hasNestedEvaluations())
converter.collectSymbolSet(e, symbolsInNestedRegions, flag,
/*collectSymbols=*/true,
/*collectHostAssociatedSymbols=*/false);
else
converter.collectSymbolSet(e, symbolsInParentRegions, flag,
/*collectSymbols=*/false,
/*collectHostAssociatedSymbols=*/true);
}
collectSymbolsInNestedRegions(eval, flag, symbolsInNestedRegions);
}

void DataSharingProcessor::collectDefaultSymbols() {
Expand Down Expand Up @@ -367,7 +384,6 @@ void DataSharingProcessor::defaultPrivatize(
!sym->GetUltimate().has<Fortran::semantics::NamelistDetails>() &&
!Fortran::semantics::IsImpliedDoIndex(sym->GetUltimate()) &&
!symbolsInNestedRegions.contains(sym) &&
!symbolsInParentRegions.contains(sym) &&
!privatizedSymbols.contains(sym))
doPrivatize(sym, clauseOps, privateSyms);
}
Expand Down
6 changes: 5 additions & 1 deletion flang/lib/Lower/OpenMP/DataSharingProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class DataSharingProcessor {
llvm::SetVector<const Fortran::semantics::Symbol *> privatizedSymbols;
llvm::SetVector<const Fortran::semantics::Symbol *> defaultSymbols;
llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInNestedRegions;
llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInParentRegions;
llvm::DenseMap<const Fortran::semantics::Symbol *, mlir::omp::PrivateClauseOp>
symToPrivatizer;
Fortran::lower::AbstractConverter &converter;
Expand All @@ -52,6 +51,11 @@ class DataSharingProcessor {

bool needBarrier();
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
void collectSymbolsInNestedRegions(
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::Symbol::Flag flag,
llvm::SetVector<const Fortran::semantics::Symbol *>
&symbolsInNestedRegions);
void collectOmpObjectListSymbol(
const omp::ObjectList &objects,
llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet);
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Lower/OpenMP/default-clause-byref.f90
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,6 @@ subroutine nested_default_clause_tests
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
!CHECK: %[[PRIVATE_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
Expand All @@ -242,12 +240,14 @@ subroutine nested_default_clause_tests
!CHECK: omp.terminator
!CHECK: }
!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_INNER_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[PRIVATE_INNER_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_INNER_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
!CHECK: %[[PRIVATE_INNER_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[TEMP_1:.*]] = fir.load %[[PRIVATE_INNER_X_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_Z_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_INNER_Z_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
!CHECK: hlfir.assign %[[RESULT]] to %[[PRIVATE_INNER_W_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
Expand Down
55 changes: 51 additions & 4 deletions flang/test/Lower/OpenMP/default-clause.f90
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ program default_clause_lowering
end program default_clause_lowering

subroutine nested_default_clause_tests
integer :: x, y, z, w, k, a
integer :: x, y, z, w, k
!CHECK: %[[K:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFnested_default_clause_testsEk"}
!CHECK: %[[K_DECL:.*]]:2 = hlfir.declare %[[K]] {uniq_name = "_QFnested_default_clause_testsEk"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[W:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFnested_default_clause_testsEw"}
Expand Down Expand Up @@ -221,13 +221,12 @@ subroutine nested_default_clause_tests


!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_default_clause_testsEy"}
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
!CHECK: %[[PRIVATE_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
Expand All @@ -242,12 +241,14 @@ subroutine nested_default_clause_tests
!CHECK: omp.terminator
!CHECK: }
!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_INNER_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[PRIVATE_INNER_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_INNER_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
!CHECK: %[[PRIVATE_INNER_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[TEMP_1:.*]] = fir.load %[[PRIVATE_INNER_X_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_Z_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_INNER_Z_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
!CHECK: hlfir.assign %[[RESULT]] to %[[PRIVATE_INNER_W_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
Expand Down Expand Up @@ -415,3 +416,49 @@ subroutine threadprivate_with_default
end do
!$omp end parallel do
end subroutine

subroutine nested_constructs
!CHECK: %[[I:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFnested_constructsEi"}
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[I]] {{.*}}
!CHECK: %[[J:.*]] = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFnested_constructsEj"}
!CHECK: %[[J_DECL:.*]]:2 = hlfir.declare %[[J]] {{.*}}
!CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFnested_constructsEy"}
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {{.*}}
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_constructsEz"}
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {{.*}}

integer :: y, z
!CHECK: omp.parallel {
!CHECK: %[[INNER_J:.*]] = fir.alloca i32 {bindc_name = "j", pinned}
!CHECK: %[[INNER_J_DECL:.*]]:2 = hlfir.declare %[[INNER_J]] {{.*}}
!CHECK: %[[INNER_I:.*]] = fir.alloca i32 {bindc_name = "i", pinned}
!CHECK: %[[INNER_I_DECL:.*]]:2 = hlfir.declare %[[INNER_I]] {{.*}}
!CHECK: %[[INNER_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_constructsEy"}
!CHECK: %[[INNER_Y_DECL:.*]]:2 = hlfir.declare %[[INNER_Y]] {{.*}}
!CHECK: %[[TEMP:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
!CHECK: hlfir.assign %[[TEMP]] to %[[INNER_Y_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
!CHECK: %[[INNER_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_constructsEz"}
!CHECK: %[[INNER_Z_DECL:.*]]:2 = hlfir.declare %[[INNER_Z]] {{.*}}
!$omp parallel default(private) firstprivate(y)
!CHECK: {{.*}} = fir.do_loop {{.*}} {
do i = 1, 10
!CHECK: %[[CONST_1:.*]] = arith.constant 1 : i32
!CHECK: hlfir.assign %[[CONST_1]] to %[[INNER_Y_DECL]]#0 : i32, !fir.ref<i32>
y = 1
!CHECK: {{.*}} = fir.do_loop {{.*}} {
do j = 1, 10
!CHECK: %[[CONST_20:.*]] = arith.constant 20 : i32
!CHECK: hlfir.assign %[[CONST_20]] to %[[INNER_Z_DECL]]#0 : i32, !fir.ref<i32>
z = 20
!CHECK: omp.parallel {
!CHECK: %[[NESTED_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_constructsEy"}
!CHECK: %[[NESTED_Y_DECL:.*]]:2 = hlfir.declare %[[NESTED_Y]] {{.*}}
!CHECK: %[[CONST_2:.*]] = arith.constant 2 : i32
!CHECK: hlfir.assign %[[CONST_2]] to %[[NESTED_Y_DECL]]#0 : i32, !fir.ref<i32>
!$omp parallel default(private)
y = 2
!$omp end parallel
end do
end do
!$omp end parallel
end subroutine