Skip to content

Commit 291dc48

Browse files
committed
[Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstruct
Emit a special leaf constuct table in DirectiveEmitter.cpp, which will allow both decomposition of a construct into leafs, and composition of constituent constructs into a single compound construct (is possible).
1 parent f725fac commit 291dc48

File tree

7 files changed

+258
-87
lines changed

7 files changed

+258
-87
lines changed

llvm/include/llvm/Frontend/OpenMP/OMP.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,11 @@
1515

1616
#include "llvm/Frontend/OpenMP/OMP.h.inc"
1717

18+
#include "llvm/ADT/ArrayRef.h"
19+
20+
namespace llvm::omp {
21+
ArrayRef<Directive> getLeafConstructs(Directive D);
22+
Directive getCompoundConstruct(ArrayRef<Directive> Parts);
23+
} // namespace llvm::omp
24+
1825
#endif // LLVM_FRONTEND_OPENMP_OMP_H

llvm/lib/Frontend/OpenMP/OMP.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,74 @@
88

99
#include "llvm/Frontend/OpenMP/OMP.h"
1010

11+
#include "llvm/ADT/ArrayRef.h"
12+
#include "llvm/ADT/STLExtras.h"
13+
#include "llvm/ADT/SmallVector.h"
1114
#include "llvm/ADT/StringRef.h"
1215
#include "llvm/ADT/StringSwitch.h"
1316
#include "llvm/Support/ErrorHandling.h"
1417

18+
#include <algorithm>
19+
#include <iterator>
20+
#include <type_traits>
21+
1522
using namespace llvm;
16-
using namespace omp;
23+
using namespace llvm::omp;
1724

1825
#define GEN_DIRECTIVES_IMPL
1926
#include "llvm/Frontend/OpenMP/OMP.inc"
27+
28+
namespace llvm::omp {
29+
ArrayRef<Directive> getLeafConstructs(Directive D) {
30+
auto Idx = static_cast<int>(D);
31+
if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
32+
return {};
33+
const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
34+
return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
35+
}
36+
37+
Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
38+
if (Parts.empty())
39+
return OMPD_unknown;
40+
41+
// Parts don't have to be leafs, so expand them into leafs first.
42+
// Store the expanded leafs in the same format as rows in the leaf
43+
// table (generated by tablegen).
44+
SmallVector<Directive> RawLeafs(2);
45+
for (Directive P : Parts) {
46+
ArrayRef<Directive> Ls = getLeafConstructs(P);
47+
if (!Ls.empty())
48+
RawLeafs.append(Ls.begin(), Ls.end());
49+
else
50+
RawLeafs.push_back(P);
51+
}
52+
53+
auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
54+
if (GivenLeafs.size() == 1)
55+
return GivenLeafs.front();
56+
RawLeafs[1] = static_cast<Directive>(GivenLeafs.size());
57+
58+
auto Iter = llvm::lower_bound(
59+
LeafConstructTable,
60+
static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()),
61+
[](const auto *RowA, const auto *RowB) {
62+
const auto *BeginA = &RowA[2];
63+
const auto *EndA = BeginA + static_cast<int>(RowA[1]);
64+
const auto *BeginB = &RowB[2];
65+
const auto *EndB = BeginB + static_cast<int>(RowB[1]);
66+
if (BeginA == EndA && BeginB == EndB)
67+
return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]);
68+
return std::lexicographical_compare(BeginA, EndA, BeginB, EndB);
69+
});
70+
71+
if (Iter == std::end(LeafConstructTable))
72+
return OMPD_unknown;
73+
74+
// Verify that we got a match.
75+
Directive Found = (*Iter)[0];
76+
ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found);
77+
if (FoundLeafs == GivenLeafs)
78+
return Found;
79+
return OMPD_unknown;
80+
}
81+
} // namespace llvm::omp

