Skip to content

Commit eff2d2f

Browse files
committed
[SYCL][Fusion] Enable fusion of kernels with different ND-ranges
All kernels with the same (or unspecified) local size and offset can be fused. In order to make this work, some builtins getting index space information must be remapped and the resulting ND-range of the fused kernel, calculated. The ND-range of the fused kernel will have: 1. The same number of dimensions as the input ND-range with the higher number of dimensions; 2. The same local size as the shared local size (or unspecified) 3. The same offset as the shared offset 4. The global size will be the **greatest** input global size as per the following ordering: i. Number of work items (enforces correctness); ii. Number of occurrences (less remappings needed); iii. Lexical order of the dimensions (introduces determinism). Builtins obtaining the local/global size/id, work-group id, number of work-groups or offset are remapped introducing as per an alwaysinline function that can be reused along the fusion pass. More information can be found in the Builtins.cpp file, where the remapping logic is implemented. Signed-off-by: Victor Perez <[email protected]>
1 parent 16590f9 commit eff2d2f

22 files changed

+1768
-92
lines changed

sycl-fusion/common/include/Kernel.h

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#define SYCL_FUSION_COMMON_KERNEL_H
1111

1212
#include <algorithm>
13+
#include <cassert>
14+
#include <map>
1315
#include <string>
1416
#include <vector>
1517

