Skip to content

Commit 90064f0

Browse files
committed
Base commit, PR #78098
2 parents 24aa668 + 436ec9b commit 90064f0

File tree

5 files changed

+224
-0
lines changed

5 files changed

+224
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//===- Offload.h - LLVM Target Offload --------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares LLVM target offload utility classes.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TARGET_LLVM_OFFLOAD_H
14+
#define MLIR_TARGET_LLVM_OFFLOAD_H
15+
16+
#include "mlir/Support/LogicalResult.h"
17+
#include "llvm/ADT/StringRef.h"
18+
19+
namespace llvm {
20+
class Constant;
21+
class GlobalVariable;
22+
class Module;
23+
} // namespace llvm
24+
25+
namespace mlir {
26+
namespace LLVM {
27+
/// `OffloadHandler` is a utility class for creating LLVM offload entries. LLVM
28+
/// offload entries hold information on offload symbols; for example, for a GPU
29+
/// kernel, this includes its host address to identify the kernel and the kernel
30+
/// identifier in the binary. Arrays of offload entries can be used to register
31+
/// functions within the CUDA/HIP runtime. Libomptarget also uses these entries
32+
/// to register OMP target offload kernels and variables.
33+
class OffloadHandler {
34+
public:
35+
using OffloadEntryArray =
36+
std::pair<llvm::GlobalVariable *, llvm::GlobalVariable *>;
37+
OffloadHandler(llvm::Module &module) : module(module) {}
38+
39+
/// Returns the begin symbol name used in the entry array.
40+
static std::string getBeginSymbol(StringRef suffix);
41+
42+
/// Returns the end symbol name used in the entry array.
43+
static std::string getEndSymbol(StringRef suffix);
44+
45+
/// Returns the entry array if it exists or a pair of null pointers.
46+
OffloadEntryArray getEntryArray(StringRef suffix);
47+
48+
/// Emits an empty array of offloading entries.
49+
OffloadEntryArray emitEmptyEntryArray(StringRef suffix);
50+
51+
/// Inserts an offloading entry into an existing entry array. This method
52+
/// returns failure if the entry array hasn't been declared.
53+
LogicalResult insertOffloadEntry(StringRef suffix, llvm::Constant *entry);
54+
55+
protected:
56+
llvm::Module &module;
57+
};
58+
} // namespace LLVM
59+
} // namespace mlir
60+
61+
#endif // MLIR_TARGET_LLVM_OFFLOAD_H

