Skip to content

Commit 2fec998

Browse files
committed
[Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstruct
Emit a special leaf constuct table in DirectiveEmitter.cpp, which will allow both decomposition of a construct into leafs, and composition of constituent constructs into a single compound construct (is possible).
1 parent f725fac commit 2fec998

File tree

5 files changed

+235
-71
lines changed

5 files changed

+235
-71
lines changed

llvm/include/llvm/Frontend/OpenMP/OMP.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,11 @@
1515

1616
#include "llvm/Frontend/OpenMP/OMP.h.inc"
1717

18+
#include "llvm/ADT/ArrayRef.h"
19+
20+
namespace llvm::omp {
21+
ArrayRef<Directive> getLeafConstructs(Directive D);
22+
Directive getCompoundConstruct(ArrayRef<Directive> Parts);
23+
} // namespace llvm::omp
24+
1825
#endif // LLVM_FRONTEND_OPENMP_OMP_H

llvm/lib/Frontend/OpenMP/OMP.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,74 @@
88

99
#include "llvm/Frontend/OpenMP/OMP.h"
1010

11+
#include "llvm/ADT/ArrayRef.h"
12+
#include "llvm/ADT/STLExtras.h"
13+
#include "llvm/ADT/SmallVector.h"
1114
#include "llvm/ADT/StringRef.h"
1215
#include "llvm/ADT/StringSwitch.h"
1316
#include "llvm/Support/ErrorHandling.h"
1417

18+
#include <algorithm>
19+
#include <iterator>
20+
#include <type_traits>
21+
1522
using namespace llvm;
16-
using namespace omp;
23+
using namespace llvm::omp;
1724

1825
#define GEN_DIRECTIVES_IMPL
1926
#include "llvm/Frontend/OpenMP/OMP.inc"
27+
28+
namespace llvm::omp {
29+
ArrayRef<Directive> getLeafConstructs(Directive D) {
30+
auto Idx = static_cast<int>(D);
31+
if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
32+
return {};
33+
const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
34+
return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
35+
}
36+
37+
Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
38+
if (Parts.empty())
39+
return OMPD_unknown;
40+
41+
// Parts don't have to be leafs, so expand them into leafs first.
42+
// Store the expanded leafs in the same format as rows in the leaf
43+
// table (generated by tablegen).
44+
SmallVector<Directive> RawLeafs(2);
45+
for (Directive P : Parts) {
46+
ArrayRef<Directive> Ls = getLeafConstructs(P);
47+
if (!Ls.empty())
48+
RawLeafs.append(Ls.begin(), Ls.end());
49+
else
50+
RawLeafs.push_back(P);
51+
}
52+
53+
auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
54+
if (GivenLeafs.size() == 1)
55+
return GivenLeafs.front();
56+
RawLeafs[1] = static_cast<Directive>(GivenLeafs.size());
57+
58+
auto Iter = llvm::lower_bound(
59+
LeafConstructTable,
60+
static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()),
61+
[](const auto *RowA, const auto *RowB) {
62+
const auto *BeginA = &RowA[2];
63+
const auto *EndA = BeginA + static_cast<int>(RowA[1]);
64+
const auto *BeginB = &RowB[2];
65+
const auto *EndB = BeginB + static_cast<int>(RowB[1]);
66+
if (BeginA == EndA && BeginB == EndB)
67+
return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]);
68+
return std::lexicographical_compare(BeginA, EndA, BeginB, EndB);
69+
});
70+
71+
if (Iter == std::end(LeafConstructTable))
72+
return OMPD_unknown;
73+
74+
// Verify that we got a match.
75+
Directive Found = (*Iter)[0];
76+
ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found);
77+
if (FoundLeafs == GivenLeafs)
78+
return Found;
79+
return OMPD_unknown;
80+
}
81+
} // namespace llvm::omp

llvm/unittests/Frontend/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests
1414
OpenMPContextTest.cpp
1515
OpenMPIRBuilderTest.cpp
1616
OpenMPParsingTest.cpp
17+
OpenMPComposeTest.cpp
1718

