Skip to content

Commit bd2ac52

Browse files
MrSidimsjsji
authored andcommitted
Add fast math flag translation for OpenCL std lib (#2762)
Such possibility was added in SPIR-V 1.6. This patch also introduces limited translation of nofpclass LLVM parameter attribute. Signed-off-by: Sidorov, Dmitry <[email protected]> Original commit: KhronosGroup/SPIRV-LLVM-Translator@ae8fa3825a699b2
1 parent 9b50b27 commit bd2ac52

File tree

4 files changed

+153
-4
lines changed

4 files changed

+153
-4
lines changed

llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ Value *BuiltinCallMutator::doConversion() {
106106
NewCall->copyMetadata(*CI);
107107
NewCall->setAttributes(CallAttrs);
108108
NewCall->setTailCall(CI->isTailCall());
109+
if (isa<FPMathOperator>(CI))
110+
NewCall->setFastMathFlags(CI->getFastMathFlags());
111+
109112
if (CI->hasFnAttr("fpbuiltin-max-error")) {
110113
auto Attr = CI->getFnAttr("fpbuiltin-max-error");
111114
NewCall->addFnAttr(Attr);

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,8 +2510,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
25102510
case OpExtInst: {
25112511
auto *ExtInst = static_cast<SPIRVExtInst *>(BV);
25122512
switch (ExtInst->getExtSetKind()) {
2513-
case SPIRVEIS_OpenCL:
2514-
return mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
2513+
case SPIRVEIS_OpenCL: {
2514+
auto *V = mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
2515+
applyFPFastMathModeDecorations(BV, static_cast<Instruction *>(V));
2516+
return V;
2517+
}
25152518
case SPIRVEIS_Debug:
25162519
case SPIRVEIS_OpenCL_DebugInfo_100:
25172520
case SPIRVEIS_NonSemantic_Shader_DebugInfo_100:

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3068,7 +3068,8 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
30683068
if (Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
30693069
Opcode == Instruction::FMul || Opcode == Instruction::FDiv ||
30703070
Opcode == Instruction::FRem ||
3071-
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp) &&
3071+
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
3072+
BV->isExtInst()) &&
30723073
BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_6))) {
30733074
FastMathFlags FMF = BVF->getFastMathFlags();
30743075
SPIRVWord M{0};
@@ -3095,8 +3096,52 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
30953096
}
30963097
}
30973098
}
3098-
if (M != 0)
3099+
// Handle nofpclass attribute. Nothing to do if fast math flag is already
3100+
// set.
3101+
if ((BV->isExtInst() &&
3102+
static_cast<SPIRVExtInst *>(BV)->getExtSetKind() ==
3103+
SPIRVEIS_OpenCL) &&
3104+
BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_6) &&
3105+
!(M & FPFastMathModeFastMask)) {
3106+
auto *F = cast<CallInst>(V)->getCalledFunction();
3107+
auto FAttrs = F->getAttributes();
3108+
AttributeSet RetAttrs = FAttrs.getRetAttrs();
3109+
if (RetAttrs.hasAttribute(Attribute::NoFPClass)) {
3110+
FPClassTest RetTest =
3111+
RetAttrs.getAttribute(Attribute::NoFPClass).getNoFPClass();
3112+
AttributeSet RetAttrs = FAttrs.getRetAttrs();
3113+
// Only Nan and Inf tests are representable in SPIR-V now.
3114+
bool ToAddNoNan = RetTest & fcNan;
3115+
bool ToAddNoInf = RetTest & fcInf;
3116+
if (ToAddNoNan || ToAddNoInf) {
3117+
const auto *FT = F->getFunctionType();
3118+
const size_t NumParams = FT->getNumParams();
3119+
for (size_t I = 0; I != NumParams; ++I) {
3120+
if (!FT->getParamType(I)->isFloatTy())
3121+
continue;
3122+
if (!F->hasParamAttribute(I, Attribute::NoFPClass)) {
3123+
ToAddNoNan = false;
3124+
ToAddNoInf = false;
3125+
break;
3126+
}
3127+
FPClassTest ArgTest =
3128+
FAttrs.getParamAttr(I, Attribute::NoFPClass).getNoFPClass();
3129+
ToAddNoNan = ToAddNoNan && static_cast<bool>(ArgTest & fcNan);
3130+
ToAddNoInf = ToAddNoInf && static_cast<bool>(ArgTest & fcInf);
3131+
}
3132+
}
3133+
if (ToAddNoNan)
3134+
M |= FPFastMathModeNotNaNMask;
3135+
if (ToAddNoInf)
3136+
M |= FPFastMathModeNotInfMask;
3137+
}
3138+
}
3139+
if (M != 0) {
30993140
BV->setFPFastMathMode(M);
3141+
if (Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
3142+
BV->isExtInst())
3143+
BM->setMinSPIRVVersion(VersionNumber::SPIRV_1_6);
3144+
}
31003145
}
31013146
}
31023147
if (Instruction *Inst = dyn_cast<Instruction>(V)) {
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv -spirv-text %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV
3+
; RUN: llvm-spirv %t.bc -o %t.spv
4+
; RUN: spirv-val %t.spv
5+
; RUN: llvm-spirv -r %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-OCL
6+
; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-SPV
7+
8+
; RUN: llvm-spirv -spirv-text --spirv-max-version=1.5 %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV-NEG
9+
10+
; CHECK-SPIRV: Decorate [[#FPDec1:]] FPFastMathMode 3
11+
; CHECK-SPIRV: Decorate [[#FPDec2:]] FPFastMathMode 2
12+
; CHECK-SPIRV: Decorate [[#FPDec3:]] FPFastMathMode 3
13+
; CHECK-SPIRV: Decorate [[#FPDec4:]] FPFastMathMode 16
14+
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec1]] [[#]] fmax [[#]] [[#]]
15+
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec2]] [[#]] fmin [[#]] [[#]]
16+
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec3]] [[#]] ldexp [[#]] [[#]]
17+
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec4]] [[#]] fmax [[#]] [[#]]
18+
19+
; CHECK-SPIRV-NEG-NOT: Decorate [[#]] FPFastMathMode [[#]]
20+
21+
; CHECK-LLVM-OCL: call nnan ninf spir_func float @_Z4fmaxff(float %[[#]], float %[[#]])
22+
; CHECK-LLVM-OCL: call ninf spir_func float @_Z4fminff(float %[[#]], float %[[#]])
23+
; CHECK-LLVM-OCL: call nnan ninf spir_func float @_Z5ldexpfi(float %[[#]], i32 %[[#]])
24+
; CHECK-LLVM-OCL: call fast spir_func float @_Z4fmaxff(float %[[#]], float %[[#]])
25+
26+
; CHECK-LLVM-SPV: call nnan ninf spir_func float @_Z16__spirv_ocl_fmaxff(float %[[#]], float %[[#]])
27+
; CHECK-LLVM-SPV: call ninf spir_func float @_Z16__spirv_ocl_fminff(float %[[#]], float %[[#]])
28+
; CHECK-LLVM-SPV: call nnan ninf spir_func float @_Z17__spirv_ocl_ldexpfi(float %[[#]], i32 %[[#]])
29+
; CHECK-LLVM-SPV: call fast spir_func float @_Z16__spirv_ocl_fmaxff(float %[[#]], float %[[#]])
30+
31+
; ModuleID = 'test.bc'
32+
source_filename = "test.cpp"
33+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
34+
target triple = "spir64-unknown-unknown"
35+
36+
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
37+
38+
declare dso_local spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fmaxff(float noundef nofpclass(nan inf), float noundef nofpclass(nan inf)) local_unnamed_addr
39+
40+
declare dso_local spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fminff(float noundef nofpclass(inf), float noundef nofpclass(nan inf)) local_unnamed_addr
41+
42+
declare dso_local spir_func noundef nofpclass(nan inf) float @_Z17__spirv_ocl_ldexpfi(float noundef nofpclass(nan inf), i32 noundef)
43+
44+
define weak_odr dso_local spir_kernel void @nofpclass_all(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
45+
entry:
46+
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
47+
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
48+
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
49+
%cmp.i = icmp ult i64 %0, 2147483648
50+
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
51+
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
52+
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
53+
%call.i.i = tail call spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fmaxff(float noundef nofpclass(nan inf) %1, float noundef nofpclass(nan inf) %2)
54+
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
55+
ret void
56+
}
57+
58+
define weak_odr dso_local spir_kernel void @nofpclass_part(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
59+
entry:
60+
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
61+
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
62+
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
63+
%cmp.i = icmp ult i64 %0, 2147483648
64+
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
65+
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
66+
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
67+
%call.i.i = tail call spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fminff(float noundef nofpclass(inf) %1, float noundef nofpclass(nan inf) %2)
68+
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
69+
ret void
70+
}
71+
72+
define weak_odr dso_local spir_kernel void @nofpclass_int(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
73+
entry:
74+
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
75+
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
76+
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
77+
%cmp.i = icmp ult i64 %0, 2147483648
78+
%arrayidx5.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_dat2, i64 %0
79+
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
80+
%2 = load i32, ptr addrspace(1) %arrayidx5.i, align 4
81+
%call.i.i = tail call spir_func noundef nofpclass(nan inf) float @_Z17__spirv_ocl_ldexpfi(float noundef nofpclass(inf) %1, i32 noundef %2)
82+
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
83+
ret void
84+
}
85+
86+
define weak_odr dso_local spir_kernel void @nofpclass_fast(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
87+
entry:
88+
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
89+
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
90+
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
91+
%cmp.i = icmp ult i64 %0, 2147483648
92+
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
93+
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
94+
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
95+
%call.i.i = tail call fast spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fmaxff(float noundef nofpclass(nan inf) %1, float noundef nofpclass(nan inf) %2)
96+
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
97+
ret void
98+
}

0 commit comments

Comments
 (0)