@@ -97,6 +99,112 @@ struct SYCLArgumentDescriptor {
9799
/// List of SYCL/OpenCL kernel attributes.
98100
using AttributeList = std::vector<SYCLKernelAttribute>;
99101

102+
using Indices = std::array<size_t, 3>;
103+
104+
///
105+
/// Class to model SYCL nd_range
106+
class NDRange {
107+
public:
108+
constexpr static Indices AllZeros{0, 0, 0};
109+
110+
///
111+
/// Returns whether or not two ranges can be fused.
112+
///
113+
/// In order to be fusable, two ranges must:
114+
/// 1. Both have the local size specified or not;
115+
/// 2. If the local size is specified, it must be equal.
116+
/// 3. Have the same offset;
117+
static bool compatibleRanges(const NDRange &LHS, const NDRange &RHS) {
118+
const auto Dimensions = std::max(LHS.getDimensions(), RHS.getDimensions());
119+
const auto EqualIndices = [Dimensions](const Indices &LHS,
120+
const Indices &RHS) {
121+
return std::equal(LHS.begin(), LHS.begin() + Dimensions, RHS.begin());
122+
};
123+
return (EqualIndices(LHS.getLocalSize(), RHS.getLocalSize()) ||
124+
LHS.getLocalSize() == AllZeros || RHS.getLocalSize() == AllZeros) &&
125+
EqualIndices(LHS.getOffset(), RHS.getOffset());
126+
}
127+
128+
///
129+
/// Return whether a combination of ND-ranges is valid for fusion.
130+
template <typename InputIt>
131+
static bool isValidCombination(InputIt Begin, InputIt End) {
132+
if (Begin == End) {
133+
return false;
134+
}
135+
const auto &AllZeros = NDRange::AllZeros;
136+
const auto FirstSpecLocal =
137+
std::find_if(Begin, End, [&AllZeros](const auto &ND) {
138+
return ND.getLocalSize() != AllZeros;
139+
});
140+
return std::all_of(Begin, End,
141+
[&ND = FirstSpecLocal == End ? *Begin : *FirstSpecLocal](
142+
const auto &Other) {
143+
return NDRange::compatibleRanges(ND, Other);
144+
});
145+
}
146+
147+
///
148+
/// Return the product of each index in an indices array.
149+
constexpr static size_t linearize(const jit_compiler::Indices &I) {
150+
return I[0] * I[1] * I[2];
151+
}
152+
153+
NDRange() : NDRange{1, {1, 1, 1}} {}
154+
155+
NDRange(int Dimensions, const Indices &GlobalSize,
156+
const Indices &LocalSize = {1, 1, 1},
157+
const Indices &Offset = {0, 0, 0})
158+
: Dimensions{Dimensions},
159+
GlobalSize{GlobalSize}, LocalSize{LocalSize}, Offset{Offset} {
160+
#ifndef NDEBUG
161+
const auto CheckDim = [Dimensions](const Indices &Range) {
162+
return std::all_of(Range.begin() + Dimensions, Range.end(),
163+
[](auto D) { return D == 1; });
164+
};
165+
const auto CheckOffsetDim = [Dimensions](const Indices &Offset) {
166+
return std::all_of(Offset.begin() + Dimensions, Offset.end(),
167+
168+
[](auto D) { return D == 0; });
169+
};
170+
#endif // NDEBUG
171+
assert(CheckDim(GlobalSize) &&
172+
"Invalid global range for number of dimensions");
173+
assert(
174+
(CheckDim(LocalSize) || std::all_of(LocalSize.begin(), LocalSize.end(),
175+
[](auto D) { return D == 0; })) &&
176+
"Invalid local range for number of dimensions");
177+
assert(CheckOffsetDim(Offset) && "Invalid offset for number of dimensions");
178+
}
179+
180+
constexpr const Indices &getGlobalSize() const { return GlobalSize; }
181+
constexpr const Indices &getLocalSize() const { return LocalSize; }
182+
constexpr const Indices &getOffset() const { return Offset; }
183+
constexpr int getDimensions() const { return Dimensions; }
184+
185+
friend constexpr bool operator==(const NDRange &LHS, const NDRange &RHS) {
186+
return LHS.Dimensions == RHS.Dimensions &&
187+
LHS.GlobalSize == RHS.GlobalSize &&
188+
(LHS.LocalSize == AllZeros || RHS.LocalSize == AllZeros ||
189+
LHS.LocalSize == RHS.LocalSize) &&
190+
LHS.Offset == RHS.Offset;
191+
}
192+
193+
friend constexpr bool operator!=(const NDRange &LHS, const NDRange &RHS) {
194+
return !(LHS == RHS);
195+
}
196+
197+
private:
198+
/** @brief The number of dimensions. */
199+
int Dimensions;
200+
/** @brief The local range. */
201+
Indices GlobalSize;
202+
/** @brief The local range. */
203+
Indices LocalSize;
204+
/** @brief The offet. */
205+
Indices Offset;
206+
};
207+
100208
/// Information about a kernel from DPC++.
101209
struct SYCLKernelInfo {
102210

@@ -106,18 +214,21 @@ struct SYCLKernelInfo {
106214

107215
AttributeList Attributes;
108216

217+
NDRange NDR;
218+
109219
SYCLKernelBinaryInfo BinaryInfo;
110220

111221
//// Explicit constructor for compatibility with LLVM YAML I/O.
112-
SYCLKernelInfo() : Name{}, Args{}, Attributes{}, BinaryInfo{} {}
222+
SYCLKernelInfo() : Name{}, Args{}, Attributes{}, NDR{}, BinaryInfo{} {}
113223

114224
SYCLKernelInfo(const std::string &KernelName,
115-
const SYCLArgumentDescriptor &ArgDesc,
225+
const SYCLArgumentDescriptor &ArgDesc, const NDRange &NDR,
116226
const SYCLKernelBinaryInfo &BinInfo)
117-
: Name{KernelName}, Args{ArgDesc}, Attributes{}, BinaryInfo{BinInfo} {}
227+
: Name{KernelName}, Args{ArgDesc}, Attributes{}, NDR{NDR}, BinaryInfo{
228+
BinInfo} {}
118229

119230
explicit SYCLKernelInfo(const std::string &KernelName)
120-
: Name{KernelName}, Args{}, Attributes{}, BinaryInfo{} {}
231+
: Name{KernelName}, Args{}, Attributes{}, NDR{}, BinaryInfo{} {}
121232
};
122233

123234
///

sycl-fusion/jit-compiler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_llvm_library(sycl-fusion
55
lib/translation/SPIRVLLVMTranslation.cpp
66
lib/fusion/FusionPipeline.cpp
77
lib/fusion/FusionHelper.cpp
8+
lib/fusion/NDRangesHelper.cpp
89
lib/fusion/ModuleHelper.cpp
910
lib/helper/ConfigHelper.cpp
1011

sycl-fusion/jit-compiler/include/Hashing.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef SYCL_FUSION_JIT_COMPILER_HASHING_H
1010
#define SYCL_FUSION_JIT_COMPILER_HASHING_H
1111

12+
#include "Kernel.h"
1213
#include "Parameter.h"
1314

1415
#include "llvm/ADT/Hashing.h"
@@ -32,13 +33,23 @@ inline llvm::hash_code hash_value(const JITConstant &C) {
3233
inline llvm::hash_code hash_value(const ParameterIdentity &IP) {
3334
return llvm::hash_combine(IP.LHS, IP.RHS);
3435
}
36+
37+
inline llvm::hash_code hash_value(const NDRange &ND) {
38+
return llvm::hash_combine(ND.getDimensions(), ND.getGlobalSize(),
39+
ND.getLocalSize(), ND.getOffset());
40+
}
3541
} // namespace jit_compiler
3642

3743
namespace std {
3844
template <typename T> inline llvm::hash_code hash_value(const vector<T> &V) {
3945
return llvm::hash_combine_range(V.begin(), V.end());
4046
}
4147

48+
template <typename T, std::size_t N>
49+
inline llvm::hash_code hash_value(const array<T, N> &A) {
50+
return llvm::hash_combine_range(A.begin(), A.end());
51+
}
52+
4253
template <typename... T> struct hash<tuple<T...>> {
4354
size_t operator()(const tuple<T...> &Tuple) const noexcept {
4455
return llvm::hash_value(Tuple);

sycl-fusion/jit-compiler/include/JITContext.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ namespace jit_compiler {
2929

3030
using CacheKeyT =
3131
std::tuple<std::vector<std::string>, ParamIdentList, int,
32-
std::vector<ParameterInternalization>, std::vector<JITConstant>>;
32+
std::vector<ParameterInternalization>, std::vector<JITConstant>,
33+
std::optional<std::vector<NDRange>>>;
3334

3435
///
3536
/// Wrapper around a SPIR-V binary.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//==--------- NDRangesHelper.h - Helpers to handle ND-ranges ---------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SYCL_FUSION_JIT_COMPILER_NDRANGESHELPER_H
10+
#define SYCL_FUSION_JIT_COMPILER_NDRANGESHELPER_H
11+
12+
#include "Kernel.h"
13+
14+
#include <algorithm>
15+
16+
#include "llvm/ADT/ArrayRef.h"
17+
18+
namespace jit_compiler {
19+
///
20+
/// Combine a list of ND-ranges, obtaining the resulting "fused" ND-range.
21+
NDRange combineNDRanges(llvm::ArrayRef<NDRange> NDRanges);
22+
23+
///
24+
/// Returns whether the input list of ND-ranges is heterogeneous or not.
25+
inline bool isHeterogeneousList(llvm::ArrayRef<NDRange> NDRanges) {
26+
const auto *Begin = NDRanges.begin();
27+
const auto *End = NDRanges.end();
28+
const auto *FirstSpecLocal =
29+
std::find_if(Begin, End, [&AllZeros = NDRange::AllZeros](const auto &ND) {
30+
return ND.getLocalSize() != AllZeros;
31+
});
32+
return std::any_of(Begin, End,
33+
[&ND = FirstSpecLocal == End ? *Begin : *FirstSpecLocal](
34+
const auto &Other) { return ND != Other; });
35+
}
36+
} // namespace jit_compiler
37+
38+
#endif // SYCL_FUSION_JIT_COMPILER_NDRANGESHELPER_H

sycl-fusion/jit-compiler/lib/KernelFusion.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ static FusionResult errorToFusionResult(llvm::Error &&Err,
3636
return FusionResult{ErrMsg.str()};
3737
}
3838

39+
static std::vector<jit_compiler::NDRange>
40+
gatherNDRanges(llvm::ArrayRef<SYCLKernelInfo> KernelInformation) {
41+
std::vector<jit_compiler::NDRange> NDRanges;
42+
NDRanges.reserve(KernelInformation.size());
43+
std::transform(KernelInformation.begin(), KernelInformation.end(),
44+
std::back_inserter(NDRanges),
45+
[](const auto &I) { return I.NDR; });
46+
return NDRanges;
47+
}
48+
3949
FusionResult KernelFusion::fuseKernels(
4050
JITContext &JITCtx, Config &&JITConfig,
4151
const std::vector<SYCLKernelInfo> &KernelInformation,
@@ -44,14 +54,21 @@ FusionResult KernelFusion::fuseKernels(
4454
int BarriersFlags,
4555
const std::vector<jit_compiler::ParameterInternalization> &Internalization,
4656
const std::vector<jit_compiler::JITConstant> &Constants) {
57+
const auto NDRanges = gatherNDRanges(KernelInformation);
4758

4859
// Initialize the configuration helper to make the options for this invocation
4960
// available (on a per-thread basis).
5061
ConfigHelper::setConfig(std::move(JITConfig));
5162

5263
bool CachingEnabled = ConfigHelper::get<option::JITEnableCaching>();
53-
CacheKeyT CacheKey{KernelsToFuse, Identities, BarriersFlags, Internalization,
54-
Constants};
64+
CacheKeyT CacheKey{KernelsToFuse,
65+
Identities,
66+
BarriersFlags,
67+
Internalization,
68+
Constants,
69+
jit_compiler::isHeterogeneousList(NDRanges)
70+
? std::optional<std::vector<NDRange>>{NDRanges}
71+
: std::optional<std::vector<NDRange>>{}};
5572
if (CachingEnabled) {
5673
std::optional<SYCLKernelInfo> CachedKernel = JITCtx.getCacheEntry(CacheKey);
5774
if (CachedKernel) {
@@ -82,8 +99,9 @@ FusionResult KernelFusion::fuseKernels(
8299

83100
// Add information about the kernel that should be fused as metadata into the
84101
// LLVM module.
85-
FusedFunction FusedKernel{FusedKernelName, KernelsToFuse,
86-
std::move(Identities), Internalization, Constants};
102+
FusedFunction FusedKernel{
103+
FusedKernelName, KernelsToFuse, std::move(Identities),
104+
Internalization, Constants, NDRanges};
87105
FusedFunctionList FusedKernelList;
88106
FusedKernelList.push_back(FusedKernel);
89107
llvm::Expected<std::unique_ptr<llvm::Module>> NewModOrError =
@@ -116,6 +134,8 @@ FusionResult KernelFusion::fuseKernels(
116134
}
117135
jit_compiler::SPIRVBinary *SPIRVBin = *BinaryOrError;
118136

137+
FusedKernelInfo.NDR = FusedKernel.FusedNDRange;
138+
119139
// Update the KernelInfo for the fused kernel with the address and size of the
120140
// SPIR-V binary resulting from translation.
121141
SYCLKernelBinaryInfo &FusedBinaryInfo = FusedKernelInfo.BinaryInfo;

sycl-fusion/jit-compiler/lib/fusion/FusionHelper.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "ModuleHelper.h"
1212
#include "helper/ErrorHandling.h"
13+
#include "kernel-fusion/SYCLKernelFusion.h"
1314
#include "llvm/IR/Constants.h"
1415

1516
using namespace llvm;
@@ -81,6 +82,48 @@ Expected<std::unique_ptr<Module>> helper::FusionHelper::addFusedKernel(
8182
// The metadata can be identified by this fixed string providing a kind.
8283
F->setMetadata(MetadataKind, MDList);
8384

85+
// Attach ND-ranges related information. User of this API must pass the
86+
// following information for each kernel, as well as for the fused kernel:
87+
// 1. Number of dimensions;
88+
// 2. Global size;
89+
// 3. Local size;
90+
// 4. Offset
91+
{
92+
const auto MDFromND = [&LLVMCtx](const auto &ND) {
93+
auto MDFromIndices = [&LLVMCtx](const auto &Ind) -> Metadata * {
94+
std::array<Metadata *, Ind.size()> MD{nullptr};
95+
std::transform(
96+
Ind.begin(), Ind.end(), MD.begin(),
97+
[&LLVMCtx](auto I) { return getConstantIntMD(LLVMCtx, I); });
98+
return MDNode::get(LLVMCtx, MD);
99+
};
100+
std::array<Metadata *, 4> MD;
101+
MD[0] = getConstantIntMD(LLVMCtx, ND.getDimensions());
102+
MD[1] = MDFromIndices(ND.getGlobalSize());
103+
MD[2] = MDFromIndices(ND.getLocalSize());
104+
MD[3] = MDFromIndices(ND.getOffset());
105+
return MDNode::get(LLVMCtx, MD);
106+
};
107+
108+
// Attach ND-range of the fused kernel
109+
{
110+
assert(!F->hasMetadata(SYCLKernelFusion::NDRangeMDKey));
111+
F->setMetadata(SYCLKernelFusion::NDRangeMDKey,
112+
MDFromND(FF.FusedNDRange));
113+
}
114+
115+
// Attach ND-ranges of each kernel to be fused
116+
{
117+
const auto SrcNDRanges = FF.NDRanges;
118+
SmallVector<Metadata *> Nodes;
119+
std::transform(SrcNDRanges.begin(), SrcNDRanges.end(),
120+
std::back_inserter(Nodes), MDFromND);
121+
assert(!F->hasMetadata(SYCLKernelFusion::NDRangesMDKey));
122+
F->setMetadata(SYCLKernelFusion::NDRangesMDKey,
123+
MDNode::get(LLVMCtx, Nodes));
124+
}
125+
}
126+
84127
// The user of this API may be able to determine that
85128
// the same value is used for multiple input functions in the fused kernel,
86129
// e.g. when using the output of one kernel as the input to another kernel.

sycl-fusion/jit-compiler/lib/fusion/FusionHelper.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#ifndef SYCL_FUSION_JIT_COMPILER_FUSION_FUSIONHELPER_H
1010
#define SYCL_FUSION_JIT_COMPILER_FUSION_FUSIONHELPER_H
1111

12+
#include "Kernel.h"
13+
#include "NDRangesHelper.h"
1214
#include "Parameter.h"
1315
#include "llvm/IR/Module.h"
1416
#include <llvm/ADT/ArrayRef.h>
@@ -25,12 +27,27 @@ class FusionHelper {
2527
/// Representation of a fused kernel named FusedName and fusing
2628
/// all the kernels listed in FusedKernels.
2729
struct FusedFunction {
30+
FusedFunction(const std::string &FusedName,
31+
const std::vector<std::string> &FusedKernels,
32+
const jit_compiler::ParamIdentList &ParameterIdentities,
33+
llvm::ArrayRef<jit_compiler::ParameterInternalization>
34+
ParameterInternalization,
35+
llvm::ArrayRef<jit_compiler::JITConstant> Constants,
36+
llvm::ArrayRef<jit_compiler::NDRange> NDRanges)
37+
: FusedName{FusedName}, FusedKernels{FusedKernels},
38+
ParameterIdentities{ParameterIdentities},
39+
ParameterInternalization{ParameterInternalization},
40+
NDRanges{NDRanges}, FusedNDRange{
41+
jit_compiler::combineNDRanges(NDRanges)} {}
42+
2843
std::string FusedName;
2944
std::vector<std::string> FusedKernels;
3045
jit_compiler::ParamIdentList ParameterIdentities;
3146
llvm::ArrayRef<jit_compiler::ParameterInternalization>
3247
ParameterInternalization;
3348
llvm::ArrayRef<jit_compiler::JITConstant> Constants;
49+
llvm::ArrayRef<jit_compiler::NDRange> NDRanges;
50+
jit_compiler::NDRange FusedNDRange;
3451
};
3552

3653
///

0 commit comments

Comments
 (0)