mlir/lib/Target/LLVM/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_library(MLIRTargetLLVM
22
ModuleToObject.cpp
3+
Offload.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVM
@@ -16,6 +17,7 @@ add_mlir_library(MLIRTargetLLVM
1617
Passes
1718
Support
1819
Target
20+
FrontendOffloading
1921
LINK_LIBS PUBLIC
2022
MLIRExecutionEngineUtils
2123
MLIRTargetLLVMIRExport

mlir/lib/Target/LLVM/Offload.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//===- Offload.cpp - LLVM Target Offload ------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines LLVM target offload utility classes.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Target/LLVM/Offload.h"
14+
#include "llvm/Frontend/Offloading/Utility.h"
15+
#include "llvm/IR/Constants.h"
16+
#include "llvm/IR/Module.h"
17+
18+
using namespace mlir;
19+
using namespace mlir::LLVM;
20+
21+
std::string OffloadHandler::getBeginSymbol(StringRef suffix) {
22+
return ("__begin_offload_" + suffix).str();
23+
}
24+
25+
std::string OffloadHandler::getEndSymbol(StringRef suffix) {
26+
return ("__end_offload_" + suffix).str();
27+
}
28+
29+
namespace {
30+
/// Returns the type of the entry array.
31+
llvm::ArrayType *getEntryArrayType(llvm::Module &module, size_t numElems) {
32+
return llvm::ArrayType::get(llvm::offloading::getEntryTy(module), numElems);
33+
}
34+
35+
/// Creates the initializer of the entry array.
36+
llvm::Constant *getEntryArrayBegin(llvm::Module &module,
37+
ArrayRef<llvm::Constant *> entries) {
38+
// If there are no entries return a constant zero initializer.
39+
llvm::ArrayType *arrayTy = getEntryArrayType(module, entries.size());
40+
return entries.empty() ? llvm::ConstantAggregateZero::get(arrayTy)
41+
: llvm::ConstantArray::get(arrayTy, entries);
42+
}
43+
44+
/// Computes the end position of the entry array.
45+
llvm::Constant *getEntryArrayEnd(llvm::Module &module,
46+
llvm::GlobalVariable *begin, size_t numElems) {
47+
llvm::Type *intTy = module.getDataLayout().getIntPtrType(module.getContext());
48+
return llvm::ConstantExpr::getGetElementPtr(
49+
llvm::offloading::getEntryTy(module), begin,
50+
ArrayRef<llvm::Constant *>({llvm::ConstantInt::get(intTy, numElems)}),
51+
true);
52+
}
53+
} // namespace
54+
55+
OffloadHandler::OffloadEntryArray
56+
OffloadHandler::getEntryArray(StringRef suffix) {
57+
llvm::GlobalVariable *beginGV =
58+
module.getGlobalVariable(getBeginSymbol(suffix), true);
59+
llvm::GlobalVariable *endGV =
60+
module.getGlobalVariable(getEndSymbol(suffix), true);
61+
return {beginGV, endGV};
62+
}
63+
64+
OffloadHandler::OffloadEntryArray
65+
OffloadHandler::emitEmptyEntryArray(StringRef suffix) {
66+
llvm::ArrayType *arrayTy = getEntryArrayType(module, 0);
67+
auto *beginGV = new llvm::GlobalVariable(
68+
module, arrayTy, /*isConstant=*/true, llvm::GlobalValue::InternalLinkage,
69+
getEntryArrayBegin(module, {}), getBeginSymbol(suffix));
70+
auto *endGV = new llvm::GlobalVariable(
71+
module, llvm::PointerType::get(module.getContext(), 0),
72+
/*isConstant=*/true, llvm::GlobalValue::InternalLinkage,
73+
getEntryArrayEnd(module, beginGV, 0), getEndSymbol(suffix));
74+
return {beginGV, endGV};
75+
}
76+
77+
LogicalResult OffloadHandler::insertOffloadEntry(StringRef suffix,
78+
llvm::Constant *entry) {
79+
// Get the begin and end symbols to the entry array.
80+
std::string beginSymId = getBeginSymbol(suffix);
81+
llvm::GlobalVariable *beginGV = module.getGlobalVariable(beginSymId, true);
82+
llvm::GlobalVariable *endGV =
83+
module.getGlobalVariable(getEndSymbol(suffix), true);
84+
// Fail if the symbols are missing.
85+
if (!beginGV || !endGV)
86+
return failure();
87+
// Create the entry initializer.
88+
assert(beginGV->getInitializer() && "entry array initializer is missing.");
89+
// Add existing entries into the new entry array.
90+
SmallVector<llvm::Constant *> entries;
91+
if (auto beginInit = dyn_cast_or_null<llvm::ConstantAggregate>(
92+
beginGV->getInitializer())) {
93+
for (unsigned i = 0; i < beginInit->getNumOperands(); ++i)
94+
entries.push_back(beginInit->getOperand(i));
95+
}
96+
// Add the new entry.
97+
entries.push_back(entry);
98+
// Create a global holding the new updated set of entries.
99+
auto *arrayTy = llvm::ArrayType::get(llvm::offloading::getEntryTy(module),
100+
entries.size());
101+
auto *entryArr = new llvm::GlobalVariable(
102+
module, arrayTy, /*isConstant=*/true, llvm::GlobalValue::InternalLinkage,
103+
getEntryArrayBegin(module, entries), beginSymId, endGV);
104+
// Replace the old entry array variable withe new one.
105+
beginGV->replaceAllUsesWith(entryArr);
106+
beginGV->eraseFromParent();
107+
entryArr->setName(beginSymId);
108+
// Update the end symbol.
109+
endGV->setInitializer(getEntryArrayEnd(module, entryArr, entries.size()));
110+
return success();
111+
}

mlir/unittests/Target/LLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_unittest(MLIRTargetLLVMTests
2+
Offload.cpp
23
SerializeNVVMTarget.cpp
34
SerializeROCDLTarget.cpp
45
SerializeToLLVMBitcode.cpp
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===- Offload.cpp ----------------------------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Target/LLVM/Offload.h"
10+
#include "llvm/Frontend/Offloading/Utility.h"
11+
#include "llvm/IR/Constants.h"
12+
#include "llvm/IR/Module.h"
13+
14+
#include "gmock/gmock.h"
15+
16+
using namespace llvm;
17+
18+
TEST(MLIRTarget, OffloadAPI) {
19+
using OffloadEntryArray = mlir::LLVM::OffloadHandler::OffloadEntryArray;
20+
LLVMContext llvmContext;
21+
Module llvmModule("offload", llvmContext);
22+
mlir::LLVM::OffloadHandler handler(llvmModule);
23+
StringRef suffix = ".mlir";
24+
// Check there's no entry array with `.mlir` suffix.
25+
OffloadEntryArray entryArray = handler.getEntryArray(suffix);
26+
EXPECT_EQ(entryArray, OffloadEntryArray());
27+
// Emit the entry array.
28+
handler.emitEmptyEntryArray(suffix);
29+
// Check there's an entry array with `.mlir` suffix.
30+
entryArray = handler.getEntryArray(suffix);
31+
ASSERT_NE(entryArray.first, nullptr);
32+
ASSERT_NE(entryArray.second, nullptr);
33+
// Check the array contains no entries.
34+
auto *zeroInitializer = dyn_cast_or_null<ConstantAggregateZero>(
35+
entryArray.first->getInitializer());
36+
ASSERT_NE(zeroInitializer, nullptr);
37+
// Insert an empty entries.
38+
auto emptyEntry =
39+
ConstantAggregateZero::get(offloading::getEntryTy(llvmModule));
40+
ASSERT_TRUE(succeeded(handler.insertOffloadEntry(suffix, emptyEntry)));
41+
// Check there's an entry in the entry array with `.mlir` suffix.
42+
entryArray = handler.getEntryArray(suffix);
43+
ASSERT_NE(entryArray.first, nullptr);
44+
Constant *arrayInitializer = entryArray.first->getInitializer();
45+
ASSERT_NE(arrayInitializer, nullptr);
46+
auto *arrayTy = dyn_cast_or_null<ArrayType>(arrayInitializer->getType());
47+
ASSERT_NE(arrayTy, nullptr);
48+
EXPECT_EQ(arrayTy->getNumElements(), 1u);
49+
}

0 commit comments

Comments
 (0)