Skip to content

[SYCL] Add extension and implement fp control kernel property #11591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions llvm/lib/SYCLLowerIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ endif()
if (NOT TARGET LLVMGenXIntrinsics)
if (NOT DEFINED LLVMGenXIntrinsics_SOURCE_DIR)
set(LLVMGenXIntrinsics_GIT_REPO https://github.com/intel/vc-intrinsics.git)
# Author: Jinsong Ji <[email protected]>
# Date: Thu Aug 10 14:41:52 2023 +0000
# Guard removed typed pointer enum within version macro
set(LLVMGenXIntrinsics_GIT_TAG 17a53f4304463b8e7e639d57ef17479040a8a2ad)
# Author: Artur Gainullin <[email protected]>
# Date: Thu Nov 9 00:37:24 2023 +0000

# Replace old kernel with rewritten kernel in metadata only since LLVM 17
set(LLVMGenXIntrinsics_GIT_TAG a8403355ada112b72d1fc7db29fd04325eecee60)

message(STATUS "vc-intrinsics repo is missing. Will try to download it from ${LLVMGenXIntrinsics_GIT_REPO}")
include(FetchContent)
Expand Down
79 changes: 79 additions & 0 deletions llvm/lib/SYCLLowerIR/CompileTimePropertiesPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,36 @@ const StringMap<Decor> SpirvDecorMap = {
};
#undef SYCL_COMPILE_TIME_PROPERTY

// Masks defined here must be in sync with the SYCL header with fp control
// kernel property.
enum FloatControl {
RTE = 1, // Round to nearest or even
RTP = 1 << 1, // Round towards +ve inf
RTN = 1 << 2, // Round towards -ve inf
RTZ = 1 << 3, // Round towards zero

DENORM_FTZ = 1 << 4, // Denorm mode flush to zero
DENORM_D_ALLOW = 1 << 5, // Denorm mode double allow
DENORM_F_ALLOW = 1 << 6, // Denorm mode float allow
DENORM_HF_ALLOW = 1 << 7 // Denorm mode half allow
};

enum FloatControlMask {
ROUND_MASK = (RTE | RTP | RTN | RTZ),
DENORM_MASK = (DENORM_D_ALLOW | DENORM_F_ALLOW | DENORM_HF_ALLOW)
};

// SPIRV execution modes for FP control.
// These opcodes are specified in SPIRV specification (SPV_KHR_float_controls
// and SPV_INTEL_float_controls2 extensions):
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.pdf
constexpr uint32_t SPIRV_ROUNDING_MODE_RTE = 4462; // RoundingModeRTE
constexpr uint32_t SPIRV_ROUNDING_MODE_RTZ = 4463; // RoundingModeRTZ
constexpr uint32_t SPIRV_ROUNDING_MODE_RTP_INTEL = 5620; // RoundingModeRTPINTEL
constexpr uint32_t SPIRV_ROUNDING_MODE_RTN_INTEL = 5621; // RoundingModeRTNINTEL
constexpr uint32_t SPIRV_DENORM_FLUSH_TO_ZERO = 4460; // DenormFlushToZero
constexpr uint32_t SPIRV_DENORM_PRESERVE = 4459; // DenormPreserve

/// Builds a metadata node for a SPIR-V decoration (decoration code is
/// \c uint32_t integers) with no value.
///
Expand Down Expand Up @@ -282,6 +312,55 @@ attributeToExecModeMetadata(const Attribute &Attr, Function &F) {
if (!AttrKindStr.startswith("sycl-"))
return std::nullopt;

auto AddFPControlMetadataForWidth = [&](int32_t SPIRVFPControl,
int32_t Width) {
auto NamedMD = M.getOrInsertNamedMetadata("spirv.ExecutionMode");
SmallVector<Metadata *, 4> ValueVec;
ValueVec.push_back(ConstantAsMetadata::get(&F));
ValueVec.push_back(ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), SPIRVFPControl)));
ValueVec.push_back(ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), Width)));
NamedMD->addOperand(MDNode::get(Ctx, ValueVec));
};

auto AddFPControlMetadata = [&](int32_t SPIRVFPControl) {
for (int32_t Width : {64, 32, 16}) {
AddFPControlMetadataForWidth(SPIRVFPControl, Width);
}
};

