Skip to content

[SYCL][Fusion] Enable fusion of kernels with different ND-ranges #8209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sycl-fusion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ set(LLVM_SPIRV_INCLUDE_DIRS "${LLVM_MAIN_SRC_DIR}/../llvm-spirv/include")
if(WIN32)
message(WARNING "Kernel fusion not yet supported on Windows")
else(WIN32)
add_subdirectory(common)
add_subdirectory(jit-compiler)
add_subdirectory(passes)
add_subdirectory(test)
Expand Down
23 changes: 23 additions & 0 deletions sycl-fusion/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
add_llvm_library(sycl-fusion-common
lib/NDRangesHelper.cpp
)

target_include_directories(sycl-fusion-common
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/lib
)

if (BUILD_SHARED_LIBS)
if(NOT MSVC AND NOT APPLE)
# Manage symbol visibility through the linker to make sure no LLVM symbols
# are exported and confuse the drivers.
set(linker_script "${CMAKE_CURRENT_SOURCE_DIR}/ld-version-script.txt")
target_link_libraries(
sycl-fusion-common PRIVATE "-Wl,--version-script=${linker_script}")
set_target_properties(sycl-fusion-common
PROPERTIES
LINK_DEPENDS
${linker_script})
endif()
endif()
84 changes: 80 additions & 4 deletions sycl-fusion/common/include/Kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#define SYCL_FUSION_COMMON_KERNEL_H

#include <algorithm>
#include <array>
#include <cassert>
#include <string>
#include <vector>

