Skip to content

[SYCL][ESIMD] Fix compilation break occurring when bfloat16 constructor is used in a kernel #8892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 4, 2023
Merged
83 changes: 78 additions & 5 deletions llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,10 @@ class ESIMDIntrinDescTable {
{"slm_init", {"slm.init", {a(0)}}},
{"bf_cvt", {"bf.cvt", {a(0)}}},
{"tf32_cvt", {"tf32.cvt", {a(0)}}},
{"__devicelib_ConvertFToBF16INTEL",
{"__spirv_ConvertFToBF16INTEL", {a(0)}}},
{"__devicelib_ConvertBF16ToFINTEL",
{"__spirv_ConvertBF16ToFINTEL", {a(0)}}},
{"addc", {"addc", {l(0)}}},
{"subb", {"subb", {l(0)}}},
{"bfn", {"bfn", {a(0), a(1), a(2), t(0)}}}};
Expand Down Expand Up @@ -703,6 +707,28 @@ static const ESIMDIntrinDesc &getIntrinDesc(StringRef SrcSpelling) {
return It->second;
}

static bool isDevicelibFunction(StringRef FunctionName) {
return llvm::StringSwitch<bool>(FunctionName)
.Case("__devicelib_ConvertFToBF16INTEL", true)
.Case("__devicelib_ConvertBF16ToFINTEL", true)
.Default(false);
}

// Mangle deviceLib function to make it pass through the regular workflow
// These functions are defined as extern "C" which Demangler that is used
// fails to handle properly.
static std::string mangleDevicelibFunction(StringRef FunctionName) {
if (isDevicelibFunction(FunctionName)) {
if (FunctionName.startswith("__devicelib_ConvertFToBF16INTEL")) {
return (Twine("_Z31") + FunctionName + "RKf").str();
}
if (FunctionName.startswith("__devicelib_ConvertBF16ToFINTEL")) {
return (Twine("_Z31") + FunctionName + "RKt").str();
}
}
return FunctionName.str();
}

Type *parsePrimitiveTypeString(StringRef TyStr, LLVMContext &Ctx) {
return llvm::StringSwitch<Type *>(TyStr)
.Case("bool", IntegerType::getInt1Ty(Ctx))
Expand Down Expand Up @@ -1326,6 +1352,46 @@ static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc,
}
}

// Create a spirv function declaration
// This is used for lowering devicelib functions.
// The function
// 1. Generates spirv function definition
// 2. Converts passed by reference argument of devicelib function into passed by
// value argument of spirv functions
// 3. Assigns proper attributes to generated function
static Function *
createDeviceLibESIMDDeclaration(const ESIMDIntrinDesc &Desc,
SmallVector<Value *, 16> &GenXArgs,
CallInst &CI) {
SmallVector<Type *, 16> ArgTypes;
IRBuilder<> Bld(&CI);
for (unsigned i = 0; i < GenXArgs.size(); ++i) {
Type *NTy = llvm::StringSwitch<Type *>(Desc.GenXSpelling)
.Case("__spirv_ConvertFToBF16INTEL",
Type::getFloatTy(CI.getContext()))
.Case("__spirv_ConvertBF16ToFINTEL",
Type::getInt16Ty(CI.getContext()))
.Default(nullptr);

auto LI = Bld.CreateLoad(NTy, GenXArgs[i]);
GenXArgs[i] = LI;
ArgTypes.push_back(NTy);
}
auto *FType = FunctionType::get(CI.getType(), ArgTypes, false);
Function *F = CI.getModule()->getFunction(Desc.GenXSpelling);
if (!F) {
F = Function::Create(FType, GlobalVariable::ExternalLinkage,
Desc.GenXSpelling, CI.getModule());
F->addFnAttr(Attribute::NoUnwind);
F->addFnAttr(Attribute::Convergent);
F->setDSOLocal(true);

F->setCallingConv(CallingConv::SPIR_FUNC);
}

return F;
}

