Skip to content

Commit c8cd2f6

Browse files
authored
Fix integer dot product translation (#1174)
Fix translation of IR from source when dot function is called with integer arguments (ints, vectors of chars or shorts) to properly support cl_khr_integer_dot_product extension. Previously it translated calls such as dot(uchar4, uchar4) into OpDot %i8, which is wrong because OpDot only operates on floating point types.
1 parent 0eb9a7d commit c8cd2f6

6 files changed

+373
-1
lines changed

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,23 @@ void OCLToSPIRVBase::visitCallInst(CallInst &CI) {
323323
return;
324324
}
325325
if (DemangledName == kOCLBuiltinName::Dot &&
326-
!(CI.getOperand(0)->getType()->isVectorTy())) {
326+
(CI.getOperand(0)->getType()->isFloatTy() ||
327+
CI.getOperand(1)->getType()->isDoubleTy())) {
327328
visitCallDot(&CI);
328329
return;
329330
}
331+
if (DemangledName == kOCLBuiltinName::Dot ||
332+
DemangledName == kOCLBuiltinName::DotAccSat) {
333+
if (CI.getOperand(0)->getType()->isVectorTy()) {
334+
auto *VT = (VectorType *)(CI.getOperand(0)->getType());
335+
if (!isa<llvm::IntegerType>(VT->getElementType())) {
336+
visitCallBuiltinSimple(&CI, MangledName, DemangledName);
337+
return;
338+
}
339+
}
340+
visitCallDot(&CI, MangledName, DemangledName);
341+
return;
342+
}
330343
if (DemangledName == kOCLBuiltinName::FMin ||
331344
DemangledName == kOCLBuiltinName::FMax ||
332345
DemangledName == kOCLBuiltinName::Min ||
@@ -1306,6 +1319,105 @@ void OCLToSPIRVBase::visitCallDot(CallInst *CI) {
13061319
CI->eraseFromParent();
13071320
}
13081321

1322+
void OCLToSPIRVBase::visitCallDot(CallInst *CI, StringRef MangledName,
1323+
StringRef DemangledName) {
1324+
// translation for dot function calls,
1325+
// to differentiate between integer dot products
1326+
1327+
SmallVector<Value *, 3> Args;
1328+
Args.push_back(CI->getOperand(0));
1329+
Args.push_back(CI->getOperand(1));
1330+
bool IsFirstSigned, IsSecondSigned;
1331+
bool IsDot = DemangledName == kOCLBuiltinName::Dot;
1332+
std::string FunName = (IsDot) ? "DotKHR" : "DotAccSatKHR";
1333+
if (CI->arg_size() > 2) {
1334+
Args.push_back(CI->getOperand(2));
1335+
}
1336+
if (CI->arg_size() > 3) {
1337+
Args.push_back(CI->getOperand(3));
1338+
}
1339+
if (CI->getOperand(0)->getType()->isVectorTy()) {
1340+
if (IsDot) {
1341+
// dot(char4, char4) _Z3dotDv4_cS_
1342+
// dot(char4, uchar4) _Z3dotDv4_cDv4_h
1343+
// dot(uchar4, char4) _Z3dotDv4_hDv4_c
1344+
// dot(uchar4, uchar4) _Z3dotDv4_hS_
1345+
// or
1346+
// dot(short2, short2) _Z3dotDv2_sS_
1347+
// dot(short2, ushort2) _Z3dotDv2_sDv2_t
1348+
// dot(ushort2, short2) _Z3dotDv2_tDv2_s
1349+
// dot(ushort2, ushort2) _Z3dotDv2_tS_
1350+
assert(MangledName.startswith("_Z3dotDv"));
1351+
if (MangledName[MangledName.size() - 1] == '_') {
1352+
IsFirstSigned = ((MangledName[MangledName.size() - 3] == 'c') ||
1353+
(MangledName[MangledName.size() - 3] == 's'));
1354+
IsSecondSigned = IsFirstSigned;
1355+
} else {
1356+
IsFirstSigned = ((MangledName[MangledName.size() - 6] == 'c') ||
1357+
(MangledName[MangledName.size() - 6] == 's'));
1358+
IsSecondSigned = ((MangledName[MangledName.size() - 1] == 'c') ||
1359+
(MangledName[MangledName.size() - 1] == 's'));
1360+
}
1361+
} else {
1362+
// dot_acc_sat(char4, char4, int) _Z11dot_acc_satDv4_cS_i
1363+
// dot_acc_sat(char4, uchar4, int) _Z11dot_acc_satDv4_cDv4_hi
1364+
// dot_acc_sat(uchar4, char4, int) _Z11dot_acc_satDv4_hDv4_ci
1365+
// dot_acc_sat(uchar4, uchar4, uint) _Z11dot_acc_satDv4_hS_j
1366+
// or
1367+
// dot_acc_sat(short2, short2, int) _Z11dot_acc_satDv4_sS_i
1368+
// dot_acc_sat(short2, ushort2, int) _Z11dot_acc_satDv4_sDv4_ti
1369+
// dot_acc_sat(ushort2, short2, int) _Z11dot_acc_satDv4_tDv4_si
1370+
// dot_acc_sat(ushort2, ushort2, uint) _Z11dot_acc_satDv4_tS_j
1371+
assert(MangledName.startswith("_Z11dot_acc_satDv"));
1372+
IsFirstSigned = ((MangledName[19] == 'c') || (MangledName[19] == 's'));
1373+
IsSecondSigned = (MangledName[20] == 'S'
1374+
? IsFirstSigned
1375+
: ((MangledName[MangledName.size() - 2] == 'c') ||
1376+
(MangledName[MangledName.size() - 2] == 's')));
1377+
}
1378+
} else {
1379+
// for packed format
1380+
// dot(int, int, int) _Z3dotiii
1381+
// dot(int, uint, int) _Z3dotiji
1382+
// dot(uint, int, int) _Z3dotjii
1383+
// dot(uint, uint, int) _Z3dotjji
1384+
// or
1385+
// dot_acc_sat(int, int, int, int) _Z11dot_acc_satiiii
1386+
// dot_acc_sat(int, uint, int, int) _Z11dot_acc_satijii
1387+
// dot_acc_sat(uint, int, int, int) _Z11dot_acc_satjiii
1388+
// dot_acc_sat(uint, uint, int, int) _Z11dot_acc_satjjii
1389+
assert(MangledName.startswith("_Z3dot") ||
1390+
MangledName.startswith("_Z11dot_acc_sat"));
1391+
IsFirstSigned = (IsDot) ? (MangledName[MangledName.size() - 3] == 'i')
1392+
: (MangledName[MangledName.size() - 4] == 'i');
1393+
IsSecondSigned = (IsDot) ? (MangledName[MangledName.size() - 2] == 'i')
1394+
: (MangledName[MangledName.size() - 3] == 'i');
1395+
}
1396+
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1397+
mutateCallInstSPIRV(
1398+
M, CI,
1399+
[=](CallInst *, std::vector<Value *> &Args) {
1400+
// If arguments are in order unsigned -> signed
1401+
// then the translator should swap them,
1402+
// so that the OpSUDotKHR can be used properly
1403+
if (IsFirstSigned == false && IsSecondSigned == true) {
1404+
std::swap(Args[0], Args[1]);
1405+
}
1406+
Op OC;
1407+
if (IsDot) {
1408+
OC = (IsFirstSigned != IsSecondSigned
1409+
? OpSUDot
1410+
: ((IsFirstSigned) ? OpSDot : OpUDot));
1411+
} else {
1412+
OC = (IsFirstSigned != IsSecondSigned
1413+
? OpSUDotAccSat
1414+
: ((IsFirstSigned) ? OpSDotAccSat : OpUDotAccSat));
1415+
}
1416+
return getSPIRVFuncName(OC);
1417+
},
1418+
&Attrs);
1419+
}
1420+
13091421
void OCLToSPIRVBase::visitCallScalToVec(CallInst *CI, StringRef MangledName,
13101422
StringRef DemangledName) {
13111423
// Check if all arguments have the same type - it's simple case.

lib/SPIRV/OCLToSPIRV.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ class OCLToSPIRVBase : public InstVisitor<OCLToSPIRVBase> {
210210
/// Transforms OpDot instructions with a scalar type to a fmul instruction
211211
void visitCallDot(CallInst *CI);
212212

213+
/// Transforms OpDot instructions with a vector or scalar (packed vector) type
214+
/// to dot or dot_acc_sat instructions
215+
void visitCallDot(CallInst *CI, StringRef MangledName,
216+
StringRef DemangledName);
217+
213218
/// Fixes for built-in functions with vector+scalar arguments that are
214219
/// translated to the SPIR-V instructions where all arguments must have the
215220
/// same type.

lib/SPIRV/OCLUtil.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ const static char Barrier[] = "barrier";
233233
const static char Clamp[] = "clamp";
234234
const static char ConvertPrefix[] = "convert_";
235235
const static char Dot[] = "dot";
236+
const static char DotAccSat[] = "dot_acc_sat";
236237
const static char EnqueueKernel[] = "enqueue_kernel";
237238
const static char FixedSqrtINTEL[] = "intel_arbitrary_fixed_sqrt";
238239
const static char FixedRecipINTEL[] = "intel_arbitrary_fixed_recip";
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv -s %t.bc -o %t.regularized.bc
3+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_integer_dot_product -o %t-spirv.spv
4+
; RUN: spirv-val %t-spirv.spv
5+
; RUN: llvm-dis %t.regularized.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
6+
; RUN: llvm-spirv %t.bc -spirv-text --spirv-ext=+SPV_KHR_integer_dot_product -o - | FileCheck %s --check-prefix=CHECK-SPIRV
7+
8+
;CHECK-LLVM: call spir_func i32 @_Z15__spirv_SDotKHR
9+
;CHECK-LLVM: call spir_func i32 @_Z16__spirv_SUDotKHR
10+
;CHECK-LLVM: call spir_func i32 @_Z16__spirv_SUDotKHR
11+
;CHECK-LLVM: call spir_func i32 @_Z15__spirv_UDotKHR
12+
13+
;CHECK-LLVM: call spir_func i32 @_Z21__spirv_SDotAccSatKHR
14+
;CHECK-LLVM: call spir_func i32 @_Z22__spirv_SUDotAccSatKHR
15+
;CHECK-LLVM: call spir_func i32 @_Z22__spirv_SUDotAccSatKHR
16+
;CHECK-LLVM: call spir_func i32 @_Z21__spirv_UDotAccSatKHR
17+
18+
;CHECK-SPIRV: SDotKHR
19+
;CHECK-SPIRV: SUDotKHR
20+
;CHECK-SPIRV: SUDotKHR
21+
;CHECK-SPIRV: UDotKHR
22+
23+
;CHECK-SPIRV: SDotAccSatKHR
24+
;CHECK-SPIRV: SUDotAccSatKHR
25+
;CHECK-SPIRV: SUDotAccSatKHR
26+
;CHECK-SPIRV: UDotAccSatKHR
27+
28+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
29+
target triple = "spir"
30+
31+
; Function Attrs: convergent norecurse nounwind
32+
define spir_kernel void @test1(<4 x i8> %ia, <4 x i8> %ua, <4 x i8> %ib, <4 x i8> %ub, <4 x i8> %ires, <4 x i8> %ures) local_unnamed_addr #0 !kernel_arg_addr_space !3 !kernel_arg_access_qual !4 !kernel_arg_type !5 !kernel_arg_base_type !6 !kernel_arg_type_qual !7 {
33+
entry:
34+
%call = tail call spir_func i32 @_Z3dotDv4_cS_(<4 x i8> %ia, <4 x i8> %ib) #2
35+
%call1 = tail call spir_func i32 @_Z3dotDv4_cDv4_h(<4 x i8> %ia, <4 x i8> %ub) #2
36+
%call2 = tail call spir_func i32 @_Z3dotDv4_hDv4_c(<4 x i8> %ua, <4 x i8> %ib) #2
37+
%call3 = tail call spir_func i32 @_Z3dotDv4_hS_(<4 x i8> %ua, <4 x i8> %ub) #2
38+
%call4 = tail call spir_func i32 @_Z11dot_acc_satDv4_cS_i(<4 x i8> %ia, <4 x i8> %ib, i32 %call2) #2
39+
%call5 = tail call spir_func i32 @_Z11dot_acc_satDv4_cDv4_hi(<4 x i8> %ia, <4 x i8> %ub, i32 %call4) #2
40+
%call6 = tail call spir_func i32 @_Z11dot_acc_satDv4_hDv4_ci(<4 x i8> %ua, <4 x i8> %ib, i32 %call5) #2
41+
%call7 = tail call spir_func i32 @_Z11dot_acc_satDv4_hS_j(<4 x i8> %ua, <4 x i8> %ub, i32 %call3) #2
42+
ret void
43+
}
44+
45+
; Function Attrs: convergent
46+
declare spir_func i32 @_Z3dotDv4_cS_(<4 x i8>, <4 x i8>) local_unnamed_addr #1
47+
48+
; Function Attrs: convergent
49+
declare spir_func i32 @_Z3dotDv4_cDv4_h(<4 x i8>, <4 x i8>) local_unnamed_addr #1
50+
51+
; Function Attrs: convergent
52+
declare spir_func i32 @_Z3dotDv4_hDv4_c(<4 x i8>, <4 x i8>) local_unnamed_addr #1
53+
54+
; Function Attrs: convergent
55+
declare spir_func i32 @_Z3dotDv4_hS_(<4 x i8>, <4 x i8>) local_unnamed_addr #1
56+
57+
; Function Attrs: convergent
58+
declare spir_func i32 @_Z11dot_acc_satDv4_cS_i(<4 x i8>, <4 x i8>, i32) local_unnamed_addr #1
59+
60+
; Function Attrs: convergent
61+
declare spir_func i32 @_Z11dot_acc_satDv4_cDv4_hi(<4 x i8>, <4 x i8>, i32) local_unnamed_addr #1
62+
63+
; Function Attrs: convergent
64+
declare spir_func i32 @_Z11dot_acc_satDv4_hDv4_ci(<4 x i8>, <4 x i8>, i32) local_unnamed_addr #1
65+
66+
; Function Attrs: convergent
67+
declare spir_func i32 @_Z11dot_acc_satDv4_hS_j(<4 x i8>, <4 x i8>, i32) local_unnamed_addr #1
68+
69+
attributes #0 = { convergent norecurse nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pocharer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
70+
attributes #1 = { convergent "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pocharer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
71+
attributes #2 = { convergent nounwind }
72+
73+
!llvm.module.flags = !{!0}
74+
!opencl.ocl.version = !{!1}
75+
!opencl.spir.version = !{!1}
76+
!llvm.ident = !{!2}
77+
78+
!0 = !{i32 1, !"wchar_size", i32 4}
79+
!1 = !{i32 2, i32 0}
80+
!2 = !{!"clang version 11.0.0 (https://github.com/c199914007/llvm.git 8b94769313ca84cb9370b60ed008501ff692cb71)"}
81+
!3 = !{i32 0, i32 0, i32 0, i32 0, i32 0, i32 0}
82+
!4 = !{!"none", !"none", !"none", !"none", !"none", !"none"}
83+
!5 = !{!"char4", !"uchar4", !"char4", !"uchar4", !"char4", !"uchar4"}
84+
!6 = !{!"char __attribute__((ext_vector_type(4)))", !"uchar __attribute__((ext_vector_type(4)))", !"char __attribute__((ext_vector_type(4)))", !"uchar __attribute__((ext_vector_type(4)))", !"char __attribute__((ext_vector_type(4)))", !"uchar __attribute__((ext_vector_type(4)))"}
85+
!7 = !{!"", !"", !"", !"", !"", !""}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv -s %t.bc -o %t.regularized.bc
3+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_integer_dot_product -o %t-spirv.spv
4+
; RUN: spirv-val %t-spirv.spv
5+
; RUN: llvm-dis %t.regularized.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
6+
; RUN: llvm-spirv %t.bc -spirv-text --spirv-ext=+SPV_KHR_integer_dot_product -o - | FileCheck %s --check-prefix=CHECK-SPIRV
7+
8+
;CHECK-LLVM: call spir_func i32 @_Z15__spirv_SDotKHR
9+
;CHECK-LLVM: call spir_func i32 @_Z16__spirv_SUDotKHR
10+
;CHECK-LLVM: call spir_func i32 @_Z16__spirv_SUDotKHR
11+
;CHECK-LLVM: call spir_func i32 @_Z15__spirv_UDotKHR
12+
13+
;CHECK-LLVM: call spir_func i32 @_Z21__spirv_SDotAccSatKHR
14+
;CHECK-LLVM: call spir_func i32 @_Z22__spirv_SUDotAccSatKHR
15+
;CHECK-LLVM: call spir_func i32 @_Z22__spirv_SUDotAccSatKHR
16+
;CHECK-LLVM: call spir_func i32 @_Z21__spirv_UDotAccSatKHR
17+
18+
;CHECK-SPIRV: SDotKHR
19+
;CHECK-SPIRV: SUDotKHR
20+
;CHECK-SPIRV: SUDotKHR
21+
;CHECK-SPIRV: UDotKHR
22+
23+
;CHECK-SPIRV: SDotAccSatKHR
24+
;CHECK-SPIRV: SUDotAccSatKHR
25+
;CHECK-SPIRV: SUDotAccSatKHR
26+
;CHECK-SPIRV: UDotAccSatKHR
27+
28+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
29+
target triple = "spir"
30+
31+
; Function Attrs: convergent norecurse nounwind
32+
define spir_kernel void @test1(i32 %ia, i32 %ua, i32 %ib, i32 %ub, i32 %ires, i32 %ures) local_unnamed_addr #0 !kernel_arg_addr_space !3 !kernel_arg_access_qual !4 !kernel_arg_type !5 !kernel_arg_base_type !5 !kernel_arg_type_qual !6 {
33+
entry:
34+
%call = tail call spir_func i32 @_Z3dotiii(i32 %ia, i32 %ib, i32 0) #2
35+
%call1 = tail call spir_func i32 @_Z3dotiji(i32 %ia, i32 %ub, i32 0) #2
36+
%call2 = tail call spir_func i32 @_Z3dotjii(i32 %ua, i32 %ib, i32 0) #2
37+
%call3 = tail call spir_func i32 @_Z3dotjji(i32 %ua, i32 %ub, i32 0) #2
38+
%call4 = tail call spir_func i32 @_Z11dot_acc_satiiii(i32 %ia, i32 %ib, i32 %ires, i32 0) #2
39+
%call5 = tail call spir_func i32 @_Z11dot_acc_satijii(i32 %ia, i32 %ub, i32 %ires, i32 0) #2
40+
%call6 = tail call spir_func i32 @_Z11dot_acc_satjiii(i32 %ua, i32 %ib, i32 %ires, i32 0) #2
41+
%call7 = tail call spir_func i32 @_Z11dot_acc_satjjji(i32 %ua, i32 %ub, i32 %ures, i32 0) #2
42+
ret void
43+
}
44+
45+
; Function Attrs: convergent
46+
declare spir_func i32 @_Z3dotiii(i32, i32, i32) local_unnamed_addr #1
47+
48+
; Function Attrs: convergent
49+
declare spir_func i32 @_Z3dotiji(i32, i32, i32) local_unnamed_addr #1
50+
51+
; Function Attrs: convergent
52+
declare spir_func i32 @_Z3dotjii(i32, i32, i32) local_unnamed_addr #1
53+
54+
; Function Attrs: convergent
55+
declare spir_func i32 @_Z3dotjji(i32, i32, i32) local_unnamed_addr #1
56+
57+
; Function Attrs: convergent
58+
declare spir_func i32 @_Z11dot_acc_satiiii(i32, i32, i32, i32) local_unnamed_addr #1
59+
60+
; Function Attrs: convergent
61+
declare spir_func i32 @_Z11dot_acc_satijii(i32, i32, i32, i32) local_unnamed_addr #1
62+
63+
; Function Attrs: convergent
64+
declare spir_func i32 @_Z11dot_acc_satjiii(i32, i32, i32, i32) local_unnamed_addr #1
65+
66+
; Function Attrs: convergent
67+
declare spir_func i32 @_Z11dot_acc_satjjji(i32, i32, i32, i32) local_unnamed_addr #1
68+
69+
attributes #0 = { convergent norecurse nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
70+
attributes #1 = { convergent "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
71+
attributes #2 = { convergent nounwind }
72+
73+
!llvm.module.flags = !{!0}
74+
!opencl.ocl.version = !{!1}
75+
!opencl.spir.version = !{!1}
76+
!llvm.ident = !{!2}
77+
78+
!0 = !{i32 1, !"wchar_size", i32 4}
79+
!1 = !{i32 2, i32 0}
80+
!2 = !{!"clang version 11.0.0 (https://github.com/c199914007/llvm.git f2b7028a3598d4d88ddf1f76b50946da4e135845)"}
81+
!3 = !{i32 0, i32 0, i32 0, i32 0, i32 0, i32 0}
82+
!4 = !{!"none", !"none", !"none", !"none", !"none", !"none"}
83+
!5 = !{!"int", !"uint", !"int", !"uint", !"int", !"uint"}
84+
!6 = !{!"", !"", !"", !"", !"", !""}

0 commit comments

Comments
 (0)