Skip to content

Commit f717bd2

Browse files
qichaogujsji
authored andcommitted
Translate integer dot product SPIR-V builtins to OCL builtins (#2794)
#1174 implements translating integer dot product OCL builtins to SPIR-V builtins. This pull request is to do the reverse translation. Original commit: KhronosGroup/SPIRV-LLVM-Translator@925255cb1982896
1 parent 2a0693c commit f717bd2

File tree

5 files changed

+96
-18
lines changed

5 files changed

+96
-18
lines changed

llvm-spirv/lib/SPIRV/OCLUtil.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,20 @@ class OCLBuiltinFuncMangleInfo : public SPIRV::BuiltinFuncMangleInfo {
13271327
} else if (NameRef.starts_with("bitfield_extract_signed") ||
13281328
NameRef.starts_with("bitfield_extract_unsigned")) {
13291329
addUnsignedArgs(1, 2);
1330+
} else if (NameRef.starts_with("dot_")) {
1331+
if (NameRef.contains("4x8packed")) {
1332+
addUnsignedArgs(0, 1);
1333+
if (NameRef == "dot_acc_sat_4x8packed_uu_uint")
1334+
addUnsignedArg(2);
1335+
} else {
1336+
if (NameRef.ends_with("_uu")) {
1337+
addUnsignedArgs(0, 1);
1338+
if (NameRef.starts_with("dot_acc_sat"))
1339+
addUnsignedArg(2);
1340+
} else if (NameRef.ends_with("_su"))
1341+
addUnsignedArg(1);
1342+
NameRef = NameRef.drop_back(std::string("_uu").length());
1343+
}
13301344
}
13311345

13321346
// Store the final version of a function name

llvm-spirv/lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ void SPIRVToOCLBase::visitCallInst(CallInst &CI) {
224224
visitCallSPIRVBFloat16Conversions(&CI, OC);
225225
return;
226226
}
227+
if (OC == OpSDot || OC == OpUDot || OC == OpSUDot || OC == OpSDotAccSat ||
228+
OC == OpUDotAccSat || OC == OpSUDotAccSat) {
229+
visitCallSPIRVDot(&CI, OC, DemangledName);
230+
return;
231+
}
227232
if (OCLSPIRVBuiltinMap::rfind(OC))
228233
visitCallSPIRVBuiltin(&CI, OC);
229234
}
@@ -935,6 +940,61 @@ void SPIRVToOCLBase::visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC) {
935940
mutateCallInst(CI, Name);
936941
}
937942

943+
void SPIRVToOCLBase::visitCallSPIRVDot(CallInst *CI, Op OC,
944+
StringRef DemangledName) {
945+
// OpenCL only supports integer dot product builtins that have return types
946+
// of int and uint.
947+
if (!(DemangledName.contains("_Rint") || DemangledName.contains("_Ruint")))
948+
return;
949+
950+
bool IsPacked = !CI->getOperand(0)->getType()->isVectorTy();
951+
std::stringstream Name;
952+
switch (OC) {
953+
case OpSDot:
954+
if (IsPacked)
955+
Name << kOCLBuiltinName::Dot4x8PackedPrefix << "ss_int";
956+
else
957+
// Add an extra suffix to help determine signed/unsigned arguments
958+
Name << kOCLBuiltinName::Dot << "_ss";
959+
break;
960+
case OpUDot:
961+
if (IsPacked)
962+
Name << kOCLBuiltinName::Dot4x8PackedPrefix << "uu_uint";
963+
else
964+
Name << kOCLBuiltinName::Dot << "_uu";
965+
break;
966+
case OpSUDot:
967+
if (IsPacked)
968+
Name << kOCLBuiltinName::Dot4x8PackedPrefix << "su_int";
969+
else
970+
Name << kOCLBuiltinName::Dot << "_su";
971+
break;
972+
case OpSDotAccSat:
973+
if (IsPacked)
974+
Name << kOCLBuiltinName::DotAccSat4x8PackedPrefix << "ss_int";
975+
else
976+
Name << kOCLBuiltinName::DotAccSat << "_ss";
977+
break;
978+
case OpUDotAccSat:
979+
if (IsPacked)
980+
Name << kOCLBuiltinName::DotAccSat4x8PackedPrefix << "uu_uint";
981+
else
982+
Name << kOCLBuiltinName::DotAccSat << "_uu";
983+
break;
984+
case OpSUDotAccSat:
985+
if (IsPacked)
986+
Name << kOCLBuiltinName::DotAccSat4x8PackedPrefix << "su_int";
987+
else
988+
Name << kOCLBuiltinName::DotAccSat << "_su";
989+
break;
990+
default:
991+
break; // do nothing
992+
}
993+
auto Mutator = mutateCallInst(CI, Name.str());
994+
if (IsPacked)
995+
Mutator.removeArg(CI->arg_size() - 1);
996+
}
997+
938998
void SPIRVToOCLBase::visitCallSPIRVBuiltin(CallInst *CI, Op OC) {
939999
mutateCallInst(CI, OCLSPIRVBuiltinMap::rmap(OC));
9401000
}