1819
DEPENDS
1920
acc_gen
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/ADT/ArrayRef.h"
10+
#include "llvm/Frontend/OpenMP/OMP.h"
11+
#include "gtest/gtest.h"
12+
13+
using namespace llvm;
14+
using namespace llvm::omp;
15+
16+
TEST(Composition, GetLeafConstructs) {
17+
ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop);
18+
ASSERT_EQ(L1, (ArrayRef<Directive>{}));
19+
ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for);
20+
ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for}));
21+
ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd);
22+
ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd}));
23+
}
24+
25+
TEST(Composition, GetCompoundConstruct) {
26+
Directive C1 = getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute});
27+
ASSERT_EQ(C1, OMPD_target_teams_distribute);
28+
Directive C2 = getCompoundConstruct({OMPD_target});
29+
ASSERT_EQ(C2, OMPD_target);
30+
Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked});
31+
ASSERT_EQ(C3, OMPD_unknown);
32+
Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
33+
ASSERT_EQ(C4, OMPD_target_teams_distribute);
34+
Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
35+
ASSERT_EQ(C5, OMPD_target_teams_distribute);
36+
Directive C6 = getCompoundConstruct({});
37+
ASSERT_EQ(C6, OMPD_unknown);
38+
Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
39+
ASSERT_EQ(C7, OMPD_parallel_for_simd);
40+
}

llvm/utils/TableGen/DirectiveEmitter.cpp

Lines changed: 124 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
#include "llvm/TableGen/Record.h"
2121
#include "llvm/TableGen/TableGenBackend.h"
2222

23+
#include <numeric>
24+
#include <vector>
25+
2326
using namespace llvm;
2427

2528
namespace {
@@ -39,7 +42,8 @@ class IfDefScope {
3942
};
4043
} // namespace
4144

42-
// Generate enum class
45+
// Generate enum class. Entries are emitted in the order in which they appear
46+
// in the `Records` vector.
4347
static void GenerateEnumClass(const std::vector<Record *> &Records,
4448
raw_ostream &OS, StringRef Enum, StringRef Prefix,
4549
const DirectiveLanguage &DirLang,
@@ -175,6 +179,16 @@ bool DirectiveLanguage::HasValidityErrors() const {
175179
return HasDuplicateClausesInDirectives(getDirectives());
176180
}
177181

182+
// Count the maximum number of leaf constituents per construct.
183+
static size_t GetMaxLeafCount(const DirectiveLanguage &DirLang) {
184+
size_t MaxCount = 0;
185+
for (Record *R : DirLang.getDirectives()) {
186+
size_t Count = Directive{R}.getLeafConstructs().size();
187+
MaxCount = std::max(MaxCount, Count);
188+
}
189+
return MaxCount;
190+
}
191+
178192
// Generate the declaration section for the enumeration in the directive
179193
// language
180194
static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
@@ -189,6 +203,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
189203
if (DirLang.hasEnableBitmaskEnumInNamespace())
190204
OS << "#include \"llvm/ADT/BitmaskEnum.h\"\n";
191205

206+
OS << "#include <cstddef>\n"; // for size_t
192207
OS << "\n";
193208
OS << "namespace llvm {\n";
194209
OS << "class StringRef;\n";
@@ -244,7 +259,8 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
244259
OS << "bool isAllowedClauseForDirective(Directive D, "
245260
<< "Clause C, unsigned Version);\n";
246261
OS << "\n";
247-
OS << "llvm::ArrayRef<Directive> getLeafConstructs(Directive D);\n";
262+
OS << "constexpr std::size_t getMaxLeafCount() { return "
263+
<< GetMaxLeafCount(DirLang) << "; }\n";
248264
OS << "Association getDirectiveAssociation(Directive D);\n";
249265
if (EnumHelperFuncs.length() > 0) {
250266
OS << EnumHelperFuncs;
@@ -396,6 +412,19 @@ GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses,
396412
}
397413
}
398414

415+
static std::string GetDirectiveName(const DirectiveLanguage &DirLang,
416+
const Record *Rec) {
417+
Directive Dir{Rec};
418+
return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" +
419+
DirLang.getDirectivePrefix() + Dir.getFormattedName())
420+
.str();
421+
}
422+
423+
static std::string GetDirectiveType(const DirectiveLanguage &DirLang) {
424+
return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::Directive")
425+
.str();
426+
}
427+
399428
// Generate the isAllowedClauseForDirective function implementation.
400429
static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
401430
raw_ostream &OS) {
@@ -450,77 +479,102 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
450479
OS << "}\n"; // End of function isAllowedClauseForDirective
451480
}
452481

