Skip to content

Commit 5e1bd3b

Browse files
vmustyaigcbot
authored andcommitted
Lower math intrinsics for bfloat in VC
.
1 parent 9dec1d0 commit 5e1bd3b

File tree

2 files changed

+270
-1
lines changed

2 files changed

+270
-1
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXBFloatLowering.cpp

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ class GenXBFloatLowering : public FunctionPass,
7272
void visitSIToFPInst(SIToFPInst &Inst);
7373
void visitUIToFPInst(UIToFPInst &Inst);
7474

75+
// lower intrinsic instructions
76+
void visitCallInst(CallInst &Inst);
77+
7578
private:
7679
void lowerCastToBFloat(CastInst &Inst);
7780
void lowerCastFromBFloat(CastInst &Inst);
@@ -127,7 +130,11 @@ void GenXBFloatLowering::visitBinaryOperator(BinaryOperator &Inst) {
127130
Instruction::BinaryOps Opcode = Inst.getOpcode();
128131
auto *Op0Conv = Builder.CreateFPExt(Src0, FloatTy);
129132
auto *Op1Conv = Builder.CreateFPExt(Src1, FloatTy);
130-
auto *InstUpdate = Builder.CreateBinOp(Opcode, Op0Conv, Op1Conv);
133+
134+
auto *InstUpdate =
135+
cast<Instruction>(Builder.CreateBinOp(Opcode, Op0Conv, Op1Conv));
136+
InstUpdate->setFastMathFlags(Inst.getFastMathFlags());
137+
131138
auto *Trunc = Builder.CreateFPTrunc(InstUpdate, Ty);
132139
Inst.replaceAllUsesWith(Trunc);
133140
Inst.eraseFromParent();
@@ -195,6 +202,70 @@ void GenXBFloatLowering::visitUIToFPInst(UIToFPInst &Inst) {
195202
lowerCastToBFloat(Inst);
196203
}
197204

205+
void GenXBFloatLowering::visitCallInst(CallInst &Inst) {
206+
auto IID = vc::getAnyIntrinsicID(&Inst);
207+
auto *Ty = Inst.getType();
208+
SmallVector<Type *, 2> Types;
209+
210+
switch (IID) {
211+
default:
212+
return;
213+
case GenXIntrinsic::genx_sat:
214+
break;
215+
case Intrinsic::cos:
216+
case Intrinsic::exp2:
217+
case Intrinsic::fabs:
218+
case Intrinsic::fma:
219+
case Intrinsic::fmuladd:
220+
case Intrinsic::log2:
221+
case Intrinsic::maximum:
222+
case Intrinsic::maxnum:
223+
case Intrinsic::minimum:
224+
case Intrinsic::minnum:
225+
case Intrinsic::pow:
226+
case Intrinsic::sin:
227+
case Intrinsic::sqrt:
228+
break;
229+
case Intrinsic::fptosi_sat:
230+
case Intrinsic::fptoui_sat:
231+
Types.push_back(Ty);
232+
Ty = Inst.getArgOperand(0)->getType();
233+
break;
234+
}
235+
236+
if (!Ty->getScalarType()->isBFloatTy())
237+
return;
238+
239+
LLVM_DEBUG(dbgs() << "GenXBFloatLowering: apply on Intrinsic\n"
240+
<< Inst << "\n");
241+
242+
auto *ExtTy = getFloatTyFromBfloat(Ty);
243+
Types.push_back(ExtTy);
244+
245+
IRBuilder<> Builder(&Inst);
246+
if (isa<FPMathOperator>(Inst))
247+
Builder.setFastMathFlags(Inst.getFastMathFlags());
248+
249+
SmallVector<Value *, 4> Args;
250+
llvm::transform(Inst.args(), std::back_inserter(Args),
251+
[&Builder, ExtTy](Value *Arg) {
252+
auto *Ty = Arg->getType();
253+
if (!Ty->getScalarType()->isBFloatTy())
254+
return Arg;
255+
return Builder.CreateFPExt(Arg, ExtTy);
256+
});
257+
258+
auto *Func = vc::getAnyDeclaration(Inst.getModule(), IID, Types);
259+
Value *NewInst = Builder.CreateCall(Func, Args);
260+
261+
if (NewInst->getType()->getScalarType()->isFloatTy())
262+
NewInst = Builder.CreateFPTrunc(NewInst, Inst.getType());
263+
264+
Inst.replaceAllUsesWith(NewInst);
265+
Inst.eraseFromParent();
266+
Modify = true;
267+
}
268+
198269
void GenXBFloatLowering::lowerCastToBFloat(CastInst &Inst) {
199270
auto *ResTy = Inst.getType();
200271
if (!ResTy->getScalarType()->isBFloatTy())
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2023 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
9+
; REQUIRES: llvm_12_or_greater
10+
; RUN: %opt %use_old_pass_manager% -GenXBFloatLowering -march=genx64 -mcpu=XeHPG -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s
11+
12+
declare <8 x bfloat> @llvm.genx.sat.v8bf16(<8 x bfloat>)
13+
14+
declare <8 x bfloat> @llvm.cos.v8bf16(<8 x bfloat>)
15+
declare <8 x bfloat> @llvm.exp2.v8bf16(<8 x bfloat>)
16+
declare <8 x bfloat> @llvm.fabs.v8bf16(<8 x bfloat>)
17+
declare <8 x bfloat> @llvm.log2.v8bf16(<8 x bfloat>)
18+
declare <8 x bfloat> @llvm.sin.v8bf16(<8 x bfloat>)
19+
declare <8 x bfloat> @llvm.sqrt.v8bf16(<8 x bfloat>)
20+
21+
declare <8 x bfloat> @llvm.maximum.v8bf16(<8 x bfloat>, <8 x bfloat>)
22+
declare <8 x bfloat> @llvm.maxnum.v8bf16(<8 x bfloat>, <8 x bfloat>)
23+
declare <8 x bfloat> @llvm.minimum.v8bf16(<8 x bfloat>, <8 x bfloat>)
24+
declare <8 x bfloat> @llvm.minnum.v8bf16(<8 x bfloat>, <8 x bfloat>)
25+
declare <8 x bfloat> @llvm.pow.v8bf16(<8 x bfloat>, <8 x bfloat>)
26+
27+
declare <8 x bfloat> @llvm.fma.v8bf16(<8 x bfloat>, <8 x bfloat>, <8 x bfloat>)
28+
declare <8 x bfloat> @llvm.fmuladd.v8bf16(<8 x bfloat>, <8 x bfloat>, <8 x bfloat>)
29+
30+
declare <8 x i32> @llvm.fptosi.sat.v8i32.v8bf16(<8 x bfloat>)
31+
declare <8 x i32> @llvm.fptoui.sat.v8i32.v8bf16(<8 x bfloat>)
32+
33+
; CHECK-LABEL: test_sat
34+
define <8 x bfloat> @test_sat(<8 x bfloat> %src) {
35+
; CHECK: [[EXT:%[^ ]+]] = fpext <8 x bfloat> %src to <8 x float>
36+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.genx.sat.v8f32(<8 x float> [[EXT]])
37+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
38+
; CHECK: ret <8 x bfloat> [[TRUNC]]
39+
%res = call <8 x bfloat> @llvm.genx.sat.v8bf16(<8 x bfloat> %src)
40+
ret <8 x bfloat> %res
41+
}
42+
43+
; CHECK-LABEL: test_cos
44+
define <8 x bfloat> @test_cos(<8 x bfloat> %src) {
45+
; CHECK: [[EXT:%[^ ]+]] = fpext <8 x bfloat> %src to <8 x float>
46+
; CHECK: [[RES:%[^ ]+]] = call fast <8 x float> @llvm.cos.v8f32(<8 x float> [[EXT]])
47+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
48+
; CHECK: ret <8 x bfloat> [[TRUNC]]
49+
%res = call fast <8 x bfloat> @llvm.cos.v8bf16(<8 x bfloat> %src)
50+
ret <8 x bfloat> %res
51+
}
52+
53+
; CHECK-LABEL: test_exp2
54+
define <8 x bfloat> @test_exp2(<8 x bfloat> %src) {
55+
; CHECK: [[EXT:%[^ ]+]] = fpext <8 x bfloat> %src to <8 x float>
56+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.exp2.v8f32(<8 x float> [[EXT]])
57+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
58+
; CHECK: ret <8 x bfloat> [[TRUNC]]
59+
%res = call <8 x bfloat> @llvm.exp2.v8bf16(<8 x bfloat> %src)
60+
ret <8 x bfloat> %res
61+
}
62+
63+
; CHECK-LABEL: test_fabs
64+
define <8 x bfloat> @test_fabs(<8 x bfloat> %src) {
65+
; CHECK: [[EXT:%[^ ]+]] = fpext <8 x bfloat> %src to <8 x float>
66+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.fabs.v8f32(<8 x float> [[EXT]])
67+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
68+
; CHECK: ret <8 x bfloat> [[TRUNC]]
69+
%res = call <8 x bfloat> @llvm.fabs.v8bf16(<8 x bfloat> %src)
70+
ret <8 x bfloat> %res
71+
}
72+
73+
; CHECK-LABEL: test_log2
74+
define <8 x bfloat> @test_log2(<8 x bfloat> %src) {
75+
; CHECK: [[EXT:%[^ ]+]] = fpext <8 x bfloat> %src to <8 x float>
76+
; CHECK: [[RES:%[^ ]+]] = call afn <8 x float> @llvm.log2.v8f32(<8 x float> [[EXT]])
77+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
78+
; CHECK: ret <8 x bfloat> [[TRUNC]]
79+
%res = call afn <8 x bfloat> @llvm.log2.v8bf16(<8 x bfloat> %src)
80+
ret <8 x bfloat> %res
81+
}
82+
83+
; CHECK-LABEL: test_sin
84+
define <8 x bfloat> @test_sin(<8 x bfloat> %src) {
85+
; CHECK: [[EXT:%[^ ]+]] = fpext <8 x bfloat> %src to <8 x float>
86+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.sin.v8f32(<8 x float> [[EXT]])
87+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
88+
; CHECK: ret <8 x bfloat> [[TRUNC]]
89+
%res = call <8 x bfloat> @llvm.sin.v8bf16(<8 x bfloat> %src)
90+
ret <8 x bfloat> %res
91+
}
92+
93+
; CHECK-LABEL: test_sqrt
94+
define <8 x bfloat> @test_sqrt(<8 x bfloat> %src) {
95+
; CHECK: [[EXT:%[^ ]+]] = fpext <8 x bfloat> %src to <8 x float>
96+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.sqrt.v8f32(<8 x float> [[EXT]])
97+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
98+
; CHECK: ret <8 x bfloat> [[TRUNC]]
99+
%res = call <8 x bfloat> @llvm.sqrt.v8bf16(<8 x bfloat> %src)
100+
ret <8 x bfloat> %res
101+
}
102+
103+
; CHECK-LABEL: test_maximum
104+
define <8 x bfloat> @test_maximum(<8 x bfloat> %a, <8 x bfloat> %b) {
105+
; CHECK: [[AEXT:%[^ ]+]] = fpext <8 x bfloat> %a to <8 x float>
106+
; CHECK: [[BEXT:%[^ ]+]] = fpext <8 x bfloat> %b to <8 x float>
107+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.maximum.v8f32(<8 x float> [[AEXT]], <8 x float> [[BEXT]])
108+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
109+
; CHECK: ret <8 x bfloat> [[TRUNC]]
110+
%res = call <8 x bfloat> @llvm.maximum.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b)
111+
ret <8 x bfloat> %res
112+
}
113+
114+
; CHECK-LABEL: test_maxnum
115+
define <8 x bfloat> @test_maxnum(<8 x bfloat> %a, <8 x bfloat> %b) {
116+
; CHECK: [[AEXT:%[^ ]+]] = fpext <8 x bfloat> %a to <8 x float>
117+
; CHECK: [[BEXT:%[^ ]+]] = fpext <8 x bfloat> %b to <8 x float>
118+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.maxnum.v8f32(<8 x float> [[AEXT]], <8 x float> [[BEXT]])
119+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
120+
; CHECK: ret <8 x bfloat> [[TRUNC]]
121+
%res = call <8 x bfloat> @llvm.maxnum.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b)
122+
ret <8 x bfloat> %res
123+
}
124+
125+
; CHECK-LABEL: test_minimum
126+
define <8 x bfloat> @test_minimum(<8 x bfloat> %a, <8 x bfloat> %b) {
127+
; CHECK: [[AEXT:%[^ ]+]] = fpext <8 x bfloat> %a to <8 x float>
128+
; CHECK: [[BEXT:%[^ ]+]] = fpext <8 x bfloat> %b to <8 x float>
129+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.minimum.v8f32(<8 x float> [[AEXT]], <8 x float> [[BEXT]])
130+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
131+
; CHECK: ret <8 x bfloat> [[TRUNC]]
132+
%res = call <8 x bfloat> @llvm.minimum.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b)
133+
ret <8 x bfloat> %res
134+
}
135+
136+
; CHECK-LABEL: test_minnum
137+
define <8 x bfloat> @test_minnum(<8 x bfloat> %a, <8 x bfloat> %b) {
138+
; CHECK: [[AEXT:%[^ ]+]] = fpext <8 x bfloat> %a to <8 x float>
139+
; CHECK: [[BEXT:%[^ ]+]] = fpext <8 x bfloat> %b to <8 x float>
140+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.minnum.v8f32(<8 x float> [[AEXT]], <8 x float> [[BEXT]])
141+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
142+
; CHECK: ret <8 x bfloat> [[TRUNC]]
143+
%res = call <8 x bfloat> @llvm.minnum.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b)
144+
ret <8 x bfloat> %res
145+
}
146+
147+
; CHECK-LABEL: test_pow
148+
define <8 x bfloat> @test_pow(<8 x bfloat> %a, <8 x bfloat> %b) {
149+
; CHECK: [[AEXT:%[^ ]+]] = fpext <8 x bfloat> %a to <8 x float>
150+
; CHECK: [[BEXT:%[^ ]+]] = fpext <8 x bfloat> %b to <8 x float>
151+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.pow.v8f32(<8 x float> [[AEXT]], <8 x float> [[BEXT]])
152+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
153+
; CHECK: ret <8 x bfloat> [[TRUNC]]
154+
%res = call <8 x bfloat> @llvm.pow.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b)
155+
ret <8 x bfloat> %res
156+
}
157+
158+
; CHECK-LABEL: test_fma
159+
define <8 x bfloat> @test_fma(<8 x bfloat> %a, <8 x bfloat> %b, <8 x bfloat> %c) {
160+
; CHECK: [[AEXT:%[^ ]+]] = fpext <8 x bfloat> %a to <8 x float>
161+
; CHECK: [[BEXT:%[^ ]+]] = fpext <8 x bfloat> %b to <8 x float>
162+
; CHECK: [[CEXT:%[^ ]+]] = fpext <8 x bfloat> %c to <8 x float>
163+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.fma.v8f32(<8 x float> [[AEXT]], <8 x float> [[BEXT]], <8 x float> [[CEXT]])
164+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
165+
; CHECK: ret <8 x bfloat> [[TRUNC]]
166+
%res = call <8 x bfloat> @llvm.fma.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b, <8 x bfloat> %c)
167+
ret <8 x bfloat> %res
168+
}
169+
170+
; CHECK-LABEL: test_fmuladd
171+
define <8 x bfloat> @test_fmuladd(<8 x bfloat> %a, <8 x bfloat> %b, <8 x bfloat> %c) {
172+
; CHECK: [[AEXT:%[^ ]+]] = fpext <8 x bfloat> %a to <8 x float>
173+
; CHECK: [[BEXT:%[^ ]+]] = fpext <8 x bfloat> %b to <8 x float>
174+
; CHECK: [[CEXT:%[^ ]+]] = fpext <8 x bfloat> %c to <8 x float>
175+
; CHECK: [[RES:%[^ ]+]] = call <8 x float> @llvm.fmuladd.v8f32(<8 x float> [[AEXT]], <8 x float> [[BEXT]], <8 x float> [[CEXT]])
176+
; CHECK: [[TRUNC:%[^ ]+]] = fptrunc <8 x float> [[RES]] to <8 x bfloat>
177+
; CHECK: ret <8 x bfloat> [[TRUNC]]
178+
%res = call <8 x bfloat> @llvm.fmuladd.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b, <8 x bfloat> %c)
179+
ret <8 x bfloat> %res
180+
}
181+
182+
; CHECK-LABEL: test_fptosi_sat
183+
define <8 x i32> @test_fptosi_sat(<8 x bfloat> %src) {
184+
; CHECK: [[EXT:%[^ ]+]] = fpext <8 x bfloat> %src to <8 x float>
185+
; CHECK: [[RES:%[^ ]+]] = call <8 x i32> @llvm.fptosi.sat.v8i32.v8f32(<8 x float> [[EXT]])
186+
; CHECK: ret <8 x i32> [[RES]]
187+
%res = call <8 x i32> @llvm.fptosi.sat.v8i32.v8bf16(<8 x bfloat> %src)
188+
ret <8 x i32> %res
189+
}
190+
191+
; CHECK-LABEL: test_fptoui_sat
192+
define <8 x i32> @test_fptoui_sat(<8 x bfloat> %src) {
193+
; CHECK: [[EXT:%[^ ]+]] = fpext <8 x bfloat> %src to <8 x float>
194+
; CHECK: [[RES:%[^ ]+]] = call <8 x i32> @llvm.fptoui.sat.v8i32.v8f32(<8 x float> [[EXT]])
195+
; CHECK: ret <8 x i32> [[RES]]
196+
%res = call <8 x i32> @llvm.fptoui.sat.v8i32.v8bf16(<8 x bfloat> %src)
197+
ret <8 x i32> %res
198+
}

0 commit comments

Comments
 (0)