Expand Down Expand Up @@ -97,6 +99,77 @@ struct SYCLArgumentDescriptor {
/// List of SYCL/OpenCL kernel attributes.
using AttributeList = std::vector<SYCLKernelAttribute>;

using Indices = std::array<size_t, 3>;

///
/// Class to model SYCL nd_range
class NDRange {
public:
constexpr static Indices AllZeros{0, 0, 0};

///
/// Return the product of each index in an indices array.
constexpr static size_t linearize(const jit_compiler::Indices &I) {
return I[0] * I[1] * I[2];
}

NDRange() : NDRange{1, {1, 1, 1}} {}

NDRange(int Dimensions, const Indices &GlobalSize,
const Indices &LocalSize = {1, 1, 1},
const Indices &Offset = {0, 0, 0})
: Dimensions{Dimensions},
GlobalSize{GlobalSize}, LocalSize{LocalSize}, Offset{Offset} {
#ifndef NDEBUG
const auto CheckDim = [Dimensions](const Indices &Range) {
return std::all_of(Range.begin() + Dimensions, Range.end(),
[](auto D) { return D == 1; });
};
const auto CheckOffsetDim = [Dimensions](const Indices &Offset) {
return std::all_of(Offset.begin() + Dimensions, Offset.end(),

[](auto D) { return D == 0; });
};
#endif // NDEBUG
assert(CheckDim(GlobalSize) &&
"Invalid global range for number of dimensions");
assert(
(CheckDim(LocalSize) || std::all_of(LocalSize.begin(), LocalSize.end(),
[](auto D) { return D == 0; })) &&
"Invalid local range for number of dimensions");
assert(CheckOffsetDim(Offset) && "Invalid offset for number of dimensions");
}

constexpr const Indices &getGlobalSize() const { return GlobalSize; }
constexpr const Indices &getLocalSize() const { return LocalSize; }
constexpr const Indices &getOffset() const { return Offset; }
constexpr int getDimensions() const { return Dimensions; }

bool hasSpecificLocalSize() const { return LocalSize != AllZeros; }

friend constexpr bool operator==(const NDRange &LHS, const NDRange &RHS) {
return LHS.Dimensions == RHS.Dimensions &&
LHS.GlobalSize == RHS.GlobalSize &&
(!LHS.hasSpecificLocalSize() || !RHS.hasSpecificLocalSize() ||
LHS.LocalSize == RHS.LocalSize) &&
LHS.Offset == RHS.Offset;
}

friend constexpr bool operator!=(const NDRange &LHS, const NDRange &RHS) {
return !(LHS == RHS);
}

private:
/** @brief The number of dimensions. */
int Dimensions;
/** @brief The local range. */
Indices GlobalSize;
/** @brief The local range. */
Indices LocalSize;
/** @brief The offet. */
Indices Offset;
};

/// Information about a kernel from DPC++.
struct SYCLKernelInfo {

Expand All @@ -106,18 +179,21 @@ struct SYCLKernelInfo {

AttributeList Attributes;

NDRange NDR;

SYCLKernelBinaryInfo BinaryInfo;

//// Explicit constructor for compatibility with LLVM YAML I/O.
SYCLKernelInfo() : Name{}, Args{}, Attributes{}, BinaryInfo{} {}
SYCLKernelInfo() : Name{}, Args{}, Attributes{}, NDR{}, BinaryInfo{} {}

SYCLKernelInfo(const std::string &KernelName,
const SYCLArgumentDescriptor &ArgDesc,
const SYCLArgumentDescriptor &ArgDesc, const NDRange &NDR,
const SYCLKernelBinaryInfo &BinInfo)
: Name{KernelName}, Args{ArgDesc}, Attributes{}, BinaryInfo{BinInfo} {}
: Name{KernelName}, Args{ArgDesc}, Attributes{}, NDR{NDR}, BinaryInfo{
BinInfo} {}

explicit SYCLKernelInfo(const std::string &KernelName)
: Name{KernelName}, Args{}, Attributes{}, BinaryInfo{} {}
: Name{KernelName}, Args{}, Attributes{}, NDR{}, BinaryInfo{} {}
};

///
Expand Down
8 changes: 8 additions & 0 deletions sycl-fusion/common/ld-version-script.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
global:
/* Export everything from jit_compiler namespace */
_ZN12jit_compiler*;

local:
*;
};
99 changes: 99 additions & 0 deletions sycl-fusion/common/lib/NDRangesHelper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//==-------------------------- NDRangesHelper.cpp --------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "NDRangesHelper.h"

#include <map>

using namespace jit_compiler;
using namespace llvm;

///
/// Return the maximal global size using the following order:
/// 1. A range is greater than another range if it contains more elements (see
/// linearize());
/// 2. Else if it appears more times in the input list of ranges;
/// 3. Else if it is greater in lexicographical order.
static Indices getMaximalGlobalSize(ArrayRef<NDRange> NDRanges) {
size_t NumElements{0};
std::map<Indices, size_t> FreqMap;
for (const auto &ND : NDRanges) {
const auto &GS = ND.getGlobalSize();
const auto N = NDRange::linearize(GS);
if (N < NumElements) {
continue;
}
if (N > NumElements) {
NumElements = N;
FreqMap.clear();
}
++FreqMap[GS];
}
return std::max_element(FreqMap.begin(), FreqMap.end(),
[](const auto &LHS, const auto &RHS) {
const auto LHSN = LHS.second;
const auto RHSN = RHS.second;
if (LHSN < RHSN) {
return true;
}
if (LHSN > RHSN) {
return false;
}
return LHS.first < RHS.first;
})
->first;
}

static bool compatibleRanges(const NDRange &LHS, const NDRange &RHS) {
const auto Dimensions = std::max(LHS.getDimensions(), RHS.getDimensions());
const auto EqualIndices = [Dimensions](const Indices &LHS,
const Indices &RHS) {
return std::equal(LHS.begin(), LHS.begin() + Dimensions, RHS.begin());
};
return (!LHS.hasSpecificLocalSize() || !RHS.hasSpecificLocalSize() ||
EqualIndices(LHS.getLocalSize(), RHS.getLocalSize())) &&
EqualIndices(LHS.getOffset(), RHS.getOffset());
}

NDRange jit_compiler::combineNDRanges(ArrayRef<NDRange> NDRanges) {
assert(isValidCombination(NDRanges) && "Invalid ND-ranges combination");
const auto Dimensions =
std::max_element(NDRanges.begin(), NDRanges.end(),
[](const auto &LHS, const auto &RHS) {
return LHS.getDimensions() < RHS.getDimensions();
})
->getDimensions();
const auto GlobalSize = getMaximalGlobalSize(NDRanges);
const auto *End = NDRanges.end();
const auto *LocalSizeIter = findSpecifiedLocalSize(NDRanges);
const auto &LocalSize =
LocalSizeIter == End ? NDRange::AllZeros : LocalSizeIter->getLocalSize();
const auto &Front = NDRanges.front();
return {Dimensions, GlobalSize, LocalSize, Front.getOffset()};
}

bool jit_compiler::isHeterogeneousList(ArrayRef<NDRange> NDRanges) {
const auto *FirstSpecifiedLocalSize = findSpecifiedLocalSize(NDRanges);
const auto &ND = FirstSpecifiedLocalSize == NDRanges.end()
? NDRanges.front()
: *FirstSpecifiedLocalSize;
return any_of(NDRanges, [&ND](const auto &Other) { return ND != Other; });
}

bool jit_compiler::isValidCombination(llvm::ArrayRef<NDRange> NDRanges) {
if (NDRanges.empty()) {
return false;
}
const auto *FirstSpecifiedLocalSize = findSpecifiedLocalSize(NDRanges);
const auto &ND = FirstSpecifiedLocalSize == NDRanges.end()
? NDRanges.front()
: *FirstSpecifiedLocalSize;
return llvm::all_of(NDRanges, [&ND](const auto &Other) {
return compatibleRanges(ND, Other);
});
}
37 changes: 37 additions & 0 deletions sycl-fusion/common/lib/NDRangesHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//==--------- NDRangesHelper.h - Helpers to handle ND-ranges ---------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef SYCL_FUSION_COMMON_NDRANGESHELPER_H
#define SYCL_FUSION_COMMON_NDRANGESHELPER_H

#include "Kernel.h"

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"

namespace jit_compiler {
///
/// Combine a list of ND-ranges, obtaining the resulting "fused" ND-range.
NDRange combineNDRanges(llvm::ArrayRef<NDRange> NDRanges);

inline llvm::ArrayRef<NDRange>::const_iterator
findSpecifiedLocalSize(llvm::ArrayRef<NDRange> NDRanges) {
return llvm::find_if(
NDRanges, [](const auto &ND) { return ND.hasSpecificLocalSize(); });
}

///
/// Returns whether the input list of ND-ranges is heterogeneous or not.
bool isHeterogeneousList(llvm::ArrayRef<NDRange> NDRanges);

///
/// Return whether a combination of ND-ranges is valid for fusion.
bool isValidCombination(llvm::ArrayRef<NDRange> NDRanges);
} // namespace jit_compiler

#endif // SYCL_FUSION_COMMON_NDRANGESHELPER_H
2 changes: 2 additions & 0 deletions sycl-fusion/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_llvm_library(sycl-fusion
Core
Support
Analysis
IPO
TransformUtils
Passes
Linker
Expand All @@ -33,6 +34,7 @@ find_package(Threads REQUIRED)

target_link_libraries(sycl-fusion
PRIVATE
sycl-fusion-common
LLVMSPIRVLib
SYCLKernelFusionPasses
${CMAKE_THREAD_LIBS_INIT}
Expand Down
11 changes: 11 additions & 0 deletions sycl-fusion/jit-compiler/include/Hashing.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef SYCL_FUSION_JIT_COMPILER_HASHING_H
#define SYCL_FUSION_JIT_COMPILER_HASHING_H

#include "Kernel.h"
#include "Parameter.h"

#include "llvm/ADT/Hashing.h"
Expand All @@ -32,13 +33,23 @@ inline llvm::hash_code hash_value(const JITConstant &C) {
inline llvm::hash_code hash_value(const ParameterIdentity &IP) {
return llvm::hash_combine(IP.LHS, IP.RHS);
}

inline llvm::hash_code hash_value(const NDRange &ND) {
return llvm::hash_combine(ND.getDimensions(), ND.getGlobalSize(),
ND.getLocalSize(), ND.getOffset());
}
} // namespace jit_compiler

namespace std {
template <typename T> inline llvm::hash_code hash_value(const vector<T> &V) {
return llvm::hash_combine_range(V.begin(), V.end());
}

template <typename T, std::size_t N>
inline llvm::hash_code hash_value(const array<T, N> &A) {
return llvm::hash_combine_range(A.begin(), A.end());
}

template <typename... T> struct hash<tuple<T...>> {
size_t operator()(const tuple<T...> &Tuple) const noexcept {
return llvm::hash_value(Tuple);
Expand Down
6 changes: 5 additions & 1 deletion sycl-fusion/jit-compiler/include/JITContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ namespace jit_compiler {

using CacheKeyT =
std::tuple<std::vector<std::string>, ParamIdentList, int,
std::vector<ParameterInternalization>, std::vector<JITConstant>>;
std::vector<ParameterInternalization>, std::vector<JITConstant>,
// This field of the cache is optional because, if all of the
// ranges are equal, we will perform no remapping, so that fused
// kernels can be reused with different lists of equal nd-ranges.
std::optional<std::vector<NDRange>>>;

///
/// Wrapper around a SPIR-V binary.
Expand Down
Loading