Skip to content

Commit 325bc4e

Browse files
authored
[SYCL][Fusion] Enable fusion of kernels with different ND-ranges (#8209)
All kernels with the same (or unspecified) local size and offset can be fused. In order to make this work, some builtins getting index space information must be remapped and the resulting ND-range of the fused kernel, calculated. The ND-range of the fused kernel will have: 1. The same number of dimensions as the input ND-range with the higher number of dimensions; 2. The same local size as the shared local size (or unspecified) 3. The same offset as the shared offset 4. The global size will be the **greatest** input global size as per the following ordering: i. Number of work items (enforces correctness); ii. Number of occurrences (less remappings needed); iii. Lexical order of the dimensions (introduces determinism). Builtins obtaining the local/global size/id, work-group id, number of work-groups or offset are remapped introducing as per an alwaysinline function that can be reused along the fusion pass. More information can be found in the Builtins.cpp file, where the remapping logic is implemented. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent 77ac85a commit 325bc4e

26 files changed

+1655
-91
lines changed

sycl-fusion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ set(LLVM_SPIRV_INCLUDE_DIRS "${LLVM_MAIN_SRC_DIR}/../llvm-spirv/include")
1212
if(WIN32)
1313
message(WARNING "Kernel fusion not yet supported on Windows")
1414
else(WIN32)
15+
add_subdirectory(common)
1516
add_subdirectory(jit-compiler)
1617
add_subdirectory(passes)
1718
add_subdirectory(test)

sycl-fusion/common/CMakeLists.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
add_llvm_library(sycl-fusion-common
2+
lib/NDRangesHelper.cpp
3+
)
4+
5+
target_include_directories(sycl-fusion-common
6+
PUBLIC
7+
${CMAKE_CURRENT_SOURCE_DIR}/include
8+
${CMAKE_CURRENT_SOURCE_DIR}/lib
9+
)
10+
11+
if (BUILD_SHARED_LIBS)
12+
if(NOT MSVC AND NOT APPLE)
13+
# Manage symbol visibility through the linker to make sure no LLVM symbols
14+
# are exported and confuse the drivers.
15+
set(linker_script "${CMAKE_CURRENT_SOURCE_DIR}/ld-version-script.txt")
16+
target_link_libraries(
17+
sycl-fusion-common PRIVATE "-Wl,--version-script=${linker_script}")
18+
set_target_properties(sycl-fusion-common
19+
PROPERTIES
20+
LINK_DEPENDS
21+
${linker_script})
22+
endif()
23+
endif()

sycl-fusion/common/include/Kernel.h

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#define SYCL_FUSION_COMMON_KERNEL_H
1111

1212
#include <algorithm>
13+
#include <array>
14+
#include <cassert>
1315
#include <string>
1416
#include <vector>
1517

@@ -97,6 +99,77 @@ struct SYCLArgumentDescriptor {
9799
/// List of SYCL/OpenCL kernel attributes.
98100
using AttributeList = std::vector<SYCLKernelAttribute>;
99101

102+
using Indices = std::array<size_t, 3>;
103+
104+
///
105+
/// Class to model SYCL nd_range
106+
class NDRange {
107+
public:
108+
constexpr static Indices AllZeros{0, 0, 0};
109+
110+
///
111+
/// Return the product of each index in an indices array.
112+
constexpr static size_t linearize(const jit_compiler::Indices &I) {
113+
return I[0] * I[1] * I[2];
114+
}
115+
116+
NDRange() : NDRange{1, {1, 1, 1}} {}
117+
118+
NDRange(int Dimensions, const Indices &GlobalSize,
119+
const Indices &LocalSize = {1, 1, 1},
120+
const Indices &Offset = {0, 0, 0})
121+
: Dimensions{Dimensions},
122+
GlobalSize{GlobalSize}, LocalSize{LocalSize}, Offset{Offset} {
123+
#ifndef NDEBUG
124+
const auto CheckDim = [Dimensions](const Indices &Range) {
125+
return std::all_of(Range.begin() + Dimensions, Range.end(),
126+
[](auto D) { return D == 1; });
127+
};
128+
const auto CheckOffsetDim = [Dimensions](const Indices &Offset) {
129+
return std::all_of(Offset.begin() + Dimensions, Offset.end(),
130+
131+
[](auto D) { return D == 0; });
132+
};
133+
#endif // NDEBUG
134+
assert(CheckDim(GlobalSize) &&
135+
"Invalid global range for number of dimensions");
136+
assert(
137+
(CheckDim(LocalSize) || std::all_of(LocalSize.begin(), LocalSize.end(),
138+
[](auto D) { return D == 0; })) &&
139+
"Invalid local range for number of dimensions");
140+
assert(CheckOffsetDim(Offset) && "Invalid offset for number of dimensions");
141+
}
142+
143+
constexpr const Indices &getGlobalSize() const { return GlobalSize; }
144+
constexpr const Indices &getLocalSize() const { return LocalSize; }
145+
constexpr const Indices &getOffset() const { return Offset; }
146+
constexpr int getDimensions() const { return Dimensions; }
147+
148+
bool hasSpecificLocalSize() const { return LocalSize != AllZeros; }
149+
150+
friend constexpr bool operator==(const NDRange &LHS, const NDRange &RHS) {
151+
return LHS.Dimensions == RHS.Dimensions &&
152+
LHS.GlobalSize == RHS.GlobalSize &&
153+
(!LHS.hasSpecificLocalSize() || !RHS.hasSpecificLocalSize() ||
154+
LHS.LocalSize == RHS.LocalSize) &&
155+
LHS.Offset == RHS.Offset;
156+
}
157+
158+
friend constexpr bool operator!=(const NDRange &LHS, const NDRange &RHS) {
159+
return !(LHS == RHS);
160+
}
161+
162+
private:
163+
/** @brief The number of dimensions. */
164+
int Dimensions;
165+
/** @brief The local range. */
166+
Indices GlobalSize;
167+
/** @brief The local range. */
168+
Indices LocalSize;
169+
/** @brief The offet. */
170+
Indices Offset;
171+
};
172+
100173
/// Information about a kernel from DPC++.
101174
struct SYCLKernelInfo {
102175

@@ -106,18 +179,21 @@ struct SYCLKernelInfo {
106179

107180
AttributeList Attributes;
108181

182+
NDRange NDR;
183+
109184
SYCLKernelBinaryInfo BinaryInfo;
110185

111186
//// Explicit constructor for compatibility with LLVM YAML I/O.
112-
SYCLKernelInfo() : Name{}, Args{}, Attributes{}, BinaryInfo{} {}
187+
SYCLKernelInfo() : Name{}, Args{}, Attributes{}, NDR{}, BinaryInfo{} {}
113188

114189
SYCLKernelInfo(const std::string &KernelName,
115-
const SYCLArgumentDescriptor &ArgDesc,
190+
const SYCLArgumentDescriptor &ArgDesc, const NDRange &NDR,
116191
const SYCLKernelBinaryInfo &BinInfo)
117-
: Name{KernelName}, Args{ArgDesc}, Attributes{}, BinaryInfo{BinInfo} {}
192+
: Name{KernelName}, Args{ArgDesc}, Attributes{}, NDR{NDR}, BinaryInfo{
193+
BinInfo} {}
118194

119195
explicit SYCLKernelInfo(const std::string &KernelName)
120-
: Name{KernelName}, Args{}, Attributes{}, BinaryInfo{} {}
196+
: Name{KernelName}, Args{}, Attributes{}, NDR{}, BinaryInfo{} {}
121197
};
122198

123199
///
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
global:
3+
/* Export everything from jit_compiler namespace */
4+
_ZN12jit_compiler*;
5+
6+
local:
7+
*;
8+
};
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
//==-------------------------- NDRangesHelper.cpp --------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "NDRangesHelper.h"
10+
11+
#include <map>
12+
13+
using namespace jit_compiler;
14+
using namespace llvm;
15+
16+
///
17+
/// Return the maximal global size using the following order:
18+
/// 1. A range is greater than another range if it contains more elements (see
19+
/// linearize());
20+
/// 2. Else if it appears more times in the input list of ranges;
21+
/// 3. Else if it is greater in lexicographical order.
22+
static Indices getMaximalGlobalSize(ArrayRef<NDRange> NDRanges) {
23+
size_t NumElements{0};
24+
std::map<Indices, size_t> FreqMap;
25+
for (const auto &ND : NDRanges) {
26+
const auto &GS = ND.getGlobalSize();
27+
const auto N = NDRange::linearize(GS);
28+
if (N < NumElements) {
29+
continue;
30+
}
31+
if (N > NumElements) {
32+
NumElements = N;
33+
FreqMap.clear();
34+
}
35+
++FreqMap[GS];
36+
}
37+
return std::max_element(FreqMap.begin(), FreqMap.end(),
38+
[](const auto &LHS, const auto &RHS) {
39+
const auto LHSN = LHS.second;
40+
const auto RHSN = RHS.second;
41+
if (LHSN < RHSN) {
42+
return true;
43+
}
44+
if (LHSN > RHSN) {
45+
return false;
46+
}
47+
return LHS.first < RHS.first;
48+
})
49+
->first;
50+
}
51+
52+
static bool compatibleRanges(const NDRange &LHS, const NDRange &RHS) {
53+
const auto Dimensions = std::max(LHS.getDimensions(), RHS.getDimensions());
54+
const auto EqualIndices = [Dimensions](const Indices &LHS,
55+
const Indices &RHS) {
56+
return std::equal(LHS.begin(), LHS.begin() + Dimensions, RHS.begin());
57+
};
58+
return (!LHS.hasSpecificLocalSize() || !RHS.hasSpecificLocalSize() ||
59+
EqualIndices(LHS.getLocalSize(), RHS.getLocalSize())) &&
60+
EqualIndices(LHS.getOffset(), RHS.getOffset());
61+
}
62+
63+
NDRange jit_compiler::combineNDRanges(ArrayRef<NDRange> NDRanges) {
64+
assert(isValidCombination(NDRanges) && "Invalid ND-ranges combination");
65+
const auto Dimensions =
66+
std::max_element(NDRanges.begin(), NDRanges.end(),
67+
[](const auto &LHS, const auto &RHS) {
68+
return LHS.getDimensions() < RHS.getDimensions();
69+
})
70+
->getDimensions();
71+
const auto GlobalSize = getMaximalGlobalSize(NDRanges);
72+
const auto *End = NDRanges.end();
73+
const auto *LocalSizeIter = findSpecifiedLocalSize(NDRanges);
74+
const auto &LocalSize =
75+
LocalSizeIter == End ? NDRange::AllZeros : LocalSizeIter->getLocalSize();
76+
const auto &Front = NDRanges.front();
77+
return {Dimensions, GlobalSize, LocalSize, Front.getOffset()};
78+
}
79+
80+
bool jit_compiler::isHeterogeneousList(ArrayRef<NDRange> NDRanges) {
81+
const auto *FirstSpecifiedLocalSize = findSpecifiedLocalSize(NDRanges);
82+
const auto &ND = FirstSpecifiedLocalSize == NDRanges.end()
83+
? NDRanges.front()
84+
: *FirstSpecifiedLocalSize;
85+
return any_of(NDRanges, [&ND](const auto &Other) { return ND != Other; });
86+
}
87+
88+
bool jit_compiler::isValidCombination(llvm::ArrayRef<NDRange> NDRanges) {
89+
if (NDRanges.empty()) {
90+
return false;
91+
}
92+
const auto *FirstSpecifiedLocalSize = findSpecifiedLocalSize(NDRanges);
93+
const auto &ND = FirstSpecifiedLocalSize == NDRanges.end()
94+
? NDRanges.front()
95+
: *FirstSpecifiedLocalSize;
96+
return llvm::all_of(NDRanges, [&ND](const auto &Other) {
97+
return compatibleRanges(ND, Other);
98+
});
99+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//==--------- NDRangesHelper.h - Helpers to handle ND-ranges ---------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SYCL_FUSION_COMMON_NDRANGESHELPER_H
10+
#define SYCL_FUSION_COMMON_NDRANGESHELPER_H
11+
12+
#include "Kernel.h"
13+
14+
#include "llvm/ADT/ArrayRef.h"
15+
#include "llvm/ADT/STLExtras.h"
16+
17+
namespace jit_compiler {
18+
///
19+
/// Combine a list of ND-ranges, obtaining the resulting "fused" ND-range.
20+
NDRange combineNDRanges(llvm::ArrayRef<NDRange> NDRanges);
21+
22+
inline llvm::ArrayRef<NDRange>::const_iterator
23+
findSpecifiedLocalSize(llvm::ArrayRef<NDRange> NDRanges) {
24+
return llvm::find_if(
25+
NDRanges, [](const auto &ND) { return ND.hasSpecificLocalSize(); });
26+
}
27+
28+
///
29+
/// Returns whether the input list of ND-ranges is heterogeneous or not.
30+
bool isHeterogeneousList(llvm::ArrayRef<NDRange> NDRanges);
31+
32+
///
33+
/// Return whether a combination of ND-ranges is valid for fusion.
34+
bool isValidCombination(llvm::ArrayRef<NDRange> NDRanges);
35+
} // namespace jit_compiler
36+
37+
#endif // SYCL_FUSION_COMMON_NDRANGESHELPER_H

sycl-fusion/jit-compiler/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_llvm_library(sycl-fusion
1212
Core
1313
Support
1414
Analysis
15+
IPO
1516
TransformUtils
1617
Passes
1718
Linker
@@ -33,6 +34,7 @@ find_package(Threads REQUIRED)
3334

3435
target_link_libraries(sycl-fusion
3536
PRIVATE
37+
sycl-fusion-common
3638
LLVMSPIRVLib
3739
SYCLKernelFusionPasses
3840
${CMAKE_THREAD_LIBS_INIT}

sycl-fusion/jit-compiler/include/Hashing.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef SYCL_FUSION_JIT_COMPILER_HASHING_H
1010
#define SYCL_FUSION_JIT_COMPILER_HASHING_H
1111

12+
#include "Kernel.h"
1213
#include "Parameter.h"
1314

1415
#include "llvm/ADT/Hashing.h"
@@ -32,13 +33,23 @@ inline llvm::hash_code hash_value(const JITConstant &C) {
3233
inline llvm::hash_code hash_value(const ParameterIdentity &IP) {
3334
return llvm::hash_combine(IP.LHS, IP.RHS);
3435
}
36+
37+
inline llvm::hash_code hash_value(const NDRange &ND) {
38+
return llvm::hash_combine(ND.getDimensions(), ND.getGlobalSize(),
39+
ND.getLocalSize(), ND.getOffset());
40+
}
3541
} // namespace jit_compiler
3642

3743
namespace std {
3844
template <typename T> inline llvm::hash_code hash_value(const vector<T> &V) {
3945
return llvm::hash_combine_range(V.begin(), V.end());
4046
}
4147

48+
template <typename T, std::size_t N>
49+
inline llvm::hash_code hash_value(const array<T, N> &A) {
50+
return llvm::hash_combine_range(A.begin(), A.end());
51+
}
52+
4253
template <typename... T> struct hash<tuple<T...>> {
4354
size_t operator()(const tuple<T...> &Tuple) const noexcept {
4455
return llvm::hash_value(Tuple);

sycl-fusion/jit-compiler/include/JITContext.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ namespace jit_compiler {
2929

3030
using CacheKeyT =
3131
std::tuple<std::vector<std::string>, ParamIdentList, int,
32-
std::vector<ParameterInternalization>, std::vector<JITConstant>>;
32+
std::vector<ParameterInternalization>, std::vector<JITConstant>,
33+
// This field of the cache is optional because, if all of the
34+
// ranges are equal, we will perform no remapping, so that fused
35+
// kernels can be reused with different lists of equal nd-ranges.
36+
std::optional<std::vector<NDRange>>>;
3337

3438
///
3539
/// Wrapper around a SPIR-V binary.

0 commit comments

Comments
 (0)