Skip to content

Commit 8bf8022

Browse files
[OpaquePointers] Demangle _Float* correctly for determining argument types. (#1849)
Itanium name mangling gives different type names for the _Float16, _Float32, etc. types than the more common half/float/double names. This adds support for people who use these types in SPIR-V kernels.
1 parent 32721e8 commit 8bf8022

File tree

2 files changed

+73
-10
lines changed

2 files changed

+73
-10
lines changed

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,37 @@ static StringRef stringify(const itanium_demangle::NameType *Node) {
653653
return StringRef(Str.begin(), Str.size());
654654
}
655655

656+
/// Convert a mangled name that represents a basic integer, floating-point,
657+
/// etc. type into the corresponding LLVM type.
658+
static Type *getPrimitiveType(LLVMContext &Ctx,
659+
const llvm::itanium_demangle::Node *N) {
660+
using namespace llvm::itanium_demangle;
661+
if (auto *Name = dyn_cast<NameType>(N)) {
662+
return parsePrimitiveType(Ctx, stringify(Name));
663+
}
664+
if (auto *BitInt = dyn_cast<BitIntType>(N)) {
665+
unsigned BitWidth = 0;
666+
BitInt->match([&](const Node *NodeSize, bool) {
667+
const StringRef SizeStr(stringify(cast<NameType>(NodeSize)));
668+
SizeStr.getAsInteger(10, BitWidth);
669+
});
670+
return Type::getIntNTy(Ctx, BitWidth);
671+
}
672+
if (auto *FP = dyn_cast<BinaryFPType>(N)) {
673+
StringRef SizeStr;
674+
FP->match([&](const Node *NodeDimension) {
675+
SizeStr = stringify(cast<NameType>(NodeDimension));
676+
});
677+
return StringSwitch<Type *>(SizeStr)
678+
.Case("16", Type::getHalfTy(Ctx))
679+
.Case("32", Type::getFloatTy(Ctx))
680+
.Case("64", Type::getDoubleTy(Ctx))
681+
.Case("128", Type::getFP128Ty(Ctx))
682+
.Default(nullptr);
683+
}
684+
return nullptr;
685+
}
686+
656687
template <typename FnType>
657688
static TypedPointerType *
658689
parseNode(Module *M, const llvm::itanium_demangle::Node *ParamType,
@@ -724,21 +755,15 @@ parseNode(Module *M, const llvm::itanium_demangle::Node *ParamType,
724755
} else {
725756
PointeeTy = parsePrimitiveType(M->getContext(), MangledStructName);
726757
}
727-
} else if (auto *BitInt = dyn_cast<BitIntType>(Pointee)) {
728-
unsigned BitWidth = 0;
729-
BitInt->match([&](const Node *NodeSize, bool) {
730-
const StringRef SizeStr(stringify(cast<NameType>(NodeSize)));
731-
SizeStr.getAsInteger(10, BitWidth);
732-
});
733-
PointeeTy = Type::getIntNTy(M->getContext(), BitWidth);
758+
} else if (auto *Ty = getPrimitiveType(M->getContext(), Pointee)) {
759+
PointeeTy = Ty;
734760
} else if (auto *Vec = dyn_cast<itanium_demangle::VectorType>(Pointee)) {
735761
unsigned ElemCount = 0;
736762
const StringRef ElemCountStr(
737763
stringify(cast<NameType>(Vec->getDimension())));
738764
ElemCountStr.getAsInteger(10, ElemCount);
739-
if (auto *Name = dyn_cast<NameType>(Vec->getBaseType())) {
740-
PointeeTy = parsePrimitiveType(M->getContext(), stringify(Name));
741-
PointeeTy = llvm::VectorType::get(PointeeTy, ElemCount, false);
765+
if (auto *Ty = getPrimitiveType(M->getContext(), Vec->getBaseType())) {
766+
PointeeTy = llvm::VectorType::get(Ty, ElemCount, false);
742767
}
743768
} else if (llvm::isa<itanium_demangle::PointerType>(Pointee)) {
744769
PointeeTy = parseNode(M, Pointee, GetStructType);

test/transcoding/float16.ll

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv
3+
; RUN: spirv-val %t.spv
4+
; RUN: llvm-spirv %t.spv -to-text -o - | FileCheck %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc -emit-opaque-pointers
6+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
7+
8+
source_filename = "math_builtin_float_half.cpp"
9+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
10+
target triple = "spirv64-unknown-unknown"
11+
12+
; CHECK-SPIRV: 3 TypeFloat [[HALF:[0-9]+]] 16
13+
; CHECK-SPIRV: 4 TypePointer [[HALFPTR:[0-9]+]] 7 [[HALF]]
14+
; CHECK-SPIRV: 4 TypeVector [[HALFV2:[0-9]+]] [[HALF]] 2
15+
; CHECK-SPIRV: 4 TypePointer [[HALFV2PTR:[0-9]+]] 7 [[HALFV2]]
16+
; CHECK-SPIRV: 4 Constant [[HALF]] [[CONST:[0-9]+]] 14788
17+
; CHECK-SPIRV: 4 Variable [[HALFPTR]] [[ADDR:[0-9]+]] 7
18+
; CHECK-SPIRV: 4 Variable [[HALFV2PTR]] [[ADDR2:[0-9]+]] 7
19+
; CHECK-SPIRV: 7 ExtInst [[HALF]] [[#]] 1 fract [[CONST]] [[ADDR]]
20+
; CHECK-SPIRV: 7 ExtInst [[HALFV2]] [[#]] 1 fract [[#]] [[ADDR2]]
21+
22+
; CHECK-LLVM: %addr = alloca half
23+
; CHECK-LLVM: %addr2 = alloca <2 x half>
24+
; CHECK-LLVM: %res = call spir_func half @_Z5fractDhPDh(half 0xH39C4, ptr %addr)
25+
; CHECK-LLVM: %res2 = call spir_func <2 x half> @_Z5fractDv2_DhPS_(<2 x half> <half 0xH39C4, half 0xH0000>, ptr %addr2)
26+
27+
define spir_kernel void @test() {
28+
entry:
29+
%addr = alloca half
30+
%addr2 = alloca <2 x half>
31+
%res = call spir_func noundef half @_Z17__spirv_ocl_fractDF16_PU3AS0DF16_(half noundef 0xH39C4, ptr noundef %addr)
32+
%res2 = call spir_func noundef <2 x half> @_Z17__spirv_ocl_fractDv2_DF16_PU3AS0S_(<2 x half> noundef <half 0xH39C4, half 0xH0000>, ptr noundef %addr2)
33+
ret void
34+
}
35+
36+
declare spir_func noundef half @_Z17__spirv_ocl_fractDF16_PU3AS0DF16_(half noundef, ptr noundef) local_unnamed_addr
37+
38+
declare spir_func noundef <2 x half> @_Z17__spirv_ocl_fractDv2_DF16_PU3AS0S_(<2 x half> noundef, ptr noundef) local_unnamed_addr

0 commit comments

Comments
 (0)