453-
// Generate the getLeafConstructs function implementation.
454-
static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang,
455-
raw_ostream &OS) {
456-
auto getQualifiedName = [&](StringRef Formatted) -> std::string {
457-
return (llvm::Twine("llvm::") + DirLang.getCppNamespace() +
458-
"::Directive::" + DirLang.getDirectivePrefix() + Formatted)
459-
.str();
460-
};
461-
462-
// For each list of leaves, generate a static local object, then
463-
// return a reference to that object for a given directive, e.g.
482+
static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS,
483+
StringRef TableName) {
484+
// The leaf constructs are emitted in a form of a 2D table, where each
485+
// row corresponds to a directive (and there is a row for each directive).
464486
//
465-
// static ListTy leafConstructs_A_B = { A, B };
466-
// static ListTy leafConstructs_C_D_E = { C, D, E };
467-
// switch (Dir) {
468-
// case A_B:
469-
// return leafConstructs_A_B;
470-
// case C_D_E:
471-
// return leafConstructs_C_D_E;
472-
// }
473-
474-
// Map from a record that defines a directive to the name of the
475-
// local object with the list of its leaves.
476-
DenseMap<Record *, std::string> ListNames;
477-
478-
std::string DirectiveTypeName =
479-
std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive";
480-
481-
OS << '\n';
482-
483-
// ArrayRef<...> llvm::<ns>::GetLeafConstructs(llvm::<ns>::Directive Dir)
484-
OS << "llvm::ArrayRef<" << DirectiveTypeName
485-
<< "> llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs("
486-
<< DirectiveTypeName << " Dir) ";
487-
OS << "{\n";
488-
489-
// Generate the locals.
490-
for (Record *R : DirLang.getDirectives()) {
491-
Directive Dir{R};
487+
// Each row consists of
488+
// - the id of the directive itself,
489+
// - number of leaf constructs that will follow (0 for leafs),
490+
// - ids of the leaf constructs (none if the directive is itself a leaf).
491+
// The total number of these entries is at most MaxLeafCount+2. If this
492+
// number is less than that, it is padded to occupy exactly MaxLeafCount+2
493+
// entries in memory.
494+
//
495+
// The rows are stored in the table in the lexicographical order. This
496+
// is intended to enable binary search when mapping a sequence of leafs
497+
// back to the compound directive.
498+
// The consequence of that is that in order to find a row corresponding
499+
// to the given directive, we'd need to scan the first element of each
500+
// row. To avoid this, an auxiliary ordering table is created, such that
501+
// row for Dir_A = table[auxiliary[Dir_A]].
502+
503+
std::vector<Record *> Directives = DirLang.getDirectives();
504+
DenseMap<Record *, size_t> DirId; // Record * -> llvm::omp::Directive
505+
506+
for (auto [Idx, Rec] : llvm::enumerate(Directives))
507+
DirId.insert(std::make_pair(Rec, Idx));
508+
509+
using LeafList = std::vector<int>;
510+
int MaxLeafCount = GetMaxLeafCount(DirLang);
511+
512+
// The initial leaf table, rows order is same as directive order.
513+
std::vector<LeafList> LeafTable(Directives.size());
514+
for (auto [Idx, Rec] : llvm::enumerate(Directives)) {
515+
Directive Dir{Rec};
516+
std::vector<Record *> Leaves = Dir.getLeafConstructs();
517+
518+
auto &List = LeafTable[Idx];
519+
List.resize(MaxLeafCount + 2);
520+
List[0] = Idx; // The id of the directive itself.
521+
List[1] = Leaves.size(); // The number of leaves to follow.
522+
523+
for (int I = 0; I != MaxLeafCount; ++I)
524+
List[I + 2] =
525+
static_cast<size_t>(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1;
526+
}
492527

493-
std::vector<Record *> LeafConstructs = Dir.getLeafConstructs();
494-
if (LeafConstructs.empty())
495-
continue;
528+
// Avoid sorting the vector<vector> array, instead sort an index array.
529+
// It will also be useful later to create the auxiliary indexing array.
530+
std::vector<int> Ordering(Directives.size());
531+
std::iota(Ordering.begin(), Ordering.end(), 0);
532+
533+
llvm::sort(Ordering, [&](int A, int B) {
534+
auto &LeavesA = LeafTable[A];
535+
auto &LeavesB = LeafTable[B];
536+
if (LeavesA[1] == 0 && LeavesB[1] == 0)
537+
return LeavesA[0] < LeavesB[0];
538+
return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1],
539+
&LeavesB[2], &LeavesB[2] + LeavesB[1]);
540+
});
496541

497-
std::string ListName = "leafConstructs_" + Dir.getFormattedName();
498-
OS << " static const " << DirectiveTypeName << ' ' << ListName
499-
<< "[] = {\n";
500-
for (Record *L : LeafConstructs) {
501-
Directive LeafDir{L};
502-
OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n";
542+
// Emit the table
543+
544+
// The directives are emitted into a scoped enum, for which the underlying
545+
// type is `int` (by default). The code above uses `int` to store directive
546+
// ids, so make sure that we catch it when something changes in the
547+
// underlying type.
548+
std::string DirectiveType = GetDirectiveType(DirLang);
549+
OS << "static_assert(sizeof(" << DirectiveType << ") == sizeof(int));\n";
550+
551+
OS << "[[maybe_unused]] static const " << DirectiveType << ' ' << TableName
552+
<< "[][" << MaxLeafCount + 2 << "] = {\n";
553+
for (size_t I = 0, E = Directives.size(); I != E; ++I) {
554+
auto &Leaves = LeafTable[Ordering[I]];
555+
OS << " " << GetDirectiveName(DirLang, Directives[Leaves[0]]);
556+
OS << ", static_cast<" << DirectiveType << ">(" << Leaves[1] << "),";
557+
for (size_t I = 2, E = Leaves.size(); I != E; ++I) {
558+
int Idx = Leaves[I];
559+
if (Idx >= 0)
560+
OS << ' ' << GetDirectiveName(DirLang, Directives[Leaves[I]]) << ',';
561+
else
562+
OS << " static_cast<" << DirectiveType << ">(-1),";
503563
}
504-
OS << " };\n";
505-
ListNames.insert(std::make_pair(R, std::move(ListName)));
506-
}
507-
508-
if (!ListNames.empty())
509564
OS << '\n';
510-
OS << " switch (Dir) {\n";
511-
for (Record *R : DirLang.getDirectives()) {
512-
auto F = ListNames.find(R);
513-
if (F == ListNames.end())
514-
continue;
515-
516-
Directive Dir{R};
517-
OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n";
518-
OS << " return " << F->second << ";\n";
519565
}
520-
OS << " default:\n";
521-
OS << " return ArrayRef<" << DirectiveTypeName << ">{};\n";
522-
OS << " } // switch (Dir)\n";
523-
OS << "}\n";
566+
OS << "};\n\n";
567+
568+
// Emit the auxiliary index table: it's the inverse of the `Ordering`
569+
// table above.
570+
OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n";
571+
OS << " ";
572+
std::vector<int> Reverse(Ordering.size());
573+
for (int I = 0, E = Ordering.size(); I != E; ++I)
574+
Reverse[Ordering[I]] = I;
575+
for (int Idx : Reverse)
576+
OS << ' ' << Idx << ',';
577+
OS << "\n};\n";
524578
}
525579

526580
static void GenerateGetDirectiveAssociation(const DirectiveLanguage &DirLang,
@@ -1105,11 +1159,11 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang,
11051159
// isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
11061160
GenerateIsAllowedClause(DirLang, OS);
11071161

1108-
// getLeafConstructs(Directive D)
1109-
GenerateGetLeafConstructs(DirLang, OS);
1110-
11111162
// getDirectiveAssociation(Directive D)
11121163
GenerateGetDirectiveAssociation(DirLang, OS);
1164+
1165+
// Leaf table for getLeafConstructs, etc.
1166+
EmitLeafTable(DirLang, OS, "LeafConstructTable");
11131167
}
11141168

11151169
// Generate the implemenation section for the enumeration in the directive

0 commit comments

Comments
 (0)