if (AttrKindStr == "sycl-floating-point-control") {
uint32_t FPControl = getAttributeAsInteger<uint32_t>(Attr);
auto IsFPModeSet = [FPControl](FloatControl Flag) -> bool {
return (FPControl & Flag) == Flag;
};

if (IsFPModeSet(RTE))
AddFPControlMetadata(SPIRV_ROUNDING_MODE_RTE);

if (IsFPModeSet(RTP))
AddFPControlMetadata(SPIRV_ROUNDING_MODE_RTP_INTEL);

if (IsFPModeSet(RTN))
AddFPControlMetadata(SPIRV_ROUNDING_MODE_RTN_INTEL);

if (IsFPModeSet(RTZ))
AddFPControlMetadata(SPIRV_ROUNDING_MODE_RTZ);

if (IsFPModeSet(DENORM_FTZ))
AddFPControlMetadata(SPIRV_DENORM_FLUSH_TO_ZERO);

if (IsFPModeSet(DENORM_HF_ALLOW))
AddFPControlMetadataForWidth(SPIRV_DENORM_PRESERVE, 16);

if (IsFPModeSet(DENORM_F_ALLOW))
AddFPControlMetadataForWidth(SPIRV_DENORM_PRESERVE, 32);

if (IsFPModeSet(DENORM_D_ALLOW))
AddFPControlMetadataForWidth(SPIRV_DENORM_PRESERVE, 64);
}

if (AttrKindStr == "sycl-work-group-size" ||
AttrKindStr == "sycl-work-group-size-hint") {
// Split values in the comma-separated list integers.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
; RUN: opt -passes=compile-time-properties %s -S | FileCheck %s


define spir_kernel void @"Kernel0"() #0 {
entry:
ret void
}

define spir_kernel void @"Kernel1"() #1 {
entry:
ret void
}

define spir_kernel void @"Kernel2"() #2 {
entry:
ret void
}

define spir_kernel void @"Kernel3"() #3 {
entry:
ret void
}

define spir_kernel void @"Kernel4"() #4 {
entry:
ret void
}

define spir_kernel void @"Kernel5"() #5 {
entry:
ret void
}

define spir_kernel void @"Kernel6"() #6 {
entry:
ret void
}

define spir_kernel void @"Kernel7"() #7 {
entry:
ret void
}

define spir_kernel void @"Kernel8"() #8 {
entry:
ret void
}

define spir_kernel void @"Kernel9"() #9 {
entry:
ret void
}

define spir_kernel void @"Kernel10"() #10 {
entry:
ret void
}

; SPIRV execution modes for FP control. | BitMask
; ROUNDING_MODE_RTE = 4462; | 00000001
; ROUNDING_MODE_RTP_INTEL = 5620; | 00000010
; ROUNDING_MODE_RTN_INTEL = 5621; | 00000100
; ROUNDING_MODE_RTZ = 4463; | 00001000
; DEMORM_FLUSH_TO_ZERO = 4460; | 00010000
; DENORM_PRESERVE (double) = 4459; | 00100000
; DENORM_PRESERVE (float) = 4459; | 01000000
; DENORM_PRESERVE (half) = 4459; | 10000000

; rte + ftz (Default)
; CHECK: !0 = !{ptr @Kernel0, i32 [[RTE:4462]], i32 64}
; CHECK: !1 = !{ptr @Kernel0, i32 [[RTE]], i32 32}
; CHECK: !2 = !{ptr @Kernel0, i32 [[RTE]], i32 16}
; CHECK: !3 = !{ptr @Kernel0, i32 [[FTZ:4460]], i32 64}
; CHECK: !4 = !{ptr @Kernel0, i32 [[FTZ]], i32 32}
; CHECK: !5 = !{ptr @Kernel0, i32 [[FTZ]], i32 16}
attributes #0 = { "sycl-floating-point-control"="17" }

; rtp + ftz
; CHECK: !6 = !{ptr @Kernel1, i32 [[RTP:5620]], i32 64}
; CHECK: !7 = !{ptr @Kernel1, i32 [[RTP]], i32 32}
; CHECK: !8 = !{ptr @Kernel1, i32 [[RTP]], i32 16}
; CHECK: !9 = !{ptr @Kernel1, i32 [[FTZ]], i32 64}
; CHECK: !10 = !{ptr @Kernel1, i32 [[FTZ]], i32 32}
; CHECK: !11 = !{ptr @Kernel1, i32 [[FTZ]], i32 16}
attributes #1 = { "sycl-floating-point-control"="18" }

; rtn + ftz
; CHECK: !12 = !{ptr @Kernel2, i32 [[RTN:5621]], i32 64}
; CHECK: !13 = !{ptr @Kernel2, i32 [[RTN]], i32 32}
; CHECK: !14 = !{ptr @Kernel2, i32 [[RTN]], i32 16}
; CHECK: !15 = !{ptr @Kernel2, i32 [[FTZ]], i32 64}
; CHECK: !16 = !{ptr @Kernel2, i32 [[FTZ]], i32 32}
; CHECK: !17 = !{ptr @Kernel2, i32 [[FTZ]], i32 16}
attributes #2 = { "sycl-floating-point-control"="20" }

; rtz + ftz
; CHECK: !18 = !{ptr @Kernel3, i32 [[RTZ:4463]], i32 64}
; CHECK: !19 = !{ptr @Kernel3, i32 [[RTZ]], i32 32}
; CHECK: !20 = !{ptr @Kernel3, i32 [[RTZ]], i32 16}
; CHECK: !21 = !{ptr @Kernel3, i32 [[FTZ]], i32 64}
; CHECK: !22 = !{ptr @Kernel3, i32 [[FTZ]], i32 32}
; CHECK: !23 = !{ptr @Kernel3, i32 [[FTZ]], i32 16}
attributes #3 = { "sycl-floating-point-control"="24" }

; rte + denorm_preserve(double)
; CHECK: !24 = !{ptr @Kernel4, i32 [[RTE]], i32 64}
; CHECK: !25 = !{ptr @Kernel4, i32 [[RTE]], i32 32}
; CHECK: !26 = !{ptr @Kernel4, i32 [[RTE]], i32 16}
; CHECK: !27 = !{ptr @Kernel4, i32 [[DENORM_PRESERVE:4459]], i32 64}
attributes #4 = { "sycl-floating-point-control"="33" }

; rte + denorm_preserve(float)
; CHECK: !28 = !{ptr @Kernel5, i32 [[RTE]], i32 64}
; CHECK: !29 = !{ptr @Kernel5, i32 [[RTE]], i32 32}
; CHECK: !30 = !{ptr @Kernel5, i32 [[RTE]], i32 16}
; CHECK: !31 = !{ptr @Kernel5, i32 [[DENORM_PRESERVE]], i32 32}
attributes #5 = { "sycl-floating-point-control"="65" }

; rte + denorm_preserve(half)
; CHECK: !32 = !{ptr @Kernel6, i32 [[RTE]], i32 64}
; CHECK: !33 = !{ptr @Kernel6, i32 [[RTE]], i32 32}
; CHECK: !34 = !{ptr @Kernel6, i32 [[RTE]], i32 16}
; CHECK: !35 = !{ptr @Kernel6, i32 [[DENORM_PRESERVE]], i32 16}
attributes #6 = { "sycl-floating-point-control"="129" }

; rte + denorm_allow
; CHECK: !36 = !{ptr @Kernel7, i32 [[RTE]], i32 64}
; CHECK: !37 = !{ptr @Kernel7, i32 [[RTE]], i32 32}
; CHECK: !38 = !{ptr @Kernel7, i32 [[RTE]], i32 16}
; CHECK: !39 = !{ptr @Kernel7, i32 [[DENORM_PRESERVE]], i32 16}
; CHECK: !40 = !{ptr @Kernel7, i32 [[DENORM_PRESERVE]], i32 32}
; CHECK: !41 = !{ptr @Kernel7, i32 [[DENORM_PRESERVE]], i32 64}
attributes #7 = { "sycl-floating-point-control"="225" }

; rtz + denorm_preserve(double)
; CHECK: !42 = !{ptr @Kernel8, i32 [[RTZ]], i32 64}
; CHECK: !43 = !{ptr @Kernel8, i32 [[RTZ]], i32 32}
; CHECK: !44 = !{ptr @Kernel8, i32 [[RTZ]], i32 16}
; CHECK: !45 = !{ptr @Kernel8, i32 [[DENORM_PRESERVE]], i32 64}
attributes #8 = { "sycl-floating-point-control"="40" }

; rtp + denorm_preserve(float)
; CHECK: !46 = !{ptr @Kernel9, i32 [[RTP]], i32 64}
; CHECK: !47 = !{ptr @Kernel9, i32 [[RTP]], i32 32}
; CHECK: !48 = !{ptr @Kernel9, i32 [[RTP]], i32 16}
; CHECK: !49 = !{ptr @Kernel9, i32 [[DENORM_PRESERVE]], i32 32}
attributes #9 = { "sycl-floating-point-control"="66" }

; rtz + denorm_allow
; CHECK: !50 = !{ptr @Kernel10, i32 [[RTZ]], i32 64}
; CHECK: !51 = !{ptr @Kernel10, i32 [[RTZ]], i32 32}
; CHECK: !52 = !{ptr @Kernel10, i32 [[RTZ]], i32 16}
; CHECK: !53 = !{ptr @Kernel10, i32 [[DENORM_PRESERVE]], i32 16}
; CHECK: !54 = !{ptr @Kernel10, i32 [[DENORM_PRESERVE]], i32 32}
; CHECK: !55 = !{ptr @Kernel10, i32 [[DENORM_PRESERVE]], i32 64}
attributes #10 = { "sycl-floating-point-control"="232" }
55 changes: 55 additions & 0 deletions llvm/tools/sycl-post-link/ModuleSplitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,52 @@ void collectFunctionsAndGlobalVariablesToExtract(
}
}

// Check "spirv.ExecutionMode" named metadata in the module and remove nodes
// that reference kernels that have dead prototypes or don't reference any
// kernel at all (nullptr). Dead prototypes are removed as well.
void processSubModuleNamedMetadata(Module *M) {
auto ExecutionModeMD = M->getNamedMetadata("spirv.ExecutionMode");
if (!ExecutionModeMD)
return;

bool ContainsNodesToRemove = false;
std::vector<MDNode *> ValueVec;
for (auto Op : ExecutionModeMD->operands()) {
assert(Op->getNumOperands() > 0);
if (!Op->getOperand(0)) {
ContainsNodesToRemove = true;
continue;
}

// If the first operand is not nullptr then it has to be a kernel
// function.
Value *Val = cast<ValueAsMetadata>(Op->getOperand(0))->getValue();
Function *F = cast<Function>(Val);
// If kernel function is just a prototype and unused then we can remove it
// and later remove corresponding spirv.ExecutionMode metadata node.
if (F->isDeclaration() && F->use_empty()) {
F->eraseFromParent();
ContainsNodesToRemove = true;
continue;
}

// Rememver nodes which we need to keep in the module.
ValueVec.push_back(Op);
}
if (!ContainsNodesToRemove)
return;

if (ValueVec.empty()) {
// If all nodes need to be removed then just remove named metadata
// completely.
ExecutionModeMD->eraseFromParent();
} else {
ExecutionModeMD->clearOperands();
for (auto MD : ValueVec)
ExecutionModeMD->addOperand(MD);
}
}

ModuleDesc extractSubModule(const ModuleDesc &MD,
const SetVector<const GlobalValue *> GVs,
EntryPointGroup &&ModuleEntryPoints) {
Expand Down Expand Up @@ -577,6 +623,15 @@ void ModuleDesc::cleanup() {
MPM.addPass(StripDeadDebugInfoPass()); // Remove dead debug info.
MPM.addPass(StripDeadPrototypesPass()); // Remove dead func decls.
MPM.run(*M, MAM);

// Original module may have named metadata (spirv.ExecutionMode) referencing
// kernels in the module. Some of the Metadata nodes may reference kernels
// which are not included into the extracted submodule, in such case
// CloneModule either leaves that metadata nodes as is but they will reference
// dead prototype of the kernel or operand will be replace with nullptr. So
// process all nodes in the named metadata and remove nodes which are
// referencing kernels which are not included into submodule.
processSubModuleNamedMetadata(M.get());
}

bool ModuleDesc::isSpecConstantDefault() const {
Expand Down
Loading