llvm/test/TableGen/directive1.td

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def TDL_DirA : Directive<"dira"> {
5252
// CHECK-EMPTY:
5353
// CHECK-NEXT: #include "llvm/ADT/ArrayRef.h"
5454
// CHECK-NEXT: #include "llvm/ADT/BitmaskEnum.h"
55+
// CHECK-NEXT: #include <cstddef>
5556
// CHECK-EMPTY:
5657
// CHECK-NEXT: namespace llvm {
5758
// CHECK-NEXT: class StringRef;
@@ -112,7 +113,7 @@ def TDL_DirA : Directive<"dira"> {
112113
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
113114
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
114115
// CHECK-EMPTY:
115-
// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
116+
// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; }
116117
// CHECK-NEXT: Association getDirectiveAssociation(Directive D);
117118
// CHECK-NEXT: AKind getAKind(StringRef);
118119
// CHECK-NEXT: llvm::StringRef getTdlAKindName(AKind);
@@ -359,13 +360,6 @@ def TDL_DirA : Directive<"dira"> {
359360
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind");
360361
// IMPL-NEXT: }
361362
// IMPL-EMPTY:
362-
// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
363-
// IMPL-NEXT: switch (Dir) {
364-
// IMPL-NEXT: default:
365-
// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{};
366-
// IMPL-NEXT: } // switch (Dir)
367-
// IMPL-NEXT: }
368-
// IMPL-EMPTY:
369363
// IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
370364
// IMPL-NEXT: switch (Dir) {
371365
// IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira:
@@ -374,4 +368,13 @@ def TDL_DirA : Directive<"dira"> {
374368
// IMPL-NEXT: llvm_unreachable("Unexpected directive");
375369
// IMPL-NEXT: }
376370
// IMPL-EMPTY:
371+
// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
372+
// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
373+
// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
374+
// IMPL-NEXT: };
375+
// IMPL-EMPTY:
376+
// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = {
377+
// IMPL-NEXT: 0,
378+
// IMPL-NEXT: };
379+
// IMPL-EMPTY:
377380
// IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL

llvm/test/TableGen/directive2.td

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def TDL_DirA : Directive<"dira"> {
4545
// CHECK-NEXT: #define LLVM_Tdl_INC
4646
// CHECK-EMPTY:
4747
// CHECK-NEXT: #include "llvm/ADT/ArrayRef.h"
48+
// CHECK-NEXT: #include <cstddef>
4849
// CHECK-EMPTY:
4950
// CHECK-NEXT: namespace llvm {
5051
// CHECK-NEXT: class StringRef;
@@ -88,7 +89,7 @@ def TDL_DirA : Directive<"dira"> {
8889
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
8990
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
9091
// CHECK-EMPTY:
91-
// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
92+
// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; }
9293
// CHECK-NEXT: Association getDirectiveAssociation(Directive D);
9394
// CHECK-NEXT: } // namespace tdl
9495
// CHECK-NEXT: } // namespace llvm
@@ -290,13 +291,6 @@ def TDL_DirA : Directive<"dira"> {
290291
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind");
291292
// IMPL-NEXT: }
292293
// IMPL-EMPTY:
293-
// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
294-
// IMPL-NEXT: switch (Dir) {
295-
// IMPL-NEXT: default:
296-
// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{};
297-
// IMPL-NEXT: } // switch (Dir)
298-
// IMPL-NEXT: }
299-
// IMPL-EMPTY:
300294
// IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
301295
// IMPL-NEXT: switch (Dir) {
302296
// IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira:
@@ -305,4 +299,13 @@ def TDL_DirA : Directive<"dira"> {
305299
// IMPL-NEXT: llvm_unreachable("Unexpected directive");
306300
// IMPL-NEXT: }
307301
// IMPL-EMPTY:
302+
// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
303+
// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
304+
// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
305+
// IMPL-NEXT: };
306+
// IMPL-EMPTY:
307+
// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = {
308+
// IMPL-NEXT: 0,
309+
// IMPL-NEXT: };
310+
// IMPL-EMPTY:
308311
// IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL

llvm/unittests/Frontend/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests
1414
OpenMPContextTest.cpp
1515
OpenMPIRBuilderTest.cpp
1616
OpenMPParsingTest.cpp
17+
OpenMPComposeTest.cpp
1718

1819
DEPENDS
1920
acc_gen
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/ADT/ArrayRef.h"
10+
#include "llvm/Frontend/OpenMP/OMP.h"
11+
#include "gtest/gtest.h"
12+
13+
using namespace llvm;
14+
using namespace llvm::omp;
15+
16+
TEST(Composition, GetLeafConstructs) {
17+
ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop);
18+
ASSERT_EQ(L1, (ArrayRef<Directive>{}));
19+
ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for);
20+
ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for}));
21+
ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd);
22+
ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd}));
23+
}
24+
25+
TEST(Composition, GetCompoundConstruct) {
26+
Directive C1 =
27+
getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute});
28+
ASSERT_EQ(C1, OMPD_target_teams_distribute);
29+
Directive C2 = getCompoundConstruct({OMPD_target});
30+
ASSERT_EQ(C2, OMPD_target);
31+
Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked});
32+
ASSERT_EQ(C3, OMPD_unknown);
33+
Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
34+
ASSERT_EQ(C4, OMPD_target_teams_distribute);
35+
Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
36+
ASSERT_EQ(C5, OMPD_target_teams_distribute);
37+
Directive C6 = getCompoundConstruct({});
38+
ASSERT_EQ(C6, OMPD_unknown);
39+
Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
40+
ASSERT_EQ(C7, OMPD_parallel_for_simd);
41+
}

0 commit comments

Comments
 (0)