Skip to content

Commit f8d1823

Browse files
PawelJurekigcbot
authored andcommitted
Change the implementation of arithmetic __builtin_bf16_* validation functions
When float type is used on arguments or on return type, extend the bfloat sources to float. This way we can generate mix-mode instructions when available.
1 parent 95e0cc1 commit f8d1823

File tree

3 files changed

+60
-50
lines changed

3 files changed

+60
-50
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/BfloatFuncs/BfloatFuncsResolution.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ void BfloatFuncsResolution::visitCallInst(CallInst &CI) {
9494
.StartsWith("__builtin_bf16_isless",
9595
[&]() { handleCompare(CI, CmpInst::Predicate::FCMP_OLT); })
9696
.StartsWith("__builtin_bf16_isnotequal",
97-
[&]() { handleCompare(CI, CmpInst::Predicate::FCMP_ONE); })
97+
[&]() { handleCompare(CI, CmpInst::Predicate::FCMP_UNE); })
9898
.StartsWith("__builtin_bf16_isunordered",
9999
[&]() { handleCompare(CI, CmpInst::Predicate::FCMP_UNO); })
100100
.StartsWith("__builtin_bf16_select",
@@ -263,26 +263,15 @@ void BfloatFuncsResolution::handleArithmetic(llvm::CallInst& CI,
263263
bool IsResFloat = CI.getType()->getScalarType()->isFloatTy();
264264

265265

266-
if (IsResFloat && (FloatOperandIndex != -1)) {
267-
// 1. If we have float on destination, and one float source extend the
268-
// short sources to float.
266+
if (IsResFloat || (FloatOperandIndex != -1)) {
267+
// 1. If we have float on destination, or float source extend the
268+
// short sources to float. Let vISA handle the mix mode propagation.
269269
for (size_t i = 0; i < Operands.size(); ++i) {
270270
if (i == FloatOperandIndex)
271271
continue;
272272
auto Op = bitcastToBfloat(Operands[i]);
273-
Operands[i] = m_builder->CreateFPExt(Op, CI.getType());
274-
}
275-
} else if (!IsResFloat && (FloatOperandIndex != -1)) {
276-
// 2. If we have short on destination, truncate the float source to
277-
// bfloat.
278-
for (size_t i = 0; i < Operands.size(); ++i) {
279-
if (i == FloatOperandIndex) {
280-
Operands[i] = m_builder->CreateFPTrunc(
281-
Operands[i], getTypeBasedOnType(Operands[i]->getType(),
282-
m_builder->getBFloatTy()));
283-
} else {
284-
Operands[i] = bitcastToBfloat(Operands[i]);
285-
}
273+
Operands[i] = m_builder->CreateFPExt(
274+
Op, getTypeBasedOnType(Op->getType(), m_builder->getFloatTy()));
286275
}
287276
} else if (FloatOperandIndex == -1) {
288277
// 3. If we have only shorts on source, just
@@ -303,16 +292,22 @@ void BfloatFuncsResolution::handleArithmetic(llvm::CallInst& CI,
303292
Res = m_builder->CreateFAdd(Res, Operands[2]);
304293
} else {
305294
IGC_ASSERT_MESSAGE(0, "Unsupported number of operands.");
295+
return;
306296
}
307297

308-
if (Res->getType()->getScalarType()->isBFloatTy()) {
298+
if (Res && Res->getType()->getScalarType()->isBFloatTy()) {
309299
if (IsResFloat) {
310-
Res = m_builder->CreateFPExt(Res, CI.getType());
300+
IGC_ASSERT_MESSAGE(0, "Not expected path");
311301
} else {
312302
Res = m_builder->CreateBitCast(Res, CI.getType());
313303
}
314304
} else {
315-
IGC_ASSERT(Res->getType()->getScalarType()->isFloatTy());
305+
IGC_ASSERT(Res && Res->getType()->getScalarType()->isFloatTy());
306+
if (!CI.getType()->getScalarType()->isFloatTy()) {
307+
Res = m_builder->CreateFPTrunc(
308+
Res, getTypeBasedOnType(Res->getType(), m_builder->getBFloatTy()));
309+
Res = m_builder->CreateBitCast(Res, CI.getType());
310+
}
316311
}
317312

318313
CI.replaceAllUsesWith(Res);

IGC/Compiler/tests/BfloatFuncsResolution/arithmetic.ll

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,10 @@ entry:
105105
; CHECK: define spir_kernel void @test_add_v2
106106
define spir_kernel void @test_add_v2(i16 addrspace(1)* %out1, float %v1_1, i16 zeroext %v2_1) #1 {
107107
entry:
108-
; CHECK: %[[SRC0BF:.*]] = fptrunc float %v1_1 to bfloat
109108
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
110-
; CHECK: %[[RES:.*]] = fadd bfloat %[[SRC0BF]], %[[SRC1BF]]
109+
; CHECK: %[[SRC1F:.*]] = fpext bfloat %[[SRC1BF]] to float
110+
; CHECK: %[[RESF:.*]] = fadd float %v1_1, %[[SRC1F]]
111+
; CHECK: %[[RES:.*]] = fptrunc float %[[RESF]] to bfloat
111112
; CHECK: %{{.*}} = bitcast bfloat %[[RES]] to i16
112113
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_addft(float %v1_1, i16 zeroext %v2_1) #2
113114
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
@@ -120,8 +121,9 @@ entry:
120121
define spir_kernel void @test_add_v3(i16 addrspace(1)* %out1, i16 zeroext %v1_1, float %v2_1) #1 {
121122
entry:
122123
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
123-
; CHECK: %[[SRC1BF:.*]] = fptrunc float %v2_1 to bfloat
124-
; CHECK: %[[RES:.*]] = fadd bfloat %[[SRC0BF]], %[[SRC1BF]]
124+
; CHECK: %[[SRC0F:.*]] = fpext bfloat %[[SRC0BF]] to float
125+
; CHECK: %[[RESF:.*]] = fadd float %[[SRC0F]], %v2_1
126+
; CHECK: %[[RES:.*]] = fptrunc float %[[RESF]] to bfloat
125127
; CHECK: %{{.*}} = bitcast bfloat %[[RES]] to i16
126128
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_addtf(i16 zeroext %v1_1, float %v2_1) #2
127129
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
@@ -133,9 +135,10 @@ entry:
133135
define spir_kernel void @test_addf_v1(float addrspace(1)* %out1, i16 zeroext %v1_1, i16 zeroext %v2_1) #1 {
134136
entry:
135137
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
138+
; CHECK: %[[SRC0F:.*]] = fpext bfloat %[[SRC0BF]] to float
136139
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
137-
; CHECK: %[[RES:.*]] = fadd bfloat %[[SRC0BF]], %[[SRC1BF]]
138-
; CHECK: %{{.*}} = fpext bfloat %[[RES]] to float
140+
; CHECK: %[[SRC1F:.*]] = fpext bfloat %[[SRC1BF]] to float
141+
; CHECK: %[[RES:.*]] = fadd float %[[SRC0F]], %[[SRC1F]]
139142
%call = call spir_func float @_Z19__builtin_bf16_addftt(i16 zeroext %v1_1, i16 zeroext %v2_1) #2
140143
%arrayidx = getelementptr inbounds float, float addrspace(1)* %out1, i64 0
141144
store float %call, float addrspace(1)* %arrayidx, align 4
@@ -186,9 +189,10 @@ entry:
186189
; CHECK: define spir_kernel void @test_sub_v2
187190
define spir_kernel void @test_sub_v2(i16 addrspace(1)* %out1, float %v1_1, i16 zeroext %v2_1) #1 {
188191
entry:
189-
; CHECK: %[[SRC0BF:.*]] = fptrunc float %v1_1 to bfloat
190192
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
191-
; CHECK: %[[RES:.*]] = fsub bfloat %[[SRC0BF]], %[[SRC1BF]]
193+
; CHECK: %[[SRC1F:.*]] = fpext bfloat %[[SRC1BF]] to float
194+
; CHECK: %[[RESF:.*]] = fsub float %v1_1, %[[SRC1F]]
195+
; CHECK: %[[RES:.*]] = fptrunc float %[[RESF]] to bfloat
192196
; CHECK: %{{.*}} = bitcast bfloat %[[RES]] to i16
193197
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_subft(float %v1_1, i16 zeroext %v2_1) #2
194198
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
@@ -201,8 +205,9 @@ entry:
201205
define spir_kernel void @test_sub_v3(i16 addrspace(1)* %out1, i16 zeroext %v1_1, float %v2_1) #1 {
202206
entry:
203207
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
204-
; CHECK: %[[SRC1BF:.*]] = fptrunc float %v2_1 to bfloat
205-
; CHECK: %[[RES:.*]] = fsub bfloat %[[SRC0BF]], %[[SRC1BF]]
208+
; CHECK: %[[SRC0F:.*]] = fpext bfloat %[[SRC0BF]] to float
209+
; CHECK: %[[RESF:.*]] = fsub float %[[SRC0F]], %v2_1
210+
; CHECK: %[[RES:.*]] = fptrunc float %[[RESF]] to bfloat
206211
; CHECK: %{{.*}} = bitcast bfloat %[[RES]] to i16
207212
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_subtf(i16 zeroext %v1_1, float %v2_1) #2
208213
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
@@ -214,9 +219,10 @@ entry:
214219
define spir_kernel void @test_subf_v1(float addrspace(1)* %out1, i16 zeroext %v1_1, i16 zeroext %v2_1) #1 {
215220
entry:
216221
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
222+
; CHECK: %[[SRC0F:.*]] = fpext bfloat %[[SRC0BF]] to float
217223
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
218-
; CHECK: %[[RES:.*]] = fsub bfloat %[[SRC0BF]], %[[SRC1BF]]
219-
; CHECK: %{{.*}} = fpext bfloat %[[RES]] to float
224+
; CHECK: %[[SRC1F:.*]] = fpext bfloat %[[SRC1BF]] to float
225+
; CHECK: %[[RES:.*]] = fsub float %[[SRC0F]], %[[SRC1F]]
220226
%call = call spir_func float @_Z19__builtin_bf16_subftt(i16 zeroext %v1_1, i16 zeroext %v2_1) #2
221227
%arrayidx = getelementptr inbounds float, float addrspace(1)* %out1, i64 0
222228
store float %call, float addrspace(1)* %arrayidx, align 4
@@ -267,9 +273,10 @@ entry:
267273
; CHECK: define spir_kernel void @test_mul_v2
268274
define spir_kernel void @test_mul_v2(i16 addrspace(1)* %out1, float %v1_1, i16 zeroext %v2_1) #1 {
269275
entry:
270-
; CHECK: %[[SRC0BF:.*]] = fptrunc float %v1_1 to bfloat
271276
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
272-
; CHECK: %[[RES:.*]] = fmul bfloat %[[SRC0BF]], %[[SRC1BF]]
277+
; CHECK: %[[SRC1F:.*]] = fpext bfloat %[[SRC1BF]] to float
278+
; CHECK: %[[RESF:.*]] = fmul float %v1_1, %[[SRC1F]]
279+
; CHECK: %[[RES:.*]] = fptrunc float %[[RESF]] to bfloat
273280
; CHECK: %{{.*}} = bitcast bfloat %[[RES]] to i16
274281
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_mulft(float %v1_1, i16 zeroext %v2_1) #2
275282
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
@@ -282,8 +289,9 @@ entry:
282289
define spir_kernel void @test_mul_v3(i16 addrspace(1)* %out1, i16 zeroext %v1_1, float %v2_1) #1 {
283290
entry:
284291
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
285-
; CHECK: %[[SRC1BF:.*]] = fptrunc float %v2_1 to bfloat
286-
; CHECK: %[[RES:.*]] = fmul bfloat %[[SRC0BF]], %[[SRC1BF]]
292+
; CHECK: %[[SRC0F:.*]] = fpext bfloat %[[SRC0BF]] to float
293+
; CHECK: %[[RESF:.*]] = fmul float %[[SRC0F]], %v2_1
294+
; CHECK: %[[RES:.*]] = fptrunc float %[[RESF]] to bfloat
287295
; CHECK: %{{.*}} = bitcast bfloat %[[RES]] to i16
288296
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_multf(i16 zeroext %v1_1, float %v2_1) #2
289297
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
@@ -295,9 +303,10 @@ entry:
295303
define spir_kernel void @test_mulf_v1(float addrspace(1)* %out1, i16 zeroext %v1_1, i16 zeroext %v2_1) #1 {
296304
entry:
297305
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
306+
; CHECK: %[[SRC0F:.*]] = fpext bfloat %[[SRC0BF]] to float
298307
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
299-
; CHECK: %[[RES:.*]] = fmul bfloat %[[SRC0BF]], %[[SRC1BF]]
300-
; CHECK: %{{.*}} = fpext bfloat %[[RES]] to float
308+
; CHECK: %[[SRC1F:.*]] = fpext bfloat %[[SRC1BF]] to float
309+
; CHECK: %[[RES:.*]] = fmul float %[[SRC0F]], %[[SRC1F]]
301310
%call = call spir_func float @_Z19__builtin_bf16_mulftt(i16 zeroext %v1_1, i16 zeroext %v2_1) #2
302311
%arrayidx = getelementptr inbounds float, float addrspace(1)* %out1, i64 0
303312
store float %call, float addrspace(1)* %arrayidx, align 4
@@ -348,12 +357,14 @@ entry:
348357
; Function Attrs: convergent nounwind
349358
define spir_kernel void @test_mad_v2(i16 addrspace(1)* %out1, float %v1_1, i16 zeroext %v2_1, i16 zeroext %v3_1) #1 {
350359
entry:
351-
; CHECK: %[[SRC0BF:.*]] = fptrunc float %v1_1 to bfloat
352360
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
361+
; CHECK: %[[SRC1F:.*]] = fpext bfloat %[[SRC1BF]] to float
353362
; CHECK: %[[SRC2BF:.*]] = bitcast i16 %v3_1 to bfloat
354-
; CHECK: %[[FMULRES:.*]] = fmul bfloat %[[SRC0BF]], %[[SRC1BF]]
355-
; CHECK: %[[FADDRES:.*]] = fadd bfloat %[[FMULRES]], %[[SRC2BF]]
356-
; CHECK: %{{.*}} = bitcast bfloat %[[FADDRES]] to i16
363+
; CHECK: %[[SRC2F:.*]] = fpext bfloat %[[SRC2BF]] to float
364+
; CHECK: %[[FMULRES:.*]] = fmul float %v1_1, %[[SRC1F]]
365+
; CHECK: %[[FADDRES:.*]] = fadd float %[[FMULRES]], %[[SRC2F]]
366+
; CHECK: %[[RES:.*]] = fptrunc float %[[FADDRES]] to bfloat
367+
; CHECK: %{{.*}} = bitcast bfloat %[[RES]] to i16
357368
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_madftt(float %v1_1, i16 zeroext %v2_1, i16 zeroext %v3_1) #2
358369
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
359370
store i16 %call, i16 addrspace(1)* %arrayidx, align 2
@@ -364,11 +375,13 @@ entry:
364375
define spir_kernel void @test_mad_v3(i16 addrspace(1)* %out1, i16 zeroext %v1_1, i16 zeroext %v2_1, float %v3_1) #1 {
365376
entry:
366377
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
378+
; CHECK: %[[SRC0F:.*]] = fpext bfloat %[[SRC0BF]] to float
367379
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
368-
; CHECK: %[[SRC2BF:.*]] = fptrunc float %v3_1 to bfloat
369-
; CHECK: %[[FMULRES:.*]] = fmul bfloat %[[SRC0BF]], %[[SRC1BF]]
370-
; CHECK: %[[FADDRES:.*]] = fadd bfloat %[[FMULRES]], %[[SRC2BF]]
371-
; CHECK: %{{.*}} = bitcast bfloat %[[FADDRES]] to i16
380+
; CHECK: %[[SRC1F:.*]] = fpext bfloat %[[SRC1BF]] to float
381+
; CHECK: %[[FMULRES:.*]] = fmul float %[[SRC0F]], %[[SRC1F]]
382+
; CHECK: %[[FADDRES:.*]] = fadd float %[[FMULRES]], %v3_1
383+
; CHECK: %[[RES:.*]] = fptrunc float %[[FADDRES]] to bfloat
384+
; CHECK: %{{.*}} = bitcast bfloat %[[RES]] to i16
372385
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_madttf(i16 zeroext %v1_1, i16 zeroext %v2_1, float %v3_1) #2
373386
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
374387
store i16 %call, i16 addrspace(1)* %arrayidx, align 2
@@ -379,11 +392,13 @@ entry:
379392
define spir_kernel void @test_madf_v1(float addrspace(1)* %out1, i16 zeroext %v1_1, i16 zeroext %v2_1, i16 zeroext %v3_1) #1 {
380393
entry:
381394
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
395+
; CHECK: %[[SRC0F:.*]] = fpext bfloat %[[SRC0BF]] to float
382396
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
397+
; CHECK: %[[SRC1F:.*]] = fpext bfloat %[[SRC1BF]] to float
383398
; CHECK: %[[SRC2BF:.*]] = bitcast i16 %v3_1 to bfloat
384-
; CHECK: %[[FMULRES:.*]] = fmul bfloat %[[SRC0BF]], %[[SRC1BF]]
385-
; CHECK: %[[FADDRES:.*]] = fadd bfloat %[[FMULRES]], %[[SRC2BF]]
386-
; CHECK: %{{.*}} = fpext bfloat %[[FADDRES]] to float
399+
; CHECK: %[[SRC2F:.*]] = fpext bfloat %[[SRC2BF]] to float
400+
; CHECK: %[[FMULRES:.*]] = fmul float %[[SRC0F]], %[[SRC1F]]
401+
; CHECK: %[[FADDRES:.*]] = fadd float %[[FMULRES]], %[[SRC2F]]
387402
%call = call spir_func float @_Z19__builtin_bf16_madfttt(i16 zeroext %v1_1, i16 zeroext %v2_1, i16 zeroext %v3_1) #2
388403
%arrayidx = getelementptr inbounds float, float addrspace(1)* %out1, i64 0
389404
store float %call, float addrspace(1)* %arrayidx, align 4

IGC/Compiler/tests/BfloatFuncsResolution/compare.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ define spir_kernel void @test_isnotequal(i32 addrspace(1)* %out1, i16 zeroext %v
120120
entry:
121121
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
122122
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
123-
; CHECK: %[[FCMPRES:.*]] = fcmp one bfloat %[[SRC0BF]], %[[SRC1BF]]
123+
; CHECK: %[[FCMPRES:.*]] = fcmp une bfloat %[[SRC0BF]], %[[SRC1BF]]
124124
; CHECK: %{{.*}} = zext i1 %[[FCMPRES]] to i32
125125
%call = call spir_func i32 @_Z25__builtin_bf16_isnotequaltt(i16 zeroext %v1_1, i16 zeroext %v2_1) #2
126126
%arrayidx = getelementptr inbounds i32, i32 addrspace(1)* %out1, i64 0

0 commit comments

Comments
 (0)