Skip to content

Commit ae8fa38

Browse files
authored
Add fast math flag translation for OpenCL std lib (#2762)
Such possibility was added in SPIR-V 1.6. This patch also introduces limited translation of nofpclass LLVM parameter attribute. Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent c795db9 commit ae8fa38

File tree

4 files changed

+153
-4
lines changed

4 files changed

+153
-4
lines changed

lib/SPIRV/SPIRVBuiltinHelper.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ Value *BuiltinCallMutator::doConversion() {
106106
NewCall->copyMetadata(*CI);
107107
NewCall->setAttributes(CallAttrs);
108108
NewCall->setTailCall(CI->isTailCall());
109+
if (isa<FPMathOperator>(CI))
110+
NewCall->setFastMathFlags(CI->getFastMathFlags());
111+
109112
if (CI->hasFnAttr("fpbuiltin-max-error")) {
110113
auto Attr = CI->getFnAttr("fpbuiltin-max-error");
111114
NewCall->addFnAttr(Attr);

lib/SPIRV/SPIRVReader.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2507,8 +2507,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
25072507
case OpExtInst: {
25082508
auto *ExtInst = static_cast<SPIRVExtInst *>(BV);
25092509
switch (ExtInst->getExtSetKind()) {
2510-
case SPIRVEIS_OpenCL:
2511-
return mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
2510+
case SPIRVEIS_OpenCL: {
2511+
auto *V = mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
2512+
applyFPFastMathModeDecorations(BV, static_cast<Instruction *>(V));
2513+
return V;
2514+
}
25122515
case SPIRVEIS_Debug:
25132516
case SPIRVEIS_OpenCL_DebugInfo_100:
25142517
case SPIRVEIS_NonSemantic_Shader_DebugInfo_100:

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,7 +3060,8 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
30603060
if (Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
30613061
Opcode == Instruction::FMul || Opcode == Instruction::FDiv ||
30623062
Opcode == Instruction::FRem ||
3063-
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp) &&
3063+
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
3064+
BV->isExtInst()) &&
30643065
BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_6))) {
30653066
FastMathFlags FMF = BVF->getFastMathFlags();
30663067
SPIRVWord M{0};
@@ -3087,8 +3088,52 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
30873088
}
30883089
}
30893090
}
3090-
if (M != 0)
3091+
// Handle nofpclass attribute. Nothing to do if fast math flag is already
3092+
// set.
3093+
if ((BV->isExtInst() &&
3094+
static_cast<SPIRVExtInst *>(BV)->getExtSetKind() ==
3095+
SPIRVEIS_OpenCL) &&
3096+
BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_6) &&
3097+
!(M & FPFastMathModeFastMask)) {
3098+
auto *F = cast<CallInst>(V)->getCalledFunction();
3099+
auto FAttrs = F->getAttributes();
3100+
AttributeSet RetAttrs = FAttrs.getRetAttrs();
3101+
if (RetAttrs.hasAttribute(Attribute::NoFPClass)) {
3102+
FPClassTest RetTest =
3103+
RetAttrs.getAttribute(Attribute::NoFPClass).getNoFPClass();
3104+
AttributeSet RetAttrs = FAttrs.getRetAttrs();
3105+
// Only Nan and Inf tests are representable in SPIR-V now.
3106+
bool ToAddNoNan = RetTest & fcNan;
3107+
bool ToAddNoInf = RetTest & fcInf;
3108+
if (ToAddNoNan || ToAddNoInf) {
3109+
const auto *FT = F->getFunctionType();
3110+
const size_t NumParams = FT->getNumParams();
3111+
for (size_t I = 0; I != NumParams; ++I) {
3112+
if (!FT->getParamType(I)->isFloatTy())
3113+
continue;
3114+
if (!F->hasParamAttribute(I, Attribute::NoFPClass)) {
3115+
ToAddNoNan = false;
3116+
ToAddNoInf = false;
3117+
break;
3118+
}
3119+
FPClassTest ArgTest =
3120+
FAttrs.getParamAttr(I, Attribute::NoFPClass).getNoFPClass();
3121+
ToAddNoNan = ToAddNoNan && static_cast<bool>(ArgTest & fcNan);
3122+
ToAddNoInf = ToAddNoInf && static_cast<bool>(ArgTest & fcInf);
3123+
}
3124+
}
3125+
if (ToAddNoNan)
3126+
M |= FPFastMathModeNotNaNMask;
3127+
if (ToAddNoInf)
3128+
M |= FPFastMathModeNotInfMask;
3129+
}
3130+
}
3131+
if (M != 0) {
30913132
BV->setFPFastMathMode(M);
3133+
if (Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
3134+
BV->isExtInst())
3135+
BM->setMinSPIRVVersion(VersionNumber::SPIRV_1_6);
3136+
}
30923137
}
30933138
}
30943139
if (Instruction *Inst = dyn_cast<Instruction>(V)) {
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv -spirv-text %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV
3+
; RUN: llvm-spirv %t.bc -o %t.spv
4+
; RUN: spirv-val %t.spv
5+
; RUN: llvm-spirv -r %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-OCL
6+
; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-SPV
7+
8+
; RUN: llvm-spirv -spirv-text --spirv-max-version=1.5 %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV-NEG
9+
10+
; CHECK-SPIRV: Decorate [[#FPDec1:]] FPFastMathMode 3
11+
; CHECK-SPIRV: Decorate [[#FPDec2:]] FPFastMathMode 2
12+
; CHECK-SPIRV: Decorate [[#FPDec3:]] FPFastMathMode 3
13+
; CHECK-SPIRV: Decorate [[#FPDec4:]] FPFastMathMode 16
14+
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec1]] [[#]] fmax [[#]] [[#]]
15+
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec2]] [[#]] fmin [[#]] [[#]]
16+
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec3]] [[#]] ldexp [[#]] [[#]]
17+
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec4]] [[#]] fmax [[#]] [[#]]
18+
19+
; CHECK-SPIRV-NEG-NOT: Decorate [[#]] FPFastMathMode [[#]]
20+
21+
; CHECK-LLVM-OCL: call nnan ninf spir_func float @_Z4fmaxff(float %[[#]], float %[[#]])
22+
; CHECK-LLVM-OCL: call ninf spir_func float @_Z4fminff(float %[[#]], float %[[#]])
23+
; CHECK-LLVM-OCL: call nnan ninf spir_func float @_Z5ldexpfi(float %[[#]], i32 %[[#]])
24+
; CHECK-LLVM-OCL: call fast spir_func float @_Z4fmaxff(float %[[#]], float %[[#]])
25+
26+
; CHECK-LLVM-SPV: call nnan ninf spir_func float @_Z16__spirv_ocl_fmaxff(float %[[#]], float %[[#]])
27+
; CHECK-LLVM-SPV: call ninf spir_func float @_Z16__spirv_ocl_fminff(float %[[#]], float %[[#]])
28+
; CHECK-LLVM-SPV: call nnan ninf spir_func float @_Z17__spirv_ocl_ldexpfi(float %[[#]], i32 %[[#]])
29+
; CHECK-LLVM-SPV: call fast spir_func float @_Z16__spirv_ocl_fmaxff(float %[[#]], float %[[#]])
30+
31+
; ModuleID = 'test.bc'
32+
source_filename = "test.cpp"
33+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
34+
target triple = "spir64-unknown-unknown"
35+
36+
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
37+
38+
declare dso_local spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fmaxff(float noundef nofpclass(nan inf), float noundef nofpclass(nan inf)) local_unnamed_addr
39+
40+
declare dso_local spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fminff(float noundef nofpclass(inf), float noundef nofpclass(nan inf)) local_unnamed_addr
41+
42+
declare dso_local spir_func noundef nofpclass(nan inf) float @_Z17__spirv_ocl_ldexpfi(float noundef nofpclass(nan inf), i32 noundef)
43+
44+
define weak_odr dso_local spir_kernel void @nofpclass_all(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
45+
entry:
46+
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
47+
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
48+
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
49+
%cmp.i = icmp ult i64 %0, 2147483648
50+
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
51+
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
52+
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
53+
%call.i.i = tail call spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fmaxff(float noundef nofpclass(nan inf) %1, float noundef nofpclass(nan inf) %2)
54+
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
55+
ret void
56+
}
57+
58+
define weak_odr dso_local spir_kernel void @nofpclass_part(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
59+
entry:
60+
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
61+
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
62+
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
63+
%cmp.i = icmp ult i64 %0, 2147483648
64+
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
65+
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
66+
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
67+
%call.i.i = tail call spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fminff(float noundef nofpclass(inf) %1, float noundef nofpclass(nan inf) %2)
68+
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
69+
ret void
70+
}
71+
72+
define weak_odr dso_local spir_kernel void @nofpclass_int(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
73+
entry:
74+
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
75+
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
76+
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
77+
%cmp.i = icmp ult i64 %0, 2147483648
78+
%arrayidx5.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_dat2, i64 %0
79+
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
80+
%2 = load i32, ptr addrspace(1) %arrayidx5.i, align 4
81+
%call.i.i = tail call spir_func noundef nofpclass(nan inf) float @_Z17__spirv_ocl_ldexpfi(float noundef nofpclass(inf) %1, i32 noundef %2)
82+
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
83+
ret void
84+
}
85+
86+
define weak_odr dso_local spir_kernel void @nofpclass_fast(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
87+
entry:
88+
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
89+
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
90+
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
91+
%cmp.i = icmp ult i64 %0, 2147483648
92+
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
93+
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
94+
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
95+
%call.i.i = tail call fast spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fmaxff(float noundef nofpclass(nan inf) %1, float noundef nofpclass(nan inf) %2)
96+
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
97+
ret void
98+
}

0 commit comments

Comments
 (0)