llvm-spirv/lib/SPIRV/SPIRVToOCL.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,10 @@ class SPIRVToOCLBase : public InstVisitor<SPIRVToOCLBase>,
258258
// Transform FP atomic opcode to corresponding OpenCL function name
259259
virtual std::string mapFPAtomicName(Op OC) = 0;
260260

261+
/// Transform integer dot product builtins to corresponding OpenCL builtins
262+
/// examples: __spirv_SDotKHR => dot, __spirv_SDotAccSatKHR => dot_acc_sat
263+
void visitCallSPIRVDot(CallInst *CI, Op OC, StringRef DemangledName);
264+
261265
void translateOpaqueTypes();
262266

263267
private:

llvm-spirv/test/extensions/KHR/SPV_KHR_integer_dot_product/SPV_KHR_integer_dot_product-nonsat.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,17 @@ target triple = "spir-unknown-unknown"
4444

4545
; CHECK-LLVM: call spir_func i8 @_Z21__spirv_SDotKHR_Rchariii(
4646
; CHECK-LLVM: call spir_func i16 @_Z22__spirv_SDotKHR_Rshortiii(
47-
; CHECK-LLVM: call spir_func i32 @_Z20__spirv_SDotKHR_Rintiii(
47+
; CHECK-LLVM: call spir_func i32 @_Z20dot_4x8packed_ss_intjj(
4848
; CHECK-LLVM: call spir_func i64 @_Z21__spirv_SDotKHR_Rlongiii(
4949

5050
; CHECK-LLVM: call spir_func i8 @_Z22__spirv_UDotKHR_Ruchariii(
5151
; CHECK-LLVM: call spir_func i16 @_Z23__spirv_UDotKHR_Rushortiii(
52-
; CHECK-LLVM: call spir_func i32 @_Z21__spirv_UDotKHR_Ruintiii(
52+
; CHECK-LLVM: call spir_func i32 @_Z21dot_4x8packed_uu_uintjj(
5353
; CHECK-LLVM: call spir_func i64 @_Z22__spirv_UDotKHR_Rulongiii(
5454

5555
; CHECK-LLVM: call spir_func i8 @_Z22__spirv_SUDotKHR_Rchariii(
5656
; CHECK-LLVM: call spir_func i16 @_Z23__spirv_SUDotKHR_Rshortiii(
57-
; CHECK-LLVM: call spir_func i32 @_Z21__spirv_SUDotKHR_Rintiii(
57+
; CHECK-LLVM: call spir_func i32 @_Z20dot_4x8packed_su_intjj(
5858
; CHECK-LLVM: call spir_func i64 @_Z22__spirv_SUDotKHR_Rlongiii(
5959

6060
; CHECK-SPV-IR: call spir_func i8 @_Z21__spirv_SDotKHR_Rchariii(
@@ -112,17 +112,17 @@ define spir_kernel void @TestNonSatPacked(i32 %0, i32 %1) #0 !kernel_arg_addr_sp
112112

113113
; CHECK-LLVM: call spir_func i8 @_Z21__spirv_SDotKHR_RcharDv4_cS_(
114114
; CHECK-LLVM: call spir_func i16 @_Z22__spirv_SDotKHR_RshortDv4_cS_(
115-
; CHECK-LLVM: call spir_func i32 @_Z20__spirv_SDotKHR_RintDv4_cS_(
115+
; CHECK-LLVM: call spir_func i32 @_Z3dotDv4_cS_(
116116
; CHECK-LLVM: call spir_func i64 @_Z21__spirv_SDotKHR_RlongDv4_cS_(
117117

118118
; CHECK-LLVM: call spir_func i8 @_Z22__spirv_UDotKHR_RucharDv4_cS_(
119119
; CHECK-LLVM: call spir_func i16 @_Z23__spirv_UDotKHR_RushortDv4_cS_(
120-
; CHECK-LLVM: call spir_func i32 @_Z21__spirv_UDotKHR_RuintDv4_cS_(
120+
; CHECK-LLVM: call spir_func i32 @_Z3dotDv4_hS_(
121121
; CHECK-LLVM: call spir_func i64 @_Z22__spirv_UDotKHR_RulongDv4_cS_(
122122

123123
; CHECK-LLVM: call spir_func i8 @_Z22__spirv_SUDotKHR_RcharDv4_cS_(
124124
; CHECK-LLVM: call spir_func i16 @_Z23__spirv_SUDotKHR_RshortDv4_cS_(
125-
; CHECK-LLVM: call spir_func i32 @_Z21__spirv_SUDotKHR_RintDv4_cS_(
125+
; CHECK-LLVM: call spir_func i32 @_Z3dotDv4_cDv4_h(
126126
; CHECK-LLVM: call spir_func i64 @_Z22__spirv_SUDotKHR_RlongDv4_cS_(
127127

128128
; CHECK-SPV-IR: call spir_func i8 @_Z21__spirv_SDotKHR_RcharDv4_cS_(
@@ -179,15 +179,15 @@ define spir_kernel void @TestNonSatVec(<4 x i8> %0, <4 x i8> %1) #0 !kernel_arg_
179179
; CHECK-SPIRV: Function
180180

181181
; CHECK-LLVM: call spir_func i16 @_Z22__spirv_SDotKHR_RshortDv2_sS_(
182-
; CHECK-LLVM: call spir_func i32 @_Z20__spirv_SDotKHR_RintDv2_sS_(
182+
; CHECK-LLVM: call spir_func i32 @_Z3dotDv2_sS_(
183183
; CHECK-LLVM: call spir_func i64 @_Z21__spirv_SDotKHR_RlongDv2_sS_(
184184

185185
; CHECK-LLVM: call spir_func i16 @_Z23__spirv_UDotKHR_RushortDv2_sS_(
186-
; CHECK-LLVM: call spir_func i32 @_Z21__spirv_UDotKHR_RuintDv2_sS_(
186+
; CHECK-LLVM: call spir_func i32 @_Z3dotDv2_tS_(
187187
; CHECK-LLVM: call spir_func i64 @_Z22__spirv_UDotKHR_RulongDv2_sS_(
188188

189189
; CHECK-LLVM: call spir_func i16 @_Z23__spirv_SUDotKHR_RshortDv2_sS_(
190-
; CHECK-LLVM: call spir_func i32 @_Z21__spirv_SUDotKHR_RintDv2_sS_(
190+
; CHECK-LLVM: call spir_func i32 @_Z3dotDv2_sDv2_t(
191191
; CHECK-LLVM: call spir_func i64 @_Z22__spirv_SUDotKHR_RlongDv2_sS_(
192192

193193
; CHECK-SPV-IR: call spir_func i16 @_Z22__spirv_SDotKHR_RshortDv2_sS_(

llvm-spirv/test/extensions/KHR/SPV_KHR_integer_dot_product/SPV_KHR_integer_dot_product-sat.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,17 @@ target triple = "spir-unknown-unknown"
4444

4545
; CHECK-LLVM: call spir_func i8 @_Z27__spirv_SDotAccSatKHR_Rchariici(
4646
; CHECK-LLVM: call spir_func i16 @_Z28__spirv_SDotAccSatKHR_Rshortiisi(
47-
; CHECK-LLVM: call spir_func i32 @_Z26__spirv_SDotAccSatKHR_Rintiiii(
47+
; CHECK-LLVM: call spir_func i32 @_Z28dot_acc_sat_4x8packed_ss_intjji(
4848
; CHECK-LLVM: call spir_func i64 @_Z27__spirv_SDotAccSatKHR_Rlongiili(
4949

5050
; CHECK-LLVM: call spir_func i8 @_Z28__spirv_UDotAccSatKHR_Ruchariici(
5151
; CHECK-LLVM: call spir_func i16 @_Z29__spirv_UDotAccSatKHR_Rushortiisi(
52-
; CHECK-LLVM: call spir_func i32 @_Z27__spirv_UDotAccSatKHR_Ruintiiii(
52+
; CHECK-LLVM: call spir_func i32 @_Z29dot_acc_sat_4x8packed_uu_uintjjj(
5353
; CHECK-LLVM: call spir_func i64 @_Z28__spirv_UDotAccSatKHR_Rulongiili(
5454

5555
; CHECK-LLVM: call spir_func i8 @_Z28__spirv_SUDotAccSatKHR_Rchariici(
5656
; CHECK-LLVM: call spir_func i16 @_Z29__spirv_SUDotAccSatKHR_Rshortiisi(
57-
; CHECK-LLVM: call spir_func i32 @_Z27__spirv_SUDotAccSatKHR_Rintiiii(
57+
; CHECK-LLVM: call spir_func i32 @_Z28dot_acc_sat_4x8packed_su_intjji(
5858
; CHECK-LLVM: call spir_func i64 @_Z28__spirv_SUDotAccSatKHR_Rlongiili(
5959

6060
; CHECK-SPV-IR: call spir_func i8 @_Z27__spirv_SDotAccSatKHR_Rchariici(
@@ -112,17 +112,17 @@ define spir_kernel void @TestSatPacked(i32 %0, i32 %1, i8 %acc8, i16 %acc16, i32
112112

113113
; CHECK-LLVM: call spir_func i8 @_Z27__spirv_SDotAccSatKHR_RcharDv4_cS_c(
114114
; CHECK-LLVM: call spir_func i16 @_Z28__spirv_SDotAccSatKHR_RshortDv4_cS_s(
115-
; CHECK-LLVM: call spir_func i32 @_Z26__spirv_SDotAccSatKHR_RintDv4_cS_i(
115+
; CHECK-LLVM: call spir_func i32 @_Z11dot_acc_satDv4_cS_i(
116116
; CHECK-LLVM: call spir_func i64 @_Z27__spirv_SDotAccSatKHR_RlongDv4_cS_l(
117117

118118
; CHECK-LLVM: call spir_func i8 @_Z28__spirv_UDotAccSatKHR_RucharDv4_cS_c(
119119
; CHECK-LLVM: call spir_func i16 @_Z29__spirv_UDotAccSatKHR_RushortDv4_cS_s(
120-
; CHECK-LLVM: call spir_func i32 @_Z27__spirv_UDotAccSatKHR_RuintDv4_cS_i(
120+
; CHECK-LLVM: call spir_func i32 @_Z11dot_acc_satDv4_hS_j
121121
; CHECK-LLVM: call spir_func i64 @_Z28__spirv_UDotAccSatKHR_RulongDv4_cS_l(
122122

123123
; CHECK-LLVM: call spir_func i8 @_Z28__spirv_SUDotAccSatKHR_RcharDv4_cS_c(
124124
; CHECK-LLVM: call spir_func i16 @_Z29__spirv_SUDotAccSatKHR_RshortDv4_cS_s(
125-
; CHECK-LLVM: call spir_func i32 @_Z27__spirv_SUDotAccSatKHR_RintDv4_cS_i(
125+
; CHECK-LLVM: call spir_func i32 @_Z11dot_acc_satDv4_cDv4_hi(
126126
; CHECK-LLVM: call spir_func i64 @_Z28__spirv_SUDotAccSatKHR_RlongDv4_cS_l(
127127

128128
; CHECK-SPV-IR: call spir_func i8 @_Z27__spirv_SDotAccSatKHR_RcharDv4_cS_c(
@@ -179,15 +179,15 @@ define spir_kernel void @TestSatVec(<4 x i8> %0, <4 x i8> %1, i8 %acc8, i16 %acc
179179
; CHECK-SPIRV: Function
180180

181181
; CHECK-LLVM: call spir_func i16 @_Z28__spirv_SDotAccSatKHR_RshortDv2_sS_s(
182-
; CHECK-LLVM: call spir_func i32 @_Z26__spirv_SDotAccSatKHR_RintDv2_sS_i(
182+
; CHECK-LLVM: call spir_func i32 @_Z11dot_acc_satDv2_sS_i(
183183
; CHECK-LLVM: call spir_func i64 @_Z27__spirv_SDotAccSatKHR_RlongDv2_sS_l(
184184

185185
; CHECK-LLVM: call spir_func i16 @_Z29__spirv_UDotAccSatKHR_RushortDv2_sS_s(
186-
; CHECK-LLVM: call spir_func i32 @_Z27__spirv_UDotAccSatKHR_RuintDv2_sS_i(
186+
; CHECK-LLVM: call spir_func i32 @_Z11dot_acc_satDv2_tS_j(
187187
; CHECK-LLVM: call spir_func i64 @_Z28__spirv_UDotAccSatKHR_RulongDv2_sS_l(
188188

189189
; CHECK-LLVM: call spir_func i16 @_Z29__spirv_SUDotAccSatKHR_RshortDv2_sS_s(
190-
; CHECK-LLVM: call spir_func i32 @_Z27__spirv_SUDotAccSatKHR_RintDv2_sS_i(
190+
; CHECK-LLVM: call spir_func i32 @_Z11dot_acc_satDv2_sDv2_ti(
191191
; CHECK-LLVM: call spir_func i64 @_Z28__spirv_SUDotAccSatKHR_RlongDv2_sS_l(
192192

193193
; CHECK-SPV-IR: call spir_func i16 @_Z28__spirv_SDotAccSatKHR_RshortDv2_sS_s(

0 commit comments

Comments
 (0)