// Create a simple function declaration
// This is used for testing purposes, when it is impossible to query
// vc-intrinsics
Expand Down Expand Up @@ -1403,7 +1469,9 @@ static void translateESIMDIntrinsicCall(CallInst &CI) {
using Demangler = id::ManglingParser<SimpleAllocator>;
Function *F = CI.getCalledFunction();
llvm::esimd::assert_and_diag(F, "function to translate is invalid");
StringRef MnglName = F->getName();
std::string MnglNameStr = mangleDevicelibFunction(F->getName());
StringRef MnglName = MnglNameStr;

Demangler Parser(MnglName.begin(), MnglName.end());
id::Node *AST = Parser.parse();

Expand All @@ -1416,7 +1484,9 @@ static void translateESIMDIntrinsicCall(CallInst &CI) {
auto *FE = static_cast<id::FunctionEncoding *>(AST);
id::StringView BaseNameV = FE->getName()->getBaseName();

auto PrefLen = StringRef(ESIMD_INTRIN_PREF1).size();
auto PrefLen = isDevicelibFunction(F->getName())
? 0
: StringRef(ESIMD_INTRIN_PREF1).size();
StringRef BaseName(BaseNameV.begin() + PrefLen, BaseNameV.size() - PrefLen);
const auto &Desc = getIntrinDesc(BaseName);
if (!Desc.isValid()) // TODO remove this once all intrinsics are supported
Expand All @@ -1429,7 +1499,9 @@ static void translateESIMDIntrinsicCall(CallInst &CI) {
Function *NewFDecl = nullptr;
bool DoesFunctionReturnStructure =
isStructureReturningFunction(Desc.GenXSpelling);
if (Desc.GenXSpelling.rfind("test.src.", 0) == 0) {
if (isDevicelibFunction(F->getName())) {
NewFDecl = createDeviceLibESIMDDeclaration(Desc, GenXArgs, CI);
} else if (Desc.GenXSpelling.rfind("test.src.", 0) == 0) {
// Special case for testing purposes
NewFDecl = createTestESIMDDeclaration(Desc, GenXArgs, CI);
} else {
Expand Down Expand Up @@ -1724,7 +1796,7 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,

// See if the Name represents an ESIMD intrinsic and demangle only if it
// does.
if (!Name.consume_front(ESIMD_INTRIN_PREF0))
if (!Name.consume_front(ESIMD_INTRIN_PREF0) && !isDevicelibFunction(Name))
continue;
// now skip the digits
Name = Name.drop_while([](char C) { return std::isdigit(C); });
Expand Down Expand Up @@ -1771,7 +1843,8 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
assert(!Name.startswith("__sycl_set_kernel_properties") &&
"__sycl_set_kernel_properties must have been lowered");

if (Name.empty() || !Name.startswith(ESIMD_INTRIN_PREF1))
if (Name.empty() ||
(!Name.startswith(ESIMD_INTRIN_PREF1) && !isDevicelibFunction(Name)))
continue;
// this is ESIMD intrinsic - record for later translation
ESIMDIntrCalls.push_back(CI);
Expand Down
61 changes: 61 additions & 0 deletions sycl/test-e2e/ESIMD/regression/bfloat16Constructor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// REQUIRES: gpu
// UNSUPPORTED: gpu-intel-gen9 || cuda || hip
// RUN: %clangxx -fsycl %s -o %t.out
// RUN: %GPU_RUN_PLACEHOLDER %t.out
// XFAIL: gpu && !esimd_emulator
//==- bfloat16Constructor.cpp - Test to verify use of bfloat16 constructor -==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// This is basic test to verify use of bfloat16 constructor in kernel.
// TODO: Enable the test once the GPU RT supporting the functionality reaches
// the CI

#include <CL/sycl.hpp>
#include <ext/intel/esimd.hpp>
#include <iostream>

using namespace sycl;

int main() {
constexpr unsigned Size = 32;
constexpr unsigned VL = 32;
constexpr unsigned GroupSize = 1;

queue q;
auto dev = q.get_device();
std::cout << "Running on " << dev.get_info<info::device::name>() << "\n";
auto *C = malloc_shared<float>(Size * sizeof(float), dev, q.get_context());

for (auto i = 0; i != Size; i++) {
C[i] = 7;
}

nd_range<1> Range(range<1>(Size / VL), range<1>(GroupSize));

auto e = q.submit([&](handler &cgh) {
cgh.parallel_for<class Test>(Range, [=](nd_item<1> i) SYCL_ESIMD_KERNEL {
using bf16 = sycl::ext::oneapi::bfloat16;
using namespace __ESIMD_NS;
using namespace __ESIMD_ENS;
simd<bf16, 32> data_bf16 = bf16(0);
simd<float, 32> data = data_bf16;
lsc_block_store<float, 32>(C, data);
});
});
e.wait();
bool Pass = true;
for (auto i = 0; i != Size; i++) {
if (C[i] != 0) {
Pass = false;
}
}

free(C, q);
std::cout << (Pass ? "Test Passed\n" : "Test FAILED\n");
return 0;
}
4 changes: 2 additions & 2 deletions sycl/test/esimd/fp16_converts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void bf16_scalar() {
// The actual support in GPU RT is on the way though.
float F32_scalar = 1;
bfloat16 BF16_scalar = F32_scalar;
// CHECK: call spir_func zeroext i16 @__devicelib_ConvertFToBF16INTEL(float {{[^)]+}})
// CHECK: call i16 @__spirv_ConvertFToBF16INTEL(float {{[^)]+}})
float F32_scalar_conv = BF16_scalar;
// CHECK: call spir_func float @__devicelib_ConvertBF16ToFINTEL(i16 {{[^)]+}})
// CHECK: call float @__spirv_ConvertBF16ToFINTEL(i16 {{[